use at least int32 and uint32 for sum output (#2926)

* use at least int32 and uint32 for sum output

* use the correct type for acc

* fix opencl

* llvm mulacc
This commit is contained in:
chenyu 2023-12-24 01:14:54 -05:00 committed by GitHub
parent d424babe2c
commit b55b55d56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 37 deletions

View File

@ -4,7 +4,7 @@ from tinygrad.helpers import getenv
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2, dtype=dtypes.default_float)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)

View File

@ -77,11 +77,11 @@ class TestDType(unittest.TestCase):
def test_same_size_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype,
target_dtype=least_upper_dtype(self.DTYPE, dtype)) if dtype.itemsize == self.DTYPE.itemsize else None,
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_ops(self): list(map(
def test_upcast_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
@ -305,16 +305,16 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
core_types = list(DTYPES_DICT.values())
floats = [dt for dt in core_types if dtypes.is_float(dt)]
core_dtypes = list(DTYPES_DICT.values())
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_types))
@given(st.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):
assert least_upper_dtype(dtype) == dtype
assert least_upper_dtype(dtype, dtype) == dtype
assert least_upper_dtype(dtype, dtype, dtype) == dtype
@given(st.sampled_from(core_types), st.sampled_from(core_types))
@given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes))
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
result = least_upper_dtype(dtype1, dtype2)
assert result >= dtype1 and result >= dtype2
@ -400,5 +400,24 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2):
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
if __name__ == '__main__':
unittest.main()

View File

@ -6,7 +6,7 @@ from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
from tinygrad.codegen.kernel import LocalBuffer, Kernel
@ -50,9 +50,12 @@ class Linearizer(Kernel):
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
def get_reduce_acc(self, op, dtype:DType):
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
def get_reduce_acc(self, reduceop:LazyOp):
dtype = get_lazyop_info(reduceop).dtype
if reduceop.op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
elif reduceop.op == ReduceOps.MAX:
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -math.inf if dtypes.is_float(dtype) else False
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
@ -69,7 +72,7 @@ class Linearizer(Kernel):
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
buf = self.bufs[i]
localtype = self.get_base_dtype(buf.dtype)
localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype)
const = buf.val if isinstance(buf, ConstBuffer) else acc
def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
@ -199,8 +202,9 @@ class Linearizer(Kernel):
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size(), temp_dtype))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size())))
# kernel name (before late upcast)
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
@ -250,7 +254,7 @@ class Linearizer(Kernel):
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
if self.tensor_core:
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
@ -356,7 +360,7 @@ class Linearizer(Kernel):
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
# late reduce loop
loop_ctx = render_loop(end_local_idxs)
@ -482,12 +486,6 @@ class Linearizer(Kernel):
if arg == BinaryOps.CMPLT: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
if arg == TernaryOps.WHERE: assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
if uop == UOps.ALU: # upcast vins to the same dtype
upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float
if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool
else: vin = tuple(self.cast(x, upcast_dtype) for x in vin)
dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool
if simplify:
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
@ -535,12 +533,12 @@ class Linearizer(Kernel):
ret: List[UOp] = []
input_acc = acc[:]
for val, off in zip(zip(*values), cast(List[int], offs)):
acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op])
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[x.op])
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
else:
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)]
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else val[-1].dtype, vin=val, arg=x.op) for val in zip(*values)]
cache[x] = ret
return ret

View File

@ -223,7 +223,8 @@ class OpenCLLanguage(CStyleLanguage):
xid = [f'get_global_id({i})' for i in range(3)]
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
code_for_op = {**CStyleLanguage().code_for_op,
TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"}
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)

View File

@ -8,6 +8,11 @@ MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
def mulacc(bb, x, y, z, input_dtype, output_dtype):
if dtypes.is_float(output_dtype): return bb[-1].fadd(bb[-1].fmul(x, y, flags=MFLAGS), z, flags=MFLAGS)
# no fast math flags for int add and mul
return bb[-1].add(cast(bb, bb[-1].mul(x, y), input_dtype, output_dtype), z)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
@ -25,7 +30,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.MOD: lambda builder, x, y, var_dtype:
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS),
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
}
@ -156,7 +160,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
if len(vin) > 3:
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
else: store_op()
if uop == UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
if uop == UOps.ALU:
if args == TernaryOps.MULACC: lvars[u] = mulacc(bb, lvars[vin[0]], lvars[vin[1]], lvars[vin[2]], vin[0].dtype, dtype)
else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
bb[-1].ret_void()

View File

@ -174,7 +174,7 @@ class Tensor:
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
@staticmethod
def eye(dim:int, **kwargs):
@ -484,7 +484,11 @@ class Tensor:
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
def sum(self, axis=None, keepdim=False):
output_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else self.dtype
return self.cast(output_dtype)._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
@ -640,7 +644,7 @@ class Tensor:
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1)
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
@ -661,9 +665,9 @@ class Tensor:
assert all_int((r,c)), "does not support symbolic"
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor:
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor:
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(Tensor.zeros_like(self), self)
# ***** mlops (unary) *****