mirror of https://github.com/commaai/tinygrad.git
Non fp32 math (#2264)
* `global_load` and `global_store` using buffer dtype * `UOps.PHI` in all dtypes * `UOps.ALU` in all dtypes * `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes * -- endof implementation -- +tiny lint changes * these tests require the fp16 extention you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261) `GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul` skip the new test_linearizer_failures in CI GPU because of the fp16 extention This passes on a real GPU since the extention is available: `GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8` see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644) * these tests fail in CI due to segfaults and CPU crashes To confirm they're green locally, you can run the following commands: 1. For the tests skipped in test_ops.py (note: CLANG is very slow) `for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done` 2. For the ONNX tests skipped in CLANG: ``` CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu ``` 3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific `LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu` * Revert "these tests fail in CI due to segfaults and CPU crashes" This reverts commit 15db57014381a4449d563526ac6c870e36257658. * merge with cleanup-vectorized-hip-renders * barely working HIP P1, ALU ops need a refactor? * manage the fact that in HIP [half2 is actually an unsigned int vec](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L59)
) and half is a totally different __half that [has an unsigned int element in it](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L50)
) but can't be accessed [because it's private](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L86)
). If you just do this: ``` half2 val0 = // ... half val1 = // ... ``` then you can't do: ``` val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half')) ``` * update the sign definition to avoid division by zero in all dtypes * diff cleanup p1: why were these in the diff anyways * less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI! add ALU ops overloads for HIP this will make HIP max work handle mod Revert "handle mod" This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933. update max to use hmax add HIP GEP render logic enable CIFAR fp16 benchmark test ops for HIP back to store as float because this only works for float4 grouping right now test_ops for hip!! always sign * back to the sign we had before because we cant do a backward pass on a Less node * remove old hacks HIP compiling test_ops in CI takes ~9 mins, not doing it for now new HIP ALUs * reduce accs done right * refactor to function * no device hacks hacks p2 the other way * LLVM ALU ops half, float and double are all float update max * update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool * cleanup LLVM wrong code * dummy change for the CUDA install glitch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
1ac958a058
commit
4380ccb169
|
@ -128,8 +128,8 @@ jobs:
|
|||
run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
- name: Run 10 CIFAR training steps w winograd
|
||||
run: WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
#- name: Run 10 CIFAR training steps w WINO/HALF/HIP
|
||||
# run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt
|
||||
- name: Run 10 CIFAR training steps w WINO/HALF/HIP
|
||||
run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: Speed (AMD)
|
||||
|
|
|
@ -29,15 +29,15 @@ def get_max(var):
|
|||
def remove_single_scalar_curly_braces(ptx_code):
|
||||
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
|
||||
|
||||
def render_const(args):
|
||||
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else str(args))
|
||||
def render_const(args,dtype:DType):
|
||||
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
||||
|
||||
def render_cast(x:str, dtype:DType):
|
||||
return f"{x}.to({triton_dtypes[dtype]})"
|
||||
|
||||
def define_scalar(local_size, dtype, args):
|
||||
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args)}, dtype={triton_dtypes[dtype]})"
|
||||
return render_const(args)
|
||||
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
||||
return render_const(args,dtype)
|
||||
|
||||
def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
local_size: List[int] = []
|
||||
|
|
|
@ -177,8 +177,8 @@ if Device.DEFAULT in ['GPU', 'METAL']:
|
|||
backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy
|
||||
backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double
|
||||
|
||||
# Segfaults in CI
|
||||
if Device.DEFAULT in ['LLVM', 'CUDA'] and CI:
|
||||
# Segfaults in CI, GPU requires cl_khr_fp16
|
||||
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
|
||||
backend_test.exclude('test_max_float16_cpu')
|
||||
backend_test.exclude('test_min_float16_cpu')
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
|
||||
def test_llama(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
|
@ -63,7 +63,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
|
@ -102,4 +102,4 @@ class TestRealWorld(unittest.TestCase):
|
|||
#Device.DEFAULT = old_default
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
|
@ -15,7 +15,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc
|
|||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time."
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
@ -22,7 +22,6 @@ class TestHIPCompilationRDNA(unittest.TestCase):
|
|||
output = model(input)
|
||||
output.numpy()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_compile_hip_speedyresnet_hf(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
|
@ -34,4 +33,4 @@ class TestHIPCompilationRDNA(unittest.TestCase):
|
|||
output.numpy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -79,7 +79,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
|||
ast = helper_add_store(ast)
|
||||
helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM" and not OSX, "Segmentation fault on ubuntu")
|
||||
@unittest.skipIf((Device.DEFAULT=="LLVM" and not OSX) or (Device.DEFAULT == "GPU" and CI), "Segmentation fault on ubuntu, GPU requires cl_khr_fp16")
|
||||
def test_failure_8(self):
|
||||
ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None)
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestSpecific(unittest.TestCase):
|
|||
w = Tensor.randn(2048, 512)
|
||||
(x @ w).reshape(1, 128, 4).contiguous().realize()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "CUDA"], "Broken on LLVM, WEBGPU and CUDA")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], "Broken on LLVM and webgpu, GPU requires cl_khr_fp16")
|
||||
def test_big_vec_mul(self):
|
||||
# from LLaMA
|
||||
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
|
||||
|
@ -52,4 +52,4 @@ class TestSpecific(unittest.TestCase):
|
|||
x.conv2d(w, stride=2, padding=1).permute(0,2,3,1).reshape(18, 18*384//4, 4).contiguous().realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -14,7 +14,7 @@ def _uops_to_prg(uops):
|
|||
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
uops.append(UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg))
|
||||
return uops[-1]
|
||||
|
||||
def _test_single_value(vals, op, dtype):
|
||||
|
@ -43,7 +43,7 @@ def _test_single_value_const(vals, op, dtype):
|
|||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
|
||||
|
||||
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
|
@ -78,7 +78,7 @@ class TestFloatUOps(TestUOps):
|
|||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
||||
|
|
|
@ -46,6 +46,11 @@ class Linearizer(Kernel):
|
|||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self, op, dtype:DType):
|
||||
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
|
@ -74,7 +79,8 @@ class Linearizer(Kernel):
|
|||
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
|
||||
else:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
|
||||
localtype = dtypes.float32 if amt == 1 else dtypes.float.vec(amt)
|
||||
localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt)
|
||||
if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt)
|
||||
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
|
@ -237,7 +243,7 @@ class Linearizer(Kernel):
|
|||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
|
||||
|
||||
if self.tensor_core:
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
|
@ -343,7 +349,7 @@ class Linearizer(Kernel):
|
|||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype))
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = render_loop(end_local_idxs)
|
||||
|
@ -455,8 +461,14 @@ class Linearizer(Kernel):
|
|||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
|
||||
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
|
||||
key = (uop, dtype, vin, arg)
|
||||
if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
|
||||
if uop == UOps.ALU: # upcast vins to the same dtype
|
||||
upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float
|
||||
if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool
|
||||
else: vin = tuple(self.cast(x, upcast_dtype) for x in vin)
|
||||
dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool
|
||||
if simplify:
|
||||
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
||||
|
@ -502,11 +514,11 @@ class Linearizer(Kernel):
|
|||
ret: List[UOp] = []
|
||||
input_acc = acc[:]
|
||||
for val, off in zip(zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
||||
acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op])
|
||||
ret.append(acc[off])
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uop(UOps.PHI, dtypes.float32, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uop(UOps.ALU, dtypes.float32, val, x.op) for val in zip(*values)]
|
||||
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)]
|
||||
return ret
|
||||
|
|
|
@ -47,18 +47,18 @@ class CStyleLanguage(NamedTuple):
|
|||
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})"
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
def render_const(self, x:Union[float,int,bool], var_dtype) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}"
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
||||
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.sz > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
|
@ -95,7 +95,7 @@ class CStyleLanguage(NamedTuple):
|
|||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes.float.vec(4), "images must be float4"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16:
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
|
@ -156,8 +156,6 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
|
||||
elif args == BinaryOps.MAX:
|
||||
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin] + [dtype])
|
||||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
|
@ -292,10 +290,31 @@ __device__ float4 vload_half4(size_t offset, const half *p) { return make_float4
|
|||
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
|
||||
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
|
||||
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
|
||||
__device__ half exp2(half x) { return hexp2(x); }
|
||||
__device__ half log2(half x) { return hlog2(x); }
|
||||
__device__ half sin(half x) { return hsin(x); }
|
||||
__device__ half sqrt(half x) { return hsqrt(x); }
|
||||
__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; }
|
||||
__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); }
|
||||
__device__ bool operator!=(const half &a, const int &b) { return (float)a != b; }
|
||||
|
||||
// HACKS for ALU ops on half and result of half2 GEP
|
||||
__device__ half operator+(const half &a, const unsigned short &b) { return __hadd(a, (half)(b)); }
|
||||
__device__ half operator-(const half &a, const unsigned short &b) { return __hsub(a, (half)(b)); }
|
||||
__device__ half operator*(const half &a, const unsigned short &b) { return __hmul(a, (half)(b)); }
|
||||
__device__ half operator/(const half &a, const unsigned short &b) { return __hdiv(a, (half)(b)); }
|
||||
__device__ bool operator<(const half &a, const unsigned short &b) { return __hlt(a, (half)(b)); }
|
||||
// now the other way
|
||||
__device__ half operator+(const unsigned short &a, const half &b) { return __hadd((half)(a), b); }
|
||||
__device__ half operator-(const unsigned short &a, const half &b) { return __hsub((half)(a), b); }
|
||||
__device__ half operator*(const unsigned short &a, const half &b) { return __hmul((half)(a), b); }
|
||||
__device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); }
|
||||
__device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); }
|
||||
"""
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"}
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
|
||||
# TODO: how much of this can be merged with above?
|
||||
|
@ -338,9 +357,6 @@ class WGSLLanguage(CStyleLanguage):
|
|||
if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
|
||||
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
|
||||
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())
|
||||
|
|
|
@ -6,8 +6,9 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
|||
|
||||
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.not_(x),
|
||||
UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
|
@ -16,12 +17,12 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
# TODO: this should be casted
|
||||
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.select(builder.icmp_signed(">", x, y), x, y),
|
||||
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y) if isinstance(x.type, ir.FloatType) else builder.urem(x,y),
|
||||
BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y),
|
||||
BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y),
|
||||
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
|
||||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z),
|
||||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z
|
||||
),
|
||||
}
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
||||
|
@ -98,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(ir.FloatType())
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2]._block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
|
@ -146,7 +147,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
|
||||
else: store_op()
|
||||
if uop == UOps.ALU:
|
||||
lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype)
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)
|
||||
|
||||
bb[-1].ret_void()
|
||||
|
|
|
@ -634,7 +634,7 @@ class Tensor:
|
|||
def square(self): return self*self
|
||||
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
||||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
|
||||
def reciprocal(self): return 1.0/self
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
|
Loading…
Reference in New Issue