mirror of https://github.com/commaai/tinygrad.git
use at least int32 and uint32 for sum output (#2926)
* use at least int32 and uint32 for sum output * use the correct type for acc * fix opencl * llvm mulacc
This commit is contained in:
parent
d424babe2c
commit
b55b55d56e
|
@ -4,7 +4,7 @@ from tinygrad.helpers import getenv
|
|||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2, dtype=dtypes.default_float)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
|
|
|
@ -77,18 +77,18 @@ class TestDType(unittest.TestCase):
|
|||
|
||||
def test_same_size_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype,
|
||||
target_dtype=least_upper_dtype(self.DTYPE, dtype)) if dtype.itemsize == self.DTYPE.itemsize else None,
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_upcast_ops(self): list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
def test_upcast_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_upcast_to_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
|
@ -305,16 +305,16 @@ class TestTypeSpec(unittest.TestCase):
|
|||
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
|
||||
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_types if dtypes.is_float(dt)]
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
@given(st.sampled_from(core_dtypes))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
assert least_upper_dtype(dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype, dtype) == dtype
|
||||
|
||||
@given(st.sampled_from(core_types), st.sampled_from(core_types))
|
||||
@given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes))
|
||||
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
|
||||
result = least_upper_dtype(dtype1, dtype2)
|
||||
assert result >= dtype1 and result >= dtype2
|
||||
|
@ -400,5 +400,24 @@ class TestAutoCastType(unittest.TestCase):
|
|||
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes))
|
||||
def test_matmul(self, dt1, dt2):
|
||||
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -6,7 +6,7 @@ from enum import Enum, auto
|
|||
from dataclasses import dataclass
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
|
@ -50,9 +50,12 @@ class Linearizer(Kernel):
|
|||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self, op, dtype:DType):
|
||||
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
dtype = get_lazyop_info(reduceop).dtype
|
||||
if reduceop.op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif reduceop.op == ReduceOps.MAX:
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -math.inf if dtypes.is_float(dtype) else False
|
||||
|
||||
# NOTE: once images are loaded, we uop them as their base float
|
||||
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
|
||||
|
@ -69,7 +72,7 @@ class Linearizer(Kernel):
|
|||
|
||||
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
localtype = self.get_base_dtype(buf.dtype)
|
||||
localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype)
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
|
||||
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
|
||||
|
@ -199,8 +202,9 @@ class Linearizer(Kernel):
|
|||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size(), temp_dtype))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
@ -250,7 +254,7 @@ class Linearizer(Kernel):
|
|||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
if self.tensor_core:
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
|
@ -356,7 +360,7 @@ class Linearizer(Kernel):
|
|||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = render_loop(end_local_idxs)
|
||||
|
@ -482,12 +486,6 @@ class Linearizer(Kernel):
|
|||
if arg == BinaryOps.CMPLT: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
if arg == TernaryOps.WHERE: assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
|
||||
if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
|
||||
if uop == UOps.ALU: # upcast vins to the same dtype
|
||||
upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float
|
||||
if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool
|
||||
else: vin = tuple(self.cast(x, upcast_dtype) for x in vin)
|
||||
dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool
|
||||
if simplify:
|
||||
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
||||
|
@ -535,12 +533,12 @@ class Linearizer(Kernel):
|
|||
ret: List[UOp] = []
|
||||
input_acc = acc[:]
|
||||
for val, off in zip(zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op])
|
||||
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[x.op])
|
||||
ret.append(acc[off])
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)]
|
||||
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else val[-1].dtype, vin=val, arg=x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
|
|
@ -223,7 +223,8 @@ class OpenCLLanguage(CStyleLanguage):
|
|||
xid = [f'get_global_id({i})' for i in range(3)]
|
||||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"}
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
|
|
@ -8,6 +8,11 @@ MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but
|
|||
|
||||
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
def mulacc(bb, x, y, z, input_dtype, output_dtype):
|
||||
if dtypes.is_float(output_dtype): return bb[-1].fadd(bb[-1].fmul(x, y, flags=MFLAGS), z, flags=MFLAGS)
|
||||
# no fast math flags for int add and mul
|
||||
return bb[-1].add(cast(bb, bb[-1].mul(x, y), input_dtype, output_dtype), z)
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
|
||||
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
||||
|
@ -25,7 +30,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
BinaryOps.MOD: lambda builder, x, y, var_dtype:
|
||||
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
|
||||
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
|
||||
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS),
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
|
||||
}
|
||||
|
||||
|
@ -156,7 +160,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
if len(vin) > 3:
|
||||
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
|
||||
else: store_op()
|
||||
if uop == UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
|
||||
if uop == UOps.ALU:
|
||||
if args == TernaryOps.MULACC: lvars[u] = mulacc(bb, lvars[vin[0]], lvars[vin[1]], lvars[vin[2]], vin[0].dtype, dtype)
|
||||
else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
|
||||
bb[-1].ret_void()
|
||||
|
|
|
@ -174,7 +174,7 @@ class Tensor:
|
|||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs):
|
||||
|
@ -484,7 +484,11 @@ class Tensor:
|
|||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
||||
def sum(self, axis=None, keepdim=False):
|
||||
output_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else self.dtype
|
||||
return self.cast(output_dtype)._reduce(mlops.Sum, axis, keepdim)
|
||||
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
|
@ -640,7 +644,7 @@ class Tensor:
|
|||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
|
@ -661,9 +665,9 @@ class Tensor:
|
|||
assert all_int((r,c)), "does not support symbolic"
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
|
|
Loading…
Reference in New Issue