support using str to specify dtype (#5897)

* support using str to specify dtype

in Tensor creation and args into `cast` and `bitcast`, and acc_dtype

* more tests
This commit is contained in:
chenyu 2024-08-04 12:56:28 -04:00 committed by GitHub
parent 4f9221e8dd
commit c67e9887f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 23 deletions

View File

@ -395,6 +395,23 @@ class TestTypeSpec(unittest.TestCase):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float

View File

@ -96,6 +96,9 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
DTypeLike = Union[str, DType]
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
from tinygrad.shape.symbolic import sint, Variable
@ -9,9 +9,10 @@ from tinygrad.device import Buffer
from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
dtype = to_dtype(dtype)
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
@ -23,10 +24,10 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
class LazyBuffer:
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, dtype, st.shape, st.size, metadata
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
self._base: Optional[LazyBuffer] = None
if base is None:
# properties on base
@ -35,9 +36,9 @@ class LazyBuffer:
if self.op is MetaOps.VIEW:
# some LazyBuffers can be processed with only a view, no AST required
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
else:
self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
self.buffer.ref(1)
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
@ -66,7 +67,7 @@ class LazyBuffer:
def lbs(self) -> List[LazyBuffer]: return [self]
@staticmethod
def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
assert isinstance(src, tuple)
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)

View File

@ -7,7 +7,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
@ -106,7 +106,8 @@ class Tensor:
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
@ -361,7 +362,7 @@ class Tensor:
# ***** creation entrypoint *****
@staticmethod
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
if isinstance(device, tuple):
return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
for d in device], None), device, dtype, **kwargs)
@ -403,7 +404,7 @@ class Tensor:
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
@staticmethod
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
@ -419,7 +420,7 @@ class Tensor:
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
if not THREEFRY.value:
# for bfloat16, numpy rand passes buffer in float
if (dtype or dtypes.default_float) == dtypes.bfloat16:
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
return Tensor._metaop(MetaOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
@ -588,7 +589,7 @@ class Tensor:
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
@ -1308,7 +1309,7 @@ class Tensor:
ret = fxn.apply(self, axis=axis_)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
"""
Sums the elements of the tensor along the specified axis or axes.
@ -1630,7 +1631,7 @@ class Tensor:
return (-self).argmax(axis=axis, keepdim=keepdim)
@staticmethod
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
@ -1734,7 +1735,7 @@ class Tensor:
padding_ = self._padding2d(padding, len(k_ := make_pair(kernel_size)))
return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Applies a convolution over a tensor with a given `weight` and optional `bias`.
@ -1821,7 +1822,7 @@ class Tensor:
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Performs dot product between two tensors.
@ -1840,7 +1841,7 @@ class Tensor:
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Performs matrix multiplication between two tensors.
@ -2960,12 +2961,12 @@ class Tensor:
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DType):
def llvm_bf16_cast(self, dtype:DTypeLike):
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor:
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.
@ -2978,9 +2979,9 @@ class Tensor:
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
def bitcast(self, dtype:DType) -> Tensor:
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.
@ -2996,7 +2997,7 @@ class Tensor:
```
"""
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != (dt:=to_dtype(dtype)) else self
def float(self) -> Tensor:
"""