mirror of https://github.com/commaai/tinygrad.git
simple bitcast 2 (#1445)
* simple bitcast 2 * bc 2 * empty * Revert "empty" This reverts commit d8ee083655b67947afb1e577020b4395d001832c.
This commit is contained in:
parent
943b227cb1
commit
d67e248d9b
|
@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
|||
|
||||
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
||||
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
|
||||
|
||||
# tests no-op casts from source_dtype to target_dtypes
|
||||
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
|
||||
|
@ -110,6 +111,25 @@ class TestInt8Dtype(unittest.TestCase):
|
|||
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
|
||||
class TestBitCast(unittest.TestCase):
|
||||
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
|
||||
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# NOTE: these are the same as normal casts
|
||||
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
|
||||
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
|
||||
|
||||
def test_shape_change_bitcast(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
||||
|
||||
class TestInt32Dtype(unittest.TestCase):
|
||||
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
|
||||
|
||||
|
|
|
@ -64,6 +64,9 @@ class TestUOps(unittest.TestCase):
|
|||
def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b))
|
||||
# CMPLT and MOD aren't tested
|
||||
|
||||
# doesn't work in LLVM
|
||||
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
|
||||
|
||||
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0, 1, 2.0]:
|
||||
|
|
|
@ -52,7 +52,7 @@ class Token(NamedTuple):
|
|||
assert self.offset is None
|
||||
return f"{self.dtype.name} {self.name}"
|
||||
if self.offset is None: return self.name
|
||||
assert self.dtype in [dtypes._float4, dtypes._float2]
|
||||
assert self.dtype in [dtypes._float4, dtypes._float2], f"{self.dtype} isn't okay with offset {self.offset}"
|
||||
return self.name+"."+"xyzw"[int(self.offset)]
|
||||
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
|
||||
|
||||
|
@ -121,7 +121,7 @@ class MemOp(NamedTuple):
|
|||
invalid_value: Union[float, int] = 0.0
|
||||
|
||||
class ConstOp(NamedTuple):
|
||||
value: float
|
||||
value: Union[float, int]
|
||||
|
||||
# shared
|
||||
valid: Variable
|
||||
|
|
|
@ -98,11 +98,13 @@ class dtypes:
|
|||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float = float32
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(2, 8, "long", np.int64)
|
||||
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
||||
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
|
||||
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
||||
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
||||
|
||||
# NOTE: bfloat16 isn't supported in numpy
|
||||
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
||||
|
|
|
@ -151,9 +151,9 @@ class LazyBuffer:
|
|||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
# put CAST before the final RESHAPE
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
|
||||
else:
|
||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
|
||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
|
||||
self.dtype = dtypes.float32
|
||||
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
||||
|
||||
|
@ -186,12 +186,14 @@ class LazyBuffer:
|
|||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
assert self.dtype.np, "numpy dtype is required for toCPU"
|
||||
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
|
||||
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||
realized = self.cast((dtypes.from_np(self.dtype.np), False)).contiguous().realize().realized
|
||||
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||
return ret
|
||||
|
||||
def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self
|
||||
def cast(self:LazyBuffer, arg:Tuple[DType, bool]) -> LazyBuffer:
|
||||
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||
|
@ -307,7 +309,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer,
|
|||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
# get outputs now
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
|
|
|
@ -10,12 +10,12 @@ class Contiguous(Function):
|
|||
def backward(self, grad_output): return grad_output
|
||||
|
||||
class Cast(Function):
|
||||
__slots__ = "input_dtype"
|
||||
def forward(self, x, dtype):
|
||||
self.input_dtype = x.dtype
|
||||
return x.cast(dtype)
|
||||
__slots__ = "input_dtype", "bitcast"
|
||||
def forward(self, x, dtype, bitcast=False):
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
return x.cast((dtype, bitcast))
|
||||
def backward(self, grad_output):
|
||||
return grad_output.cast(self.input_dtype)
|
||||
return grad_output.cast((self.input_dtype, self.bitcast))
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ class FlopCounter:
|
|||
return ret
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops()), # cast uses no flops
|
||||
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
|
||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||
|
|
|
@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
|||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
|
|
@ -32,7 +32,8 @@ def einsum_mulacc(einsum, get_strides, expand):
|
|||
return mulacc
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin,
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os, mmap
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import prod, DType
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
|
||||
|
@ -21,7 +21,7 @@ class RawDiskBuffer(RawBufferMapped):
|
|||
def __del__(self):
|
||||
self._buf[2] -= 1
|
||||
if self._buf[2] == 0: self._buf[0].close()
|
||||
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||
def shrink(self, arg):
|
||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
|
|
|
@ -10,7 +10,8 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
|
|||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
|
|
|
@ -73,7 +73,8 @@ def torch_load(fn:str):
|
|||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||||
if storage[1] == dtypes.bfloat16:
|
||||
ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||||
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
||||
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
|
|
|
@ -684,6 +684,7 @@ class Tensor:
|
|||
# ***** cast ops *****
|
||||
|
||||
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
||||
def bitcast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||
|
||||
|
|
Loading…
Reference in New Issue