bfloat16 in LLVM (enough for llama 2) (#1293)

* add bf16 support to LLVM

* bf16 read works
This commit is contained in:
George Hotz 2023-07-19 20:18:32 -07:00 committed by GitHub
parent 74e63fe4ee
commit ca77d6cd72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 5 deletions

View File

@ -3,7 +3,7 @@ import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor, dtypes
from extra.utils import OSX
from extra.utils import OSX, temp
def _test_to_np(a:Tensor, np_dtype, target):
print(a)
@ -26,6 +26,39 @@ def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(l
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
dat = open(temp('f32'), "rb").read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")

View File

@ -41,7 +41,7 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
@ -104,6 +104,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
if args.memory_dtype != newvar.dtype:
if dtypes.is_int(args.memory_dtype):
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].sitofp(val, ir.FloatType())
elif args.memory_dtype == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
lvars[newvar] = val
@ -114,6 +118,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
if args.memory_dtype != vin[0].dtype:
if dtypes.is_int(args.memory_dtype):
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
elif args.memory_dtype == dtypes.bfloat16:
element = bb[-1].bitcast(element, ir.IntType(32))
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
element = bb[-1].trunc(element, ir.IntType(16))
else:
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))

View File

@ -101,6 +101,9 @@ class dtypes:
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)

View File

@ -185,6 +185,7 @@ class LazyBuffer:
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
assert self.dtype.np, "numpy dtype is required for toCPU"
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
return ret

View File

@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):

View File

@ -63,10 +63,10 @@ class Tensor:
return
if data.__class__ is list:
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
if data.__class__ is np.ndarray:
data = cast(np.ndarray, data)
if isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(data)
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return