diff --git a/test/test_dtype.py b/test/test_dtype.py index 608a959b..1d2d69a7 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target) +def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) # tests no-op casts from source_dtype to target_dtypes def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None): @@ -110,6 +111,25 @@ class TestInt8Dtype(unittest.TestCase): def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) +@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") +class TestBitCast(unittest.TestCase): + def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") + def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432]) + def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0]) + + # NOTE: these are the same as normal casts + def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252]) + def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4]) + + def test_shape_change_bitcast(self): + with self.assertRaises(AssertionError): + _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) + class TestInt32Dtype(unittest.TestCase): def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4]) diff --git a/test/test_uops.py b/test/test_uops.py index 1f7901e0..3b75a1b9 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -64,6 +64,9 @@ class TestUOps(unittest.TestCase): def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b)) # CMPLT and MOD aren't tested + # doesn't work in LLVM + #def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32) + def _test_top_fxn(self, bop, fxn, dt=dtypes.float32): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0, 1, 2.0]: diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index d77dffbe..0a9c4635 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -52,7 +52,7 @@ class Token(NamedTuple): assert self.offset is None return f"{self.dtype.name} {self.name}" if self.offset is None: return self.name - assert self.dtype in [dtypes._float4, dtypes._float2] + assert self.dtype in [dtypes._float4, dtypes._float2], f"{self.dtype} isn't okay with offset {self.offset}" return self.name+"."+"xyzw"[int(self.offset)] def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>" @@ -121,7 +121,7 @@ class MemOp(NamedTuple): invalid_value: Union[float, int] = 0.0 class ConstOp(NamedTuple): - value: float + value: Union[float, int] # shared valid: Variable diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 3cf9d885..7397aaec 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -98,11 +98,13 @@ class dtypes: float32: Final[DType] = DType(4, 4, "float", np.float32) float = float32 int8: Final[DType] = DType(0, 1, "char", np.int8) - int32: Final[DType] = DType(1, 4, "int", np.int32) - int64: Final[DType] = DType(2, 8, "long", np.int64) + int16: Final[DType] = DType(1, 2, "short", np.int16) + int32: Final[DType] = DType(2, 4, "int", np.int32) + int64: Final[DType] = DType(3, 8, "long", np.int64) uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8) - uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32) - uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64) + uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16) + uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32) + uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64) # NOTE: bfloat16 isn't supported in numpy bfloat16: Final[DType] = DType(0, 2, "__bf16", None) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 3200ba28..38867885 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -151,9 +151,9 @@ class LazyBuffer: if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): if self.op.op == MovementOps.RESHAPE: # put CAST before the final RESHAPE - self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg) + self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg) else: - self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32) + self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False)) self.dtype = dtypes.float32 self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args()) @@ -186,12 +186,14 @@ class LazyBuffer: # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? def toCPU(self): - assert self.dtype.np, "numpy dtype is required for toCPU" - realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized + assert self.dtype.np, f"{self.dtype} is not supported in toCPU" + realized = self.cast((dtypes.from_np(self.dtype.np), False)).contiguous().realize().realized ret = cast(RawBuffer, realized).toCPU().reshape(self.shape) return ret - def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self + def cast(self:LazyBuffer, arg:Tuple[DType, bool]) -> LazyBuffer: + assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes" + return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self) def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y) def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z) @@ -307,7 +309,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) # get outputs now - out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg) + out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0] # push all contiguous to the end of BinaryOps. kernels 198 -> 196 if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs): diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index fd859f0e..3cf762c0 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -10,12 +10,12 @@ class Contiguous(Function): def backward(self, grad_output): return grad_output class Cast(Function): - __slots__ = "input_dtype" - def forward(self, x, dtype): - self.input_dtype = x.dtype - return x.cast(dtype) + __slots__ = "input_dtype", "bitcast" + def forward(self, x, dtype, bitcast=False): + self.input_dtype, self.bitcast = x.dtype, bitcast + return x.cast((dtype, bitcast)) def backward(self, grad_output): - return grad_output.cast(self.input_dtype) + return grad_output.cast((self.input_dtype, self.bitcast)) # ************* unary ops ************* diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 300057f3..b2ded5a9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -113,7 +113,7 @@ class FlopCounter: return ret from tinygrad.shape.shapetracker import ShapeTracker shape_fxn_for_op: Dict[Op, Callable] = { - UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops()), # cast uses no flops + UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST}, **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, **{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps}, diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 7ed80cbf..6b32ffa9 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn): # this one is simple enough that i moved it out of the runtimes class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)()) + def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)()) def _buffer(self): return memoryview(self._buf) class RawBufferCopyInOut(RawBufferCopyIn): diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index ff8fa301..648648ac 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -32,7 +32,8 @@ def einsum_mulacc(einsum, get_strides, expand): return mulacc numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin, + UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, + UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)), BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt, diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index a29d117b..b48c32a3 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,6 +1,6 @@ import os, mmap from typing import Optional -from typing import Callable, Dict +from typing import Callable, Dict, Tuple from tinygrad.helpers import prod, DType from tinygrad.runtime.lib import RawBufferMapped from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps @@ -21,7 +21,7 @@ class RawDiskBuffer(RawBufferMapped): def __del__(self): self._buf[2] -= 1 if self._buf[2] == 0: self._buf[0].close() - def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset) + def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset) def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset) def shrink(self, arg): assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}" diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index b68b2bc1..2d959e27 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -10,7 +10,8 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch. inverse_type_map = {v:k for k,v in type_map.items()} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, + UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin, + UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), diff --git a/tinygrad/state.py b/tinygrad/state.py index af0f04f1..6b63cb95 100644 --- a/tinygrad/state.py +++ b/tinygrad/state.py @@ -73,7 +73,8 @@ def torch_load(fn:str): # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support if storage[1] == dtypes.bfloat16: - ret = ret.to("LLVM").half().to(Device.DEFAULT) + ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half() + #ret = ret.to("LLVM").half().to(Device.DEFAULT) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1eefd86d..c479c25f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -684,6 +684,7 @@ class Tensor: # ***** cast ops ***** def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self + def bitcast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self def float(self) -> Tensor: return self.cast(dtypes.float32) def half(self) -> Tensor: return self.cast(dtypes.float16)