diff --git a/extra/triton/triton.py b/extra/triton/triton.py index 32453782..d1ff8149 100644 --- a/extra/triton/triton.py +++ b/extra/triton/triton.py @@ -32,8 +32,8 @@ def remove_single_scalar_curly_braces(ptx_code): def render_const(args,dtype:DType): return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args)) -def render_cast(x:str, dtype:DType): - return f"{x}.to({triton_dtypes[dtype]})" +def render_cast(x:str, dtype:DType, bitcast=False): + return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})" def define_scalar(local_size, dtype, args): if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})" @@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") local_size.append(args[2]) r[u] = args[1] - elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype) + elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1]) else: raise NotImplementedError(f"unimplemented: {uop}") prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n" diff --git a/test/test_dtype.py b/test/test_dtype.py index 5266ce62..7f5f6763 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) -def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) +def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist()) class TestDType(unittest.TestCase): DTYPE: Any = None @@ -82,6 +82,12 @@ class TestDType(unittest.TestCase): lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, get_available_cast_dtypes(self.DTYPE) )) + def test_bitcast(self): + if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast") + list(map( + lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None, + get_available_cast_dtypes(self.DTYPE) + )) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return @@ -140,21 +146,7 @@ class TestUint8Dtype(TestDType): @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) -@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") class TestBitCast(unittest.TestCase): - def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") - def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432]) - def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0]) - - # NOTE: these are the same as normal casts - def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252]) - def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4]) - def test_shape_change_bitcast(self): with self.assertRaises(AssertionError): _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index a1adf348..ce8eb637 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -71,14 +71,13 @@ def torch_load(fn:str): lens[storage[2]] = storage[4] * storage[1].itemsize if storage[2] not in offsets: return None byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize - ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1]) + ret = t[byte_offset:byte_offset+prod(size)] # convert bfloat16 -> float16 using LLVM for Llama 2 # upstream LLaMA also does this conversion: # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support - if storage[1] == dtypes.bfloat16: - ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half() - #ret = ret.to("LLVM").half().to(Device.DEFAULT) + if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half() + else: ret = ret.cast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index a6f089f8..06975ec5 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -40,7 +40,8 @@ class CStyleLanguage(NamedTuple): } # returns a str expression of the casted xs with the given type - def render_cast(self, x:List[str], var_dtype:DType) -> str: + def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: + if bitcast: return f"(*(({self.buffer_prefix}{var_dtype.name}*)&{x[0]}))" if len(x) == 1: return f"({var_dtype.name})({x[0]})" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert self.float4 is not None, "vectorized cast is not supported on this platform" @@ -187,7 +188,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) if len(vin) > 3: kk("}") elif uop == UOps.CAST and dtype is not None: - val = lang.render_cast([r[x] for x in vin], dtype) + val = lang.render_cast([r[x] for x in vin], dtype, bitcast=isinstance(args, tuple) and args[1]) if child_count[u] <= 1: r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") elif uop == UOps.DEFINE_LOCAL: @@ -224,6 +225,9 @@ class OpenCLLanguage(CStyleLanguage): uses_vload = True # NOTE: mad is used so the loads aren't reordered into the math on 845 code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"} + type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" } + def render_cast(self, x, var_dtype, bitcast=False) -> str: + return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype) OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) class MetalLanguage(CStyleLanguage): @@ -237,6 +241,8 @@ class MetalLanguage(CStyleLanguage): gid = [f"gid.{chr(120+i)}" for i in range(3)] lid = [f"lid.{chr(120+i)}" for i in range(3)] extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'] + def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str: + return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage()) class CUDALanguage(CStyleLanguage): @@ -350,8 +356,8 @@ class WGSLLanguage(CStyleLanguage): def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select(f32({y}), {x}, bool({cond}))" - def render_cast(self, x:List[str], var_dtype:DType) -> str: - if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})" + def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: + if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})" raise NotImplementedError(f"no cast for {var_dtype}") def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 27ea4622..0d543d47 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -28,8 +28,9 @@ code_for_op: Final[Dict[Op, Callable]] = { dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} -def cast(bb, val, input_type, output_type): +def cast(bb, val, input_type, output_type, bitcast=False): if input_type == output_type: return val + if bitcast: return bb[-1].bitcast(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_float(input_type): if dtypes.is_float(output_type): @@ -149,7 +150,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: else: store_op() if uop == UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) - if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype) + if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1]) bb[-1].ret_void() return str(module), {}