diff --git a/test/test_dtype.py b/test/test_dtype.py index e651523e..e146a46a 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -3,7 +3,7 @@ import numpy as np from tinygrad.helpers import getenv, DType, DEBUG from tinygrad.lazy import Device from tinygrad.tensor import Tensor, dtypes -from extra.utils import OSX +from extra.utils import OSX, temp def _test_to_np(a:Tensor, np_dtype, target): print(a) @@ -26,6 +26,39 @@ def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(l def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target) def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target) + +class TestBFloat16DType(unittest.TestCase): + def test_bf16_to_float(self): + with self.assertRaises(AssertionError): + _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000]) + + def test_float_to_bf16(self): + with self.assertRaises(AssertionError): + _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000]) + + # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) + + @unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM") + def test_bf16(self): + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) + t.realize() + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + + @unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM") + def test_bf16_disk_write_read(self): + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) + t.to(f"disk:{temp('f32')}").realize() + + # hack to "cast" f32 -> bf16 + dat = open(temp('f32'), "rb").read() + adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)]) + with open(temp('bf16'), "wb") as f: f.write(adat) + + t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize() + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for LLVM, it segfaults because it can't link to the casting function @unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends") diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index b18b6ca6..dc94fea3 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -41,7 +41,7 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str: buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function - dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} + dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()] func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') @@ -104,6 +104,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str: if args.memory_dtype != newvar.dtype: if dtypes.is_int(args.memory_dtype): val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].sitofp(val, ir.FloatType()) + elif args.memory_dtype == dtypes.bfloat16: + val = bb[-1].sext(val, ir.IntType(32)) + val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) + val = bb[-1].bitcast(val, ir.FloatType()) else: val = bb[-1].fpext(val, ir.FloatType()) lvars[newvar] = val @@ -114,6 +118,10 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str: if args.memory_dtype != vin[0].dtype: if dtypes.is_int(args.memory_dtype): element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype]) + elif args.memory_dtype == dtypes.bfloat16: + element = bb[-1].bitcast(element, ir.IntType(32)) + element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16)) + element = bb[-1].trunc(element, ir.IntType(16)) else: element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype]) bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d7636437..2f89a55a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -101,6 +101,9 @@ class dtypes: uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32) uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64) + # NOTE: bfloat16 isn't supported in numpy + bfloat16: Final[DType] = DType(0, 2, "__bf16", None) + # NOTE: these are internal dtypes, should probably check for that _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) _float2: Final[DType] = DType(4, 4*2, "float2", None, 2) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 9c173d2f..cf68b506 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -185,6 +185,7 @@ class LazyBuffer: # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? def toCPU(self): + assert self.dtype.np, "numpy dtype is required for toCPU" realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized ret = cast(RawBuffer, realized).toCPU().reshape(self.shape) return ret diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 96ab7da2..7ed80cbf 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn): # this one is simple enough that i moved it out of the runtimes class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)()) + def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)()) def _buffer(self): return memoryview(self._buf) class RawBufferCopyInOut(RawBufferCopyIn): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 544b3652..cb3a1df7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -63,10 +63,10 @@ class Tensor: return if data.__class__ is list: + assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" data = np.array(data, dtype=(dtype or Tensor.default_type).np) - if data.__class__ is np.ndarray: - data = cast(np.ndarray, data) + if isinstance(data, np.ndarray): data = LazyBuffer.fromCPU(data) self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) return