mirror of https://github.com/commaai/tinygrad.git
bfloat16 in LLVM (enough for llama 2) (#1293)
* add bf16 support to LLVM * bf16 read works
This commit is contained in:
parent
74e63fe4ee
commit
ca77d6cd72
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tinygrad.helpers import getenv, DType, DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from extra.utils import OSX
|
||||
from extra.utils import OSX, temp
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
print(a)
|
||||
|
@ -26,6 +26,39 @@ def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(l
|
|||
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
|
||||
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
|
||||
|
||||
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
|
||||
|
||||
def test_float_to_bf16(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
|
||||
|
||||
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
|
||||
def test_bf16(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
|
||||
t.realize()
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
# hack to "cast" f32 -> bf16
|
||||
dat = open(temp('f32'), "rb").read()
|
||||
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
|
||||
with open(temp('bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
|
||||
|
|
|
@ -41,7 +41,7 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
|
|||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
|
||||
|
||||
|
@ -104,6 +104,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
|
|||
if args.memory_dtype != newvar.dtype:
|
||||
if dtypes.is_int(args.memory_dtype):
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
lvars[newvar] = val
|
||||
|
@ -114,6 +118,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
|
|||
if args.memory_dtype != vin[0].dtype:
|
||||
if dtypes.is_int(args.memory_dtype):
|
||||
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
element = bb[-1].bitcast(element, ir.IntType(32))
|
||||
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
|
||||
element = bb[-1].trunc(element, ir.IntType(16))
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
||||
|
|
|
@ -101,6 +101,9 @@ class dtypes:
|
|||
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
|
||||
|
||||
# NOTE: bfloat16 isn't supported in numpy
|
||||
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
|
|
|
@ -185,6 +185,7 @@ class LazyBuffer:
|
|||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
assert self.dtype.np, "numpy dtype is required for toCPU"
|
||||
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
|
||||
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||
return ret
|
||||
|
|
|
@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
|||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
|
|
@ -63,10 +63,10 @@ class Tensor:
|
|||
return
|
||||
|
||||
if data.__class__ is list:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
|
||||
|
||||
if data.__class__ is np.ndarray:
|
||||
data = cast(np.ndarray, data)
|
||||
if isinstance(data, np.ndarray):
|
||||
data = LazyBuffer.fromCPU(data)
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue