mirror of https://github.com/commaai/tinygrad.git
support using str to specify dtype (#5897)
* support using str to specify dtype in Tensor creation and args into `cast` and `bitcast`, and acc_dtype * more tests
This commit is contained in:
parent
4f9221e8dd
commit
c67e9887f7
|
@ -395,6 +395,23 @@ class TestTypeSpec(unittest.TestCase):
|
|||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
for dtype_str, dtype in [
|
||||
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
|
||||
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
|
||||
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
|
||||
if dtype.itemsize == 4:
|
||||
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
|
||||
tested += 1
|
||||
assert tested == 3
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
|
||||
|
||||
np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy())
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
|
|
@ -96,6 +96,9 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
|||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
||||
|
||||
DTypeLike = Union[str, DType]
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
|
@ -9,9 +9,10 @@ from tinygrad.device import Buffer
|
|||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
|
||||
dtype = to_dtype(dtype)
|
||||
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
|
@ -23,10 +24,10 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
|
|||
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
|
||||
class LazyBuffer:
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike,
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
||||
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, dtype, st.shape, st.size, metadata
|
||||
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
||||
self._base: Optional[LazyBuffer] = None
|
||||
if base is None:
|
||||
# properties on base
|
||||
|
@ -35,9 +36,9 @@ class LazyBuffer:
|
|||
|
||||
if self.op is MetaOps.VIEW:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
else:
|
||||
self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, dtype)
|
||||
self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
|
||||
self.buffer.ref(1)
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.forced_realize = False
|
||||
|
@ -66,7 +67,7 @@ class LazyBuffer:
|
|||
def lbs(self) -> List[LazyBuffer]: return [self]
|
||||
|
||||
@staticmethod
|
||||
def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
||||
def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
||||
assert isinstance(src, tuple)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque
|
|||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -106,7 +106,8 @@ class Tensor:
|
|||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
|
||||
|
@ -361,7 +362,7 @@ class Tensor:
|
|||
# ***** creation entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
||||
if isinstance(device, tuple):
|
||||
return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
||||
for d in device], None), device, dtype, **kwargs)
|
||||
|
@ -403,7 +404,7 @@ class Tensor:
|
|||
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
||||
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
|
@ -419,7 +420,7 @@ class Tensor:
|
|||
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
if not THREEFRY.value:
|
||||
# for bfloat16, numpy rand passes buffer in float
|
||||
if (dtype or dtypes.default_float) == dtypes.bfloat16:
|
||||
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
|
||||
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
return Tensor._metaop(MetaOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
|
@ -588,7 +589,7 @@ class Tensor:
|
|||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
|
@ -1308,7 +1309,7 @@ class Tensor:
|
|||
ret = fxn.apply(self, axis=axis_)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
||||
|
||||
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
||||
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
"""
|
||||
Sums the elements of the tensor along the specified axis or axes.
|
||||
|
||||
|
@ -1630,7 +1631,7 @@ class Tensor:
|
|||
return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@staticmethod
|
||||
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
|
@ -1734,7 +1735,7 @@ class Tensor:
|
|||
padding_ = self._padding2d(padding, len(k_ := make_pair(kernel_size)))
|
||||
return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
|
@ -1821,7 +1822,7 @@ class Tensor:
|
|||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
"""
|
||||
Performs dot product between two tensors.
|
||||
|
||||
|
@ -1840,7 +1841,7 @@ class Tensor:
|
|||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
||||
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
"""
|
||||
Performs matrix multiplication between two tensors.
|
||||
|
||||
|
@ -2960,12 +2961,12 @@ class Tensor:
|
|||
|
||||
# ***** cast ops *****
|
||||
|
||||
def llvm_bf16_cast(self, dtype:DType):
|
||||
def llvm_bf16_cast(self, dtype:DTypeLike):
|
||||
# hack for devices that don't support bfloat16
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
|
||||
def cast(self, dtype:DType) -> Tensor:
|
||||
def cast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Casts `self` to the given `dtype`.
|
||||
|
||||
|
@ -2978,9 +2979,9 @@ class Tensor:
|
|||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
|
||||
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
|
||||
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Bitcasts `self` to the given `dtype` of the same itemsize.
|
||||
|
||||
|
@ -2996,7 +2997,7 @@ class Tensor:
|
|||
```
|
||||
"""
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != (dt:=to_dtype(dtype)) else self
|
||||
|
||||
def float(self) -> Tensor:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue