diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fbaa4c87..64912a42 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -128,8 +128,8 @@ jobs: run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt - name: Run 10 CIFAR training steps w winograd run: WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt - #- name: Run 10 CIFAR training steps w WINO/HALF/HIP - # run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt + - name: Run 10 CIFAR training steps w WINO/HALF/HIP + run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt - uses: actions/upload-artifact@v3 with: name: Speed (AMD) diff --git a/extra/triton/triton.py b/extra/triton/triton.py index a8fb6e6b..32453782 100644 --- a/extra/triton/triton.py +++ b/extra/triton/triton.py @@ -29,15 +29,15 @@ def get_max(var): def remove_single_scalar_curly_braces(ptx_code): return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')]) -def render_const(args): - return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else str(args)) +def render_const(args,dtype:DType): + return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args)) def render_cast(x:str, dtype:DType): return f"{x}.to({triton_dtypes[dtype]})" def define_scalar(local_size, dtype, args): - if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args)}, dtype={triton_dtypes[dtype]})" - return render_const(args) + if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})" + return render_const(args,dtype) def uops_to_triton(function_name:str, uops:List[UOp]): local_size: List[int] = [] diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 4df86c88..c3059203 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -177,8 +177,8 @@ if Device.DEFAULT in ['GPU', 'METAL']: backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double -# Segfaults in CI -if Device.DEFAULT in ['LLVM', 'CUDA'] and CI: +# Segfaults in CI, GPU requires cl_khr_fp16 +if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI: backend_test.exclude('test_max_float16_cpu') backend_test.exclude('test_min_float16_cpu') diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index ce5f6644..4b3cba67 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -51,7 +51,7 @@ class TestRealWorld(unittest.TestCase): helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953) @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM") + @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1") def test_llama(self): Tensor.default_type = dtypes.float16 @@ -63,7 +63,7 @@ class TestRealWorld(unittest.TestCase): # TODO: test first token vs rest properly, also memory test is broken with CacheCollector helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True) - @unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM") + @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") def test_gpt2(self): Tensor.default_type = dtypes.float16 @@ -102,4 +102,4 @@ class TestRealWorld(unittest.TestCase): #Device.DEFAULT = old_default if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 0423c22c..323e4dc0 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -15,7 +15,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3' TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." -@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others") +@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16") class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/test/test_hip_rdna3.py b/test/test_hip_rdna3.py index 4ca29edc..dba22647 100644 --- a/test/test_hip_rdna3.py +++ b/test/test_hip_rdna3.py @@ -22,7 +22,6 @@ class TestHIPCompilationRDNA(unittest.TestCase): output = model(input) output.numpy() - @unittest.expectedFailure def test_compile_hip_speedyresnet_hf(self): Tensor.default_type = dtypes.float16 @@ -34,4 +33,4 @@ class TestHIPCompilationRDNA(unittest.TestCase): output.numpy() if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index ba9f20a3..0ae18a7e 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -79,7 +79,7 @@ class TestLinearizerFailures(unittest.TestCase): ast = helper_add_store(ast) helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"]) - @unittest.skipIf(Device.DEFAULT=="LLVM" and not OSX, "Segmentation fault on ubuntu") + @unittest.skipIf((Device.DEFAULT=="LLVM" and not OSX) or (Device.DEFAULT == "GPU" and CI), "Segmentation fault on ubuntu, GPU requires cl_khr_fp16") def test_failure_8(self): ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)] diff --git a/test/test_specific_conv.py b/test/test_specific_conv.py index fa7f53db..c8bab971 100644 --- a/test/test_specific_conv.py +++ b/test/test_specific_conv.py @@ -20,7 +20,7 @@ class TestSpecific(unittest.TestCase): w = Tensor.randn(2048, 512) (x @ w).reshape(1, 128, 4).contiguous().realize() - @unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "CUDA"], "Broken on LLVM, WEBGPU and CUDA") + @unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], "Broken on LLVM and webgpu, GPU requires cl_khr_fp16") def test_big_vec_mul(self): # from LLaMA # 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)] @@ -52,4 +52,4 @@ class TestSpecific(unittest.TestCase): x.conv2d(w, stride=2, padding=1).permute(0,2,3,1).reshape(18, 18*384//4, 4).contiguous().realize() if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 1ba35c54..4d055106 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -14,7 +14,7 @@ def _uops_to_prg(uops): runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: - uops.append(UOp(uop, dtype, tuple(vin), arg)) + uops.append(UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg)) return uops[-1] def _test_single_value(vals, op, dtype): @@ -43,7 +43,7 @@ def _test_single_value_const(vals, op, dtype): class TestUOps(unittest.TestCase): def _equal(self, v1, v2): - if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) + if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2) def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): for f in [_test_single_value, _test_single_value_const]: @@ -78,7 +78,7 @@ class TestFloatUOps(TestUOps): def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf')) def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) + def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val + + def get_reduce_acc(self, op, dtype:DType): + if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 + elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b), MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), @@ -74,7 +79,8 @@ class Linearizer(Kernel): (g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None else: g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs) - localtype = dtypes.float32 if amt == 1 else dtypes.float.vec(amt) + localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt) + if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt) e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars) @@ -237,7 +243,7 @@ class Linearizer(Kernel): fake_reduce_idxs = [x*0 for x in reduce_idxs] # define accumulator - acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype)) if self.tensor_core: def calc_tc_idxs(local_size: int, aliases: List[List[int]]): @@ -343,7 +349,7 @@ class Linearizer(Kernel): # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # late reduce loop loop_ctx = render_loop(end_local_idxs) @@ -455,8 +461,14 @@ class Linearizer(Kernel): self.applied_opts_cache = self.applied_opts[:] return self - def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: + def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: key = (uop, dtype, vin, arg) + if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:] + if uop == UOps.ALU: # upcast vins to the same dtype + upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float + if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool + else: vin = tuple(self.cast(x, upcast_dtype) for x in vin) + dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool if simplify: if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before) @@ -502,11 +514,11 @@ class Linearizer(Kernel): ret: List[UOp] = [] input_acc = acc[:] for val, off in zip(zip(*values), cast(List[int], offs)): - acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op]) + acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op]) ret.append(acc[off]) for off in range(len(acc)): if input_acc[off] != acc[off]: - acc[off] = self.uop(UOps.PHI, dtypes.float32, (input_acc[off], acc[off]) + tuple(loop_ctx)) + acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx)) else: - ret = [self.uop(UOps.ALU, dtypes.float32, val, x.op) for val in zip(*values)] + ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)] return ret diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ba8916d1..f359e9d1 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -47,18 +47,18 @@ class CStyleLanguage(NamedTuple): return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})" # returns a str expression of the const with the given type - def render_const(self, x:Union[float,int], var_dtype) -> str: + def render_const(self, x:Union[float,int,bool], var_dtype) -> str: if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" - else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" - return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val + else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower() + return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val # returns a str expression of the loaded value with the output type def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: if isinstance(buf_dtype, ImageDType): assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" - if self.uses_vload and buf_dtype == dtypes.float16: + if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" @@ -95,7 +95,7 @@ class CStyleLanguage(NamedTuple): if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes.float.vec(4), "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" - if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16: + if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" @@ -156,8 +156,6 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu # remove parens if ALU types are the same. TODO: can do more here if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype) - elif args == BinaryOps.MAX: - val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin] + [dtype]) else: val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) assert child_count[u] != 0, f"childless ALU op found {u}" @@ -292,10 +290,31 @@ __device__ float4 vload_half4(size_t offset, const half *p) { return make_float4 __device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; } __device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; } __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; } +__device__ half exp2(half x) { return hexp2(x); } +__device__ half log2(half x) { return hlog2(x); } +__device__ half sin(half x) { return hsin(x); } +__device__ half sqrt(half x) { return hsqrt(x); } +__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; } +__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); } +__device__ bool operator!=(const half &a, const int &b) { return (float)a != b; } + +// HACKS for ALU ops on half and result of half2 GEP +__device__ half operator+(const half &a, const unsigned short &b) { return __hadd(a, (half)(b)); } +__device__ half operator-(const half &a, const unsigned short &b) { return __hsub(a, (half)(b)); } +__device__ half operator*(const half &a, const unsigned short &b) { return __hmul(a, (half)(b)); } +__device__ half operator/(const half &a, const unsigned short &b) { return __hdiv(a, (half)(b)); } +__device__ bool operator<(const half &a, const unsigned short &b) { return __hlt(a, (half)(b)); } +// now the other way +__device__ half operator+(const unsigned short &a, const half &b) { return __hadd((half)(a), b); } +__device__ half operator-(const unsigned short &a, const half &b) { return __hsub((half)(a), b); } +__device__ half operator*(const unsigned short &a, const half &b) { return __hmul((half)(a), b); } +__device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); } +__device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); } """ gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] + code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"} HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) # TODO: how much of this can be merged with above? @@ -338,9 +357,6 @@ class WGSLLanguage(CStyleLanguage): if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})" raise NotImplementedError(f"no cast for {var_dtype}") - def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: - return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})" - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};" WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage()) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 50ad9fa3..99801ece 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -6,8 +6,9 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf +def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1 code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.not_(x), + UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS), UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), @@ -16,12 +17,12 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS), - # TODO: this should be casted - BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()), - BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.select(builder.icmp_signed(">", x, y), x, y), - BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y) if isinstance(x.type, ir.FloatType) else builder.urem(x,y), + BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y), + BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS), - TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z), + TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z + ), } dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} @@ -98,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: phis = [] for rp in reduce_phis: incoming = lvars[rp] - lvars[rp] = bb[-1].phi(ir.FloatType()) + lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) lvars[rp].add_incoming(incoming, bb[-2]._block) phis.append((rp, lvars[rp])) @@ -146,7 +147,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op() else: store_op() if uop == UOps.ALU: - lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype) + lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype) bb[-1].ret_void() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 771159b0..835fa172 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -634,7 +634,7 @@ class Tensor: def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu() - def sign(self): return self / (self.abs() + 1e-10) + def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) def reciprocal(self): return 1.0/self # ***** activation functions (unary) *****