simple bitcast 2 (#1445)

* simple bitcast 2

* bc 2

* empty

* Revert "empty"

This reverts commit d8ee083655b67947afb1e577020b4395d001832c.
This commit is contained in:
George Hotz 2023-08-06 00:30:50 -07:00 committed by GitHub
parent 943b227cb1
commit d67e248d9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 55 additions and 24 deletions

View File

@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target) def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
# tests no-op casts from source_dtype to target_dtypes # tests no-op casts from source_dtype to target_dtypes
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None): def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
@ -110,6 +111,25 @@ class TestInt8Dtype(unittest.TestCase):
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
class TestBitCast(unittest.TestCase):
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
# NOTE: these are the same as normal casts
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
class TestInt32Dtype(unittest.TestCase): class TestInt32Dtype(unittest.TestCase):
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4]) def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])

View File

@ -64,6 +64,9 @@ class TestUOps(unittest.TestCase):
def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b)) def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b))
# CMPLT and MOD aren't tested # CMPLT and MOD aren't tested
# doesn't work in LLVM
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32): def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]: for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0, 1, 2.0]: for a in [-2.0, 0, 1, 2.0]:

View File

@ -52,7 +52,7 @@ class Token(NamedTuple):
assert self.offset is None assert self.offset is None
return f"{self.dtype.name} {self.name}" return f"{self.dtype.name} {self.name}"
if self.offset is None: return self.name if self.offset is None: return self.name
assert self.dtype in [dtypes._float4, dtypes._float2] assert self.dtype in [dtypes._float4, dtypes._float2], f"{self.dtype} isn't okay with offset {self.offset}"
return self.name+"."+"xyzw"[int(self.offset)] return self.name+"."+"xyzw"[int(self.offset)]
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>" def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
@ -121,7 +121,7 @@ class MemOp(NamedTuple):
invalid_value: Union[float, int] = 0.0 invalid_value: Union[float, int] = 0.0
class ConstOp(NamedTuple): class ConstOp(NamedTuple):
value: float value: Union[float, int]
# shared # shared
valid: Variable valid: Variable

View File

@ -98,11 +98,13 @@ class dtypes:
float32: Final[DType] = DType(4, 4, "float", np.float32) float32: Final[DType] = DType(4, 4, "float", np.float32)
float = float32 float = float32
int8: Final[DType] = DType(0, 1, "char", np.int8) int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32) int16: Final[DType] = DType(1, 2, "short", np.int16)
int64: Final[DType] = DType(2, 8, "long", np.int64) int32: Final[DType] = DType(2, 4, "int", np.int32)
int64: Final[DType] = DType(3, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8) uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32) uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64) uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy # NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(0, 2, "__bf16", None) bfloat16: Final[DType] = DType(0, 2, "__bf16", None)

View File

@ -151,9 +151,9 @@ class LazyBuffer:
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if self.op.op == MovementOps.RESHAPE: if self.op.op == MovementOps.RESHAPE:
# put CAST before the final RESHAPE # put CAST before the final RESHAPE
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg) self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
else: else:
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32) self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
self.dtype = dtypes.float32 self.dtype = dtypes.float32
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args()) self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
@ -186,12 +186,14 @@ class LazyBuffer:
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self): def toCPU(self):
assert self.dtype.np, "numpy dtype is required for toCPU" assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized realized = self.cast((dtypes.from_np(self.dtype.np), False)).contiguous().realize().realized
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape) ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
return ret return ret
def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self def cast(self:LazyBuffer, arg:Tuple[DType, bool]) -> LazyBuffer:
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self) def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y) def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z) def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
@ -307,7 +309,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer,
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
# get outputs now # get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg) out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
# push all contiguous to the end of BinaryOps. kernels 198 -> 196 # push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs): if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):

View File

@ -10,12 +10,12 @@ class Contiguous(Function):
def backward(self, grad_output): return grad_output def backward(self, grad_output): return grad_output
class Cast(Function): class Cast(Function):
__slots__ = "input_dtype" __slots__ = "input_dtype", "bitcast"
def forward(self, x, dtype): def forward(self, x, dtype, bitcast=False):
self.input_dtype = x.dtype self.input_dtype, self.bitcast = x.dtype, bitcast
return x.cast(dtype) return x.cast((dtype, bitcast))
def backward(self, grad_output): def backward(self, grad_output):
return grad_output.cast(self.input_dtype) return grad_output.cast((self.input_dtype, self.bitcast))
# ************* unary ops ************* # ************* unary ops *************

View File

@ -113,7 +113,7 @@ class FlopCounter:
return ret return ret
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
shape_fxn_for_op: Dict[Op, Callable] = { shape_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops()), # cast uses no flops UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST}, **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps}, **{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},

View File

@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes # this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped): class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)()) def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
def _buffer(self): return memoryview(self._buf) def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn): class RawBufferCopyInOut(RawBufferCopyIn):

View File

@ -32,7 +32,8 @@ def einsum_mulacc(einsum, get_strides, expand):
return mulacc return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin, UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)), BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt, BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,

View File

@ -1,6 +1,6 @@
import os, mmap import os, mmap
from typing import Optional from typing import Optional
from typing import Callable, Dict from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
@ -21,7 +21,7 @@ class RawDiskBuffer(RawBufferMapped):
def __del__(self): def __del__(self):
self._buf[2] -= 1 self._buf[2] -= 1
if self._buf[2] == 0: self._buf[0].close() if self._buf[2] == 0: self._buf[0].close()
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset) def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset) def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg): def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}" assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"

View File

@ -10,7 +10,8 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
inverse_type_map = {v:k for k,v in type_map.items()} inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)), BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),

View File

@ -73,7 +73,8 @@ def torch_load(fn:str):
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16: if storage[1] == dtypes.bfloat16:
ret = ret.to("LLVM").half().to(Device.DEFAULT) ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]

View File

@ -684,6 +684,7 @@ class Tensor:
# ***** cast ops ***** # ***** cast ops *****
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
def bitcast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32) def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16) def half(self) -> Tensor: return self.cast(dtypes.float16)