mirror of https://github.com/commaai/tinygrad.git
extra/gemm: add a simple_conv.py along with correctness check (#3236)
* extra/gemm: add a simple_conv.py along with correctness check The goal is to easily test tensor core triggering situations * test: add tests for acc_dtype handling and fixed typing
This commit is contained in:
parent
0aad8d238b
commit
4273aabe31
|
@ -0,0 +1,30 @@
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
|
||||
|
||||
CNT = getenv("CNT", 8)
|
||||
BS = getenv("BS", 16)
|
||||
CIN = getenv("CIN", 128)
|
||||
COUT = getenv("COUT", 128)
|
||||
HW = getenv("HW", 128)
|
||||
K = getenv("K", 3)
|
||||
PADDING = getenv("PADDING", 1)
|
||||
COMP = getenv("COMP", 0)
|
||||
|
||||
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
||||
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
|
||||
|
||||
a, b = rand_input()
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = rand_input()
|
||||
c = a.conv2d(b, padding=PADDING, acc_dtype=acc_dtype).realize()
|
||||
|
||||
if COMP:
|
||||
import numpy as np, time, torch
|
||||
torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
|
||||
ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
|
||||
tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
|
||||
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2)
|
|
@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.helpers import prod, Context
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
|
@ -81,6 +81,28 @@ class TestLinearizer(unittest.TestCase):
|
|||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(c.lazydata.schedule()[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
|
||||
tests = (
|
||||
(dtypes.float16, None, dtypes.float),
|
||||
(dtypes.bfloat16, None, dtypes.float),
|
||||
(dtypes.float, None, dtypes.float),
|
||||
(dtypes.float16, dtypes.float16, dtypes.float16),
|
||||
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
|
||||
(dtypes.float, dtypes.float16, dtypes.float16),
|
||||
)
|
||||
for tensor_dtype, acc_dtype, expected_dtype in tests:
|
||||
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
|
||||
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device")
|
||||
def test_tensor_cores(self):
|
||||
for tc in tensor_cores[Device.DEFAULT]:
|
||||
|
|
|
@ -528,7 +528,7 @@ class Tensor:
|
|||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False, acc_dtype=None):
|
||||
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
||||
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
|
@ -635,7 +635,7 @@ class Tensor:
|
|||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
||||
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
||||
|
@ -649,7 +649,7 @@ class Tensor:
|
|||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
|
@ -676,7 +676,7 @@ class Tensor:
|
|||
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
|
||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW)))
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype))
|
||||
|
||||
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
||||
|
@ -685,7 +685,7 @@ class Tensor:
|
|||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def dot(self, w:Tensor, acc_dtype=None) -> Tensor:
|
||||
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
|
|
Loading…
Reference in New Issue