mirror of https://github.com/commaai/tinygrad.git
Bitcast support / fast bf16 load (#2011)
* bitcast renderers * fast llama load * make it one kernel * regression testing p1: re-enable test_dtype for all backends fix GPU * regression testing p2: fuzz all possible cases against numpy remove hancoded tests since the fuzzer covers them * define ushort * fix indent, probably need flake8 back for CI to catch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
232ed2af3f
commit
be09cc87c1
|
@ -32,8 +32,8 @@ def remove_single_scalar_curly_braces(ptx_code):
|
|||
def render_const(args,dtype:DType):
|
||||
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
||||
|
||||
def render_cast(x:str, dtype:DType):
|
||||
return f"{x}.to({triton_dtypes[dtype]})"
|
||||
def render_cast(x:str, dtype:DType, bitcast=False):
|
||||
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
|
||||
|
||||
def define_scalar(local_size, dtype, args):
|
||||
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
||||
|
@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
|||
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
||||
local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype)
|
||||
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
||||
else: raise NotImplementedError(f"unimplemented: {uop}")
|
||||
|
||||
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
|
||||
|
|
|
@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
|||
|
||||
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
||||
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
|
@ -82,6 +82,12 @@ class TestDType(unittest.TestCase):
|
|||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
list(map(
|
||||
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return
|
||||
|
@ -140,21 +146,7 @@ class TestUint8Dtype(TestDType):
|
|||
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
|
||||
class TestBitCast(unittest.TestCase):
|
||||
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
|
||||
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# NOTE: these are the same as normal casts
|
||||
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
|
||||
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
|
||||
|
||||
def test_shape_change_bitcast(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
||||
|
|
|
@ -71,14 +71,13 @@ def torch_load(fn:str):
|
|||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
ret = t[byte_offset:byte_offset+prod(size)]
|
||||
# convert bfloat16 -> float16 using LLVM for Llama 2
|
||||
# upstream LLaMA also does this conversion:
|
||||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||||
if storage[1] == dtypes.bfloat16:
|
||||
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
||||
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||||
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half()
|
||||
else: ret = ret.cast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
|
|
|
@ -40,7 +40,8 @@ class CStyleLanguage(NamedTuple):
|
|||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{var_dtype.name}*)&{x[0]}))"
|
||||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
||||
|
@ -187,7 +188,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
if len(vin) > 3: kk("}")
|
||||
elif uop == UOps.CAST and dtype is not None:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype)
|
||||
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
|
@ -224,6 +225,9 @@ class OpenCLLanguage(CStyleLanguage):
|
|||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
|
||||
class MetalLanguage(CStyleLanguage):
|
||||
|
@ -237,6 +241,8 @@ class MetalLanguage(CStyleLanguage):
|
|||
gid = [f"gid.{chr(120+i)}" for i in range(3)]
|
||||
lid = [f"lid.{chr(120+i)}" for i in range(3)]
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
|
@ -350,8 +356,8 @@ class WGSLLanguage(CStyleLanguage):
|
|||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select(f32({y}), {x}, bool({cond}))"
|
||||
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})"
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
|
|
|
@ -28,8 +28,9 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
||||
|
||||
def cast(bb, val, input_type, output_type):
|
||||
def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
if input_type == output_type: return val
|
||||
if bitcast: return bb[-1].bitcast(val, dtype_to_llvm_dtype[output_type])
|
||||
|
||||
if dtypes.is_float(input_type):
|
||||
if dtypes.is_float(output_type):
|
||||
|
@ -149,7 +150,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
else: store_op()
|
||||
if uop == UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module), {}
|
||||
|
|
Loading…
Reference in New Issue