mirror of https://github.com/commaai/tinygrad.git
simple bitcast 2 (#1445)
* simple bitcast 2 * bc 2 * empty * Revert "empty" This reverts commit d8ee083655b67947afb1e577020b4395d001832c.
This commit is contained in:
parent
943b227cb1
commit
d67e248d9b
|
@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
||||||
|
|
||||||
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
||||||
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
|
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
|
||||||
|
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
|
||||||
|
|
||||||
# tests no-op casts from source_dtype to target_dtypes
|
# tests no-op casts from source_dtype to target_dtypes
|
||||||
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
|
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
|
||||||
|
@ -110,6 +111,25 @@ class TestInt8Dtype(unittest.TestCase):
|
||||||
|
|
||||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||||
|
|
||||||
|
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
|
||||||
|
class TestBitCast(unittest.TestCase):
|
||||||
|
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||||
|
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
|
||||||
|
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||||
|
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
|
||||||
|
|
||||||
|
# NOTE: these are the same as normal casts
|
||||||
|
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
|
||||||
|
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
|
||||||
|
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||||
|
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
|
||||||
|
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||||
|
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
|
||||||
|
|
||||||
|
def test_shape_change_bitcast(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
||||||
|
|
||||||
class TestInt32Dtype(unittest.TestCase):
|
class TestInt32Dtype(unittest.TestCase):
|
||||||
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
|
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,9 @@ class TestUOps(unittest.TestCase):
|
||||||
def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b))
|
def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b))
|
||||||
# CMPLT and MOD aren't tested
|
# CMPLT and MOD aren't tested
|
||||||
|
|
||||||
|
# doesn't work in LLVM
|
||||||
|
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
|
||||||
|
|
||||||
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||||
for f in [_test_single_value, _test_single_value_const]:
|
for f in [_test_single_value, _test_single_value_const]:
|
||||||
for a in [-2.0, 0, 1, 2.0]:
|
for a in [-2.0, 0, 1, 2.0]:
|
||||||
|
|
|
@ -52,7 +52,7 @@ class Token(NamedTuple):
|
||||||
assert self.offset is None
|
assert self.offset is None
|
||||||
return f"{self.dtype.name} {self.name}"
|
return f"{self.dtype.name} {self.name}"
|
||||||
if self.offset is None: return self.name
|
if self.offset is None: return self.name
|
||||||
assert self.dtype in [dtypes._float4, dtypes._float2]
|
assert self.dtype in [dtypes._float4, dtypes._float2], f"{self.dtype} isn't okay with offset {self.offset}"
|
||||||
return self.name+"."+"xyzw"[int(self.offset)]
|
return self.name+"."+"xyzw"[int(self.offset)]
|
||||||
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
|
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ class MemOp(NamedTuple):
|
||||||
invalid_value: Union[float, int] = 0.0
|
invalid_value: Union[float, int] = 0.0
|
||||||
|
|
||||||
class ConstOp(NamedTuple):
|
class ConstOp(NamedTuple):
|
||||||
value: float
|
value: Union[float, int]
|
||||||
|
|
||||||
# shared
|
# shared
|
||||||
valid: Variable
|
valid: Variable
|
||||||
|
|
|
@ -98,11 +98,13 @@ class dtypes:
|
||||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||||
float = float32
|
float = float32
|
||||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
||||||
int64: Final[DType] = DType(2, 8, "long", np.int64)
|
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
||||||
|
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
||||||
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
||||||
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
|
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
||||||
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
|
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
||||||
|
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
||||||
|
|
||||||
# NOTE: bfloat16 isn't supported in numpy
|
# NOTE: bfloat16 isn't supported in numpy
|
||||||
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
||||||
|
|
|
@ -151,9 +151,9 @@ class LazyBuffer:
|
||||||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||||
if self.op.op == MovementOps.RESHAPE:
|
if self.op.op == MovementOps.RESHAPE:
|
||||||
# put CAST before the final RESHAPE
|
# put CAST before the final RESHAPE
|
||||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
|
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
|
||||||
else:
|
else:
|
||||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
|
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
|
||||||
self.dtype = dtypes.float32
|
self.dtype = dtypes.float32
|
||||||
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
||||||
|
|
||||||
|
@ -186,12 +186,14 @@ class LazyBuffer:
|
||||||
|
|
||||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||||
def toCPU(self):
|
def toCPU(self):
|
||||||
assert self.dtype.np, "numpy dtype is required for toCPU"
|
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||||
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
|
realized = self.cast((dtypes.from_np(self.dtype.np), False)).contiguous().realize().realized
|
||||||
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self
|
def cast(self:LazyBuffer, arg:Tuple[DType, bool]) -> LazyBuffer:
|
||||||
|
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||||
|
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
|
||||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||||
|
@ -307,7 +309,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer,
|
||||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||||
|
|
||||||
# get outputs now
|
# get outputs now
|
||||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
|
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
|
||||||
|
|
||||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||||
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||||
|
|
|
@ -10,12 +10,12 @@ class Contiguous(Function):
|
||||||
def backward(self, grad_output): return grad_output
|
def backward(self, grad_output): return grad_output
|
||||||
|
|
||||||
class Cast(Function):
|
class Cast(Function):
|
||||||
__slots__ = "input_dtype"
|
__slots__ = "input_dtype", "bitcast"
|
||||||
def forward(self, x, dtype):
|
def forward(self, x, dtype, bitcast=False):
|
||||||
self.input_dtype = x.dtype
|
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||||
return x.cast(dtype)
|
return x.cast((dtype, bitcast))
|
||||||
def backward(self, grad_output):
|
def backward(self, grad_output):
|
||||||
return grad_output.cast(self.input_dtype)
|
return grad_output.cast((self.input_dtype, self.bitcast))
|
||||||
|
|
||||||
# ************* unary ops *************
|
# ************* unary ops *************
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ class FlopCounter:
|
||||||
return ret
|
return ret
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||||
UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops()), # cast uses no flops
|
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
|
||||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||||
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||||
|
|
|
@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
||||||
|
|
||||||
# this one is simple enough that i moved it out of the runtimes
|
# this one is simple enough that i moved it out of the runtimes
|
||||||
class RawMallocBuffer(RawBufferMapped):
|
class RawMallocBuffer(RawBufferMapped):
|
||||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
|
||||||
def _buffer(self): return memoryview(self._buf)
|
def _buffer(self): return memoryview(self._buf)
|
||||||
|
|
||||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||||
|
|
|
@ -32,7 +32,8 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||||
return mulacc
|
return mulacc
|
||||||
|
|
||||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin,
|
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||||
|
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os, mmap
|
import os, mmap
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict, Tuple
|
||||||
from tinygrad.helpers import prod, DType
|
from tinygrad.helpers import prod, DType
|
||||||
from tinygrad.runtime.lib import RawBufferMapped
|
from tinygrad.runtime.lib import RawBufferMapped
|
||||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
|
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
|
||||||
|
@ -21,7 +21,7 @@ class RawDiskBuffer(RawBufferMapped):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self._buf[2] -= 1
|
self._buf[2] -= 1
|
||||||
if self._buf[2] == 0: self._buf[0].close()
|
if self._buf[2] == 0: self._buf[0].close()
|
||||||
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
|
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
||||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||||
def shrink(self, arg):
|
def shrink(self, arg):
|
||||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||||
|
|
|
@ -10,7 +10,8 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
|
||||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||||
|
|
||||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||||
|
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
|
||||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
|
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||||
|
|
|
@ -73,7 +73,8 @@ def torch_load(fn:str):
|
||||||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||||
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||||||
if storage[1] == dtypes.bfloat16:
|
if storage[1] == dtypes.bfloat16:
|
||||||
ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
||||||
|
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||||||
|
|
||||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||||
|
|
|
@ -684,6 +684,7 @@ class Tensor:
|
||||||
# ***** cast ops *****
|
# ***** cast ops *****
|
||||||
|
|
||||||
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
||||||
|
def bitcast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue