Bitcast support / fast bf16 load (#2011)

* bitcast renderers

* fast llama load

* make it one kernel

* regression testing p1: re-enable test_dtype for all backends

fix GPU

* regression testing p2: fuzz all possible cases against numpy

remove hancoded tests since the fuzzer covers them

* define ushort

* fix indent, probably need flake8 back for CI to catch

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal 2023-12-05 19:19:28 -05:00 committed by GitHub
parent 232ed2af3f
commit be09cc87c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 28 deletions

View File

@ -32,8 +32,8 @@ def remove_single_scalar_curly_braces(ptx_code):
def render_const(args,dtype:DType):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
def render_cast(x:str, dtype:DType):
return f"{x}.to({triton_dtypes[dtype]})"
def render_cast(x:str, dtype:DType, bitcast=False):
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
def define_scalar(local_size, dtype, args):
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype)
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
else: raise NotImplementedError(f"unimplemented: {uop}")
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"

View File

@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@ -82,6 +82,12 @@ class TestDType(unittest.TestCase):
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return
@ -140,21 +146,7 @@ class TestUint8Dtype(TestDType):
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
class TestBitCast(unittest.TestCase):
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
# NOTE: these are the same as normal casts
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])

View File

@ -71,14 +71,13 @@ def torch_load(fn:str):
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
ret = t[byte_offset:byte_offset+prod(size)]
# convert bfloat16 -> float16 using LLVM for Llama 2
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16:
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half()
else: ret = ret.cast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]

View File

@ -40,7 +40,8 @@ class CStyleLanguage(NamedTuple):
}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType) -> str:
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{var_dtype.name}*)&{x[0]}))"
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
@ -187,7 +188,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
if len(vin) > 3: kk("}")
elif uop == UOps.CAST and dtype is not None:
val = lang.render_cast([r[x] for x in vin], dtype)
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=isinstance(args, tuple) and args[1])
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
elif uop == UOps.DEFINE_LOCAL:
@ -224,6 +225,9 @@ class OpenCLLanguage(CStyleLanguage):
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
class MetalLanguage(CStyleLanguage):
@ -237,6 +241,8 @@ class MetalLanguage(CStyleLanguage):
gid = [f"gid.{chr(120+i)}" for i in range(3)]
lid = [f"lid.{chr(120+i)}" for i in range(3)]
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
class CUDALanguage(CStyleLanguage):
@ -350,8 +356,8 @@ class WGSLLanguage(CStyleLanguage):
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})"
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:

View File

@ -28,8 +28,9 @@ code_for_op: Final[Dict[Op, Callable]] = {
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
def cast(bb, val, input_type, output_type):
def cast(bb, val, input_type, output_type, bitcast=False):
if input_type == output_type: return val
if bitcast: return bb[-1].bitcast(val, dtype_to_llvm_dtype[output_type])
if dtypes.is_float(input_type):
if dtypes.is_float(output_type):
@ -149,7 +150,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
else: store_op()
if uop == UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
bb[-1].ret_void()
return str(module), {}