import torch import time import numpy as np import unittest from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, IMAGE FORWARD_ONLY = getenv("FORWARD_ONLY", 0) def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn torch.manual_seed(0) np.random.seed(0) if shps is None: ts = [torch.tensor(x, requires_grad=True) for x in vals] else: ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps] tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts] st = time.monotonic() out = torch_fxn(*ts) torch_fp = time.monotonic() - st st = time.monotonic() ret = tinygrad_fxn(*tst).realize() tinygrad_fp = time.monotonic() - st def compare(s, x,y,atol,rtol): if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}" try: np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) except Exception: raise Exception(f"{s} failed shape {x.shape}") compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol) torch_fbp, tinygrad_fbp = np.nan, np.nan if not forward_only and not FORWARD_ONLY: st = time.monotonic() out.square().mean().backward() torch_fbp = time.monotonic() - st st = time.monotonic() ret.square().mean().backward() for tt in tst: tt.grad.realize() tinygrad_fbp = time.monotonic() - st for i, (t, tt) in enumerate(zip(ts, tst)): compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol) print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") class TestOps(unittest.TestCase): def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) def _test_cmp(self, fxn, reverse=True): for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: helper_test_op(shps, fxn, fxn, forward_only=True) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) helper_test_op(None, lambda x,y: fxn(x,2), lambda x,y: fxn(x,2), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) if reverse: helper_test_op(None, lambda x,y: fxn(2,y), lambda x,y: fxn(2,y), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) def test_cmp_eq(self): self._test_cmp(lambda x,y: x.eq(y), reverse=False) def test_cmp_gt(self): self._test_cmp(lambda x,y: x>y) def test_cmp_ge(self): self._test_cmp(lambda x,y: x>=y) def test_cmp_lt(self): self._test_cmp(lambda x,y: x=2) def test_medium_grouped_conv2d(self): bs = 1 groups = 2 rcout = 2 cin = 2 helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2) def test_depthwise_conv2d(self): bs = 1 groups = 32 rcout = 1 cin = 1 helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) def test_grouped_conv2d(self): bs = 4 groups = 5 rcout = 7 cin = 3 helper_test_op([(bs,groups*cin,5,5), (groups*rcout,cin,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) def test_fancy_conv2d(self): bs = 2 cin = 3 cout = 1 groups = 3 H,W = 3,3 helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) def test_strided_conv2d_simple(self): bs,H,W = 2,3,1 helper_test_op([(bs,1,5,1), (1,1,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4) def test_strided_conv2d(self): bs = 4 cin = 3 H,W = 3,3 with self.subTest(stride := 2): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), atol=1e-4) with self.subTest(stride := (2,1)): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4) def test_negative_padding_conv2d(self): n,k = 10, 3 helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1],w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=-1).relu(), atol=1e-4) helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:, 1:],w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0)).relu(), atol=1e-4) def test_simple_padding_conv2d(self): p = (1,1,1,1) helper_test_op(None, lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4, vals=[[[[[2.,3.]]]], [[[[1.]]]]]) def test_asymmetric_padding_conv2d(self): for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: with self.subTest(padding := p): for n in [3,4]: for k in [2]: helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) def test_padded_conv2d(self): bs = 4 cin = 3 H,W = 3,3 for p in [2, (2,1), (2,2)]: with self.subTest(padding := p): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) def test_padded_conv2d_bs1(self): bs = 1 cin = 3 H,W = 3,3 padding = 1 helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) def test_dilated_conv2d(self): bs = 4 cin = 3 H,W = 3,3 for d in [2, (2,1)]: with self.subTest(dilation := d): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4) def test_maxpool2d_simple(self): ksz = (2,2) helper_test_op([(1,1,2,3)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) def test_maxpool2d(self): for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) def test_maxpool2d_bigger_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) def test_maxpool2d_unit_stride(self): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) def test_maxpool2d_smaller_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) def test_avgpool2d(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) def test_global_avgpool2d(self): helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)), lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) def test_cat(self): for dim in range(-1, 2): helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim)) def test_multicat(self): for dim in range(-1, 2): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) def test_matvec(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4) if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2)