mirror of https://github.com/commaai/tinygrad.git
remove HIP in core tinygrad (#3810)
* remove HIP in core tinygrad ci test uses device RHIP and HSA compiler (LinearizerOpt), so fine to remove HIP from tc. Also updated README and EMULATE tc test flag * EMULATE_CUDA
This commit is contained in:
parent
a7afd2f6bf
commit
5dd048a378
|
@ -35,16 +35,18 @@ jobs:
|
||||||
IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d
|
IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d
|
||||||
- name: Test emulated METAL tensor cores
|
- name: Test emulated METAL tensor cores
|
||||||
run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
|
run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
|
||||||
- name: Test emulated HIP tensor cores
|
- name: Test emulated HSA tensor cores
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||||
PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
- name: Test enumlated CUDA tensor cores
|
||||||
|
run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
|
||||||
- name: Full test tensor cores
|
- name: Full test tensor cores
|
||||||
run: |
|
run: |
|
||||||
DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||||
DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||||
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||||
- name: Test dtype with Python emulator
|
- name: Test dtype with Python emulator
|
||||||
run: PYTHONPATH=. DEBUG=2 PYTHON=1 python3 test/test_dtype.py
|
run: PYTHONPATH=. DEBUG=2 PYTHON=1 python3 test/test_dtype.py
|
||||||
|
|
|
@ -81,7 +81,7 @@ tinygrad already supports numerous accelerators, including:
|
||||||
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
||||||
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
||||||
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
||||||
- [x] [HIP](tinygrad/runtime/ops_hip.py)
|
- [x] [HSA](tinygrad/runtime/ops_hsa.py)
|
||||||
|
|
||||||
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
||||||
More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
|
More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
|
||||||
#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals")
|
#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HSA", "CUDA", "METAL"}, "only backends with locals")
|
||||||
@unittest.skipIf(CI, "slow")
|
@unittest.skipIf(CI, "slow")
|
||||||
class TestLinearizerOverflowAlt(unittest.TestCase):
|
class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||||
def test_overflow_1(self):
|
def test_overflow_1(self):
|
||||||
|
|
|
@ -56,7 +56,7 @@ tensor_cores: Dict[str, List[TensorCore]] = {
|
||||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||||
],
|
],
|
||||||
"HIP": [
|
"HSA": [
|
||||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||||
],
|
],
|
||||||
|
@ -64,7 +64,6 @@ tensor_cores: Dict[str, List[TensorCore]] = {
|
||||||
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
|
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
tensor_cores["HSA"] = tensor_cores["HIP"]
|
|
||||||
|
|
||||||
class LocalBuffer(NamedTuple):
|
class LocalBuffer(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
|
|
|
@ -107,7 +107,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||||
beam: List[Tuple[Linearizer, float]] = []
|
beam: List[Tuple[Linearizer, float]] = []
|
||||||
seen_libs = set()
|
seen_libs = set()
|
||||||
|
|
||||||
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01)
|
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01)
|
||||||
if beam_pool is None and getenv("PARALLEL", default_parallel): beam_pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker)
|
if beam_pool is None and getenv("PARALLEL", default_parallel): beam_pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -82,11 +82,10 @@ class LazyBuffer:
|
||||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||||
|
|
||||||
def _copy(self, device:str) -> LazyBuffer:
|
def _copy(self, device:str) -> LazyBuffer:
|
||||||
sync_size = 1 if self.device.startswith("HIP") else 0
|
|
||||||
if self.device.startswith("EXT") or self.device.startswith("DISK"):
|
if self.device.startswith("EXT") or self.device.startswith("DISK"):
|
||||||
# DISK/EXT don't sync
|
# DISK/EXT don't sync
|
||||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
||||||
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
||||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
|
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
|
||||||
|
|
||||||
|
|
|
@ -188,7 +188,7 @@ class PythonProgram:
|
||||||
|
|
||||||
class PythonCompiler(Compiler):
|
class PythonCompiler(Compiler):
|
||||||
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
|
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
|
||||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
|
(LinearizerOptions("HSA", has_tensor_cores=True) if getenv("EMULATE_HSA") else \
|
||||||
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
|
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
|
||||||
def render(self, name:str, uops:UOpGraph) -> str:
|
def render(self, name:str, uops:UOpGraph) -> str:
|
||||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||||
|
|
Loading…
Reference in New Issue