mirror of https://github.com/commaai/tinygrad.git
parent
643cbdfd50
commit
d25046e66a
|
@ -86,6 +86,17 @@ def helper_test_generic_square(name, N, f1, f2, onearg=False):
|
|||
|
||||
helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
|
||||
|
||||
def helper_test_matvec(name, N, M):
|
||||
torch.manual_seed(0)
|
||||
dt = torch.float32
|
||||
torch_a = (torch.rand(N, dtype=dt) - 0.5).to(torch_device)
|
||||
torch_b = (torch.rand(N, M, dtype=dt) - 0.5).to(torch_device)
|
||||
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
tiny_b = Tensor(torch_b.cpu().numpy())
|
||||
|
||||
helper_test_generic(f"{name:30s} {N:5d}x{M:5d}", lambda a,b: a@b, (torch_a, torch_b), TinyJit(lambda a,b:(a@b).realize()), (tiny_a, tiny_b))
|
||||
|
||||
prefix = None
|
||||
def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||
global prefix
|
||||
|
@ -129,6 +140,8 @@ class TestBigSpeed(unittest.TestCase):
|
|||
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
|
||||
def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
|
||||
def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130)
|
||||
def test_matvec_4096_16384(self): helper_test_matvec('matvec_4096_16384', 4096, 16384)
|
||||
def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
|
||||
|
||||
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
|
||||
class TestSpeed(unittest.TestCase):
|
||||
|
@ -247,6 +260,11 @@ class TestSpeed(unittest.TestCase):
|
|||
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
|
||||
helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
|
||||
|
||||
def test_matvec_1024_1024(self): helper_test_matvec('matvec_1024_1024', 1024, 1024)
|
||||
def test_matvec_1024_4096(self): helper_test_matvec('matvec_1024_4096', 1024, 4096)
|
||||
def test_matvec_4096_1024(self): helper_test_matvec('matvec_4096_1024', 4096, 1024)
|
||||
def test_matvec_4096_4096(self): helper_test_matvec('matvec_4096_4096', 4096, 4096)
|
||||
|
||||
def test_openpilot_conv2d(self):
|
||||
bs, in_chans, out_chans = 1,12,32
|
||||
torch.manual_seed(0)
|
||||
|
|
Loading…
Reference in New Issue