diff --git a/extra/gemm/simple_conv.py b/extra/gemm/simple_conv.py new file mode 100644 index 00000000..0606163a --- /dev/null +++ b/extra/gemm/simple_conv.py @@ -0,0 +1,30 @@ +from tinygrad.helpers import getenv +from tinygrad import dtypes, Tensor + +dtype_in = dtypes.half if getenv("HALF") else dtypes.float +acc_dtype = dtypes.half if getenv("ACC_HALF") else None + +CNT = getenv("CNT", 8) +BS = getenv("BS", 16) +CIN = getenv("CIN", 128) +COUT = getenv("COUT", 128) +HW = getenv("HW", 128) +K = getenv("K", 3) +PADDING = getenv("PADDING", 1) +COMP = getenv("COMP", 0) + +FLOPS = BS*K*K*CIN*HW*HW*COUT*2 +def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize() + +a, b = rand_input() +for i in range(CNT): + if i > 0 and getenv("RAND", 0) != 0: + a, b = rand_input() + c = a.conv2d(b, padding=PADDING, acc_dtype=acc_dtype).realize() + +if COMP: + import numpy as np, time, torch + torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu") + ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device) + tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING) + np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index dd1883a4..d710691e 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector from tinygrad.realize import run_schedule from tinygrad.helpers import prod, Context -from tinygrad.dtype import dtypes +from tinygrad.dtype import DType, dtypes @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends") class TestLinearizer(unittest.TestCase): @@ -81,6 +81,28 @@ class TestLinearizer(unittest.TestCase): local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] assert local[0].dtype == acc_dtype + def test_arg_acc_dtype(self): + def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): + k = Linearizer(c.lazydata.schedule()[-1].ast) + k.linearize() + local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] + assert local[0].dtype == expected_dtype + + tests = ( + (dtypes.float16, None, dtypes.float), + (dtypes.bfloat16, None, dtypes.float), + (dtypes.float, None, dtypes.float), + (dtypes.float16, dtypes.float16, dtypes.float16), + (dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16), + (dtypes.float, dtypes.float16, dtypes.float16), + ) + for tensor_dtype, acc_dtype, expected_dtype in tests: + a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype) + helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype) + helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype) + d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) + helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype) + @unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device") def test_tensor_cores(self): for tc in tensor_cores[Device.DEFAULT]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ebd05208..de436480 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -528,7 +528,7 @@ class Tensor: ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) return ret if keepdim else ret.reshape(shape=shape) - def sum(self, axis=None, keepdim=False, acc_dtype=None): + def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ least_upper_dtype(self.dtype, dtypes.float) @@ -635,7 +635,7 @@ class Tensor: padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501 return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) - def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: + def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor: (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501 @@ -649,7 +649,7 @@ class Tensor: x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501 # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501 + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501 return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 @@ -676,7 +676,7 @@ class Tensor: dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) - ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) + ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype)) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) @@ -685,7 +685,7 @@ class Tensor: return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def dot(self, w:Tensor, acc_dtype=None) -> Tensor: + def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor: n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501