test speed v torch uses jit

This commit is contained in:
George Hotz 2023-02-12 07:43:17 -08:00
parent 693d4b89a4
commit de71c13934
4 changed files with 46 additions and 39 deletions

View File

@ -1,5 +1,6 @@
from typing import Callable, List, Tuple from typing import Callable, List, Tuple
import itertools import itertools
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import DEBUG, GlobalCounters from tinygrad.ops import DEBUG, GlobalCounters
@ -12,20 +13,24 @@ class TinyJit:
self.input_replace = {} self.input_replace = {}
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_tensors) != 0, "no inputs to JIT"
if self.cnt >= 2: if self.cnt >= 2:
for a,idx in self.input_replace.items(): a._buf = input_tensors[idx] for a,idx in self.input_replace.items(): a._buf = input_tensors[idx]
for prg, args in self.jit_cache: prg(*args) for prg, args in self.jit_cache: prg(*args)
else: elif self.cnt == 1:
if self.cnt == 1: GlobalCounters.cache = [] GlobalCounters.cache = []
self.ret = self.fxn(*args, **kwargs).realize() self.ret = self.fxn(*args, **kwargs)
if self.cnt == 1: self.jit_cache = GlobalCounters.cache
self.jit_cache = GlobalCounters.cache GlobalCounters.cache = None
GlobalCounters.cache = None assert len(self.jit_cache) != 0, "didn't JIT anything!"
# get the inputs for replacement # get the inputs for replacement
for prg, args in self.jit_cache: # pylint: disable=E1133 for prg, args in self.jit_cache: # pylint: disable=E1133
self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()}) self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()})
assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found" assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found"
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
self.cnt += 1 self.cnt += 1
return self.ret return self.ret

View File

@ -63,7 +63,7 @@ def model_exec(run_onnx, using_graph, **inputs):
ret = next(iter(run_onnx(inputs).values())) ret = next(iter(run_onnx(inputs).values()))
GlobalCounters.cache = [] # don't cache pre-realize GlobalCounters.cache = [] # don't cache pre-realize
if using_graph: graph.GRAPH = True if using_graph: graph.GRAPH = True
return ret return ret.realize()
def compile(dat, output_fn): def compile(dat, output_fn):
Tensor.no_grad = True Tensor.no_grad = True

View File

@ -8,7 +8,7 @@ from extra.jit import TinyJit
class TestJit(unittest.TestCase): class TestJit(unittest.TestCase):
def test_simple_jit(self): def test_simple_jit(self):
@TinyJit @TinyJit
def add(a, b): return a+b def add(a, b): return (a+b).realize()
for _ in range(3): for _ in range(3):
a = Tensor.randn(10, 10) a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10) b = Tensor.randn(10, 10)
@ -17,7 +17,7 @@ class TestJit(unittest.TestCase):
def test_kwargs_jit(self): def test_kwargs_jit(self):
@TinyJit @TinyJit
def add_kwargs(first, second): return first+second def add_kwargs(first, second): return (first+second).realize()
for _ in range(3): for _ in range(3):
a = Tensor.randn(10, 10) a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10) b = Tensor.randn(10, 10)
@ -26,12 +26,12 @@ class TestJit(unittest.TestCase):
def test_array_jit(self): def test_array_jit(self):
@TinyJit @TinyJit
def add_array(arr): return arr[0]+arr[1] def add_array(a, arr): return (a+arr[0]).realize()
for i in range(3): for i in range(3):
a = Tensor.randn(10, 10) a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10) b = Tensor.randn(10, 10)
a.realize(), b.realize() a.realize(), b.realize()
c = add_array([a,b]) c = add_array(a, [b])
if i == 2: if i == 2:
# should fail once jitted since jit can't handle arrays # should fail once jitted since jit can't handle arrays
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True) np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)

View File

@ -11,6 +11,7 @@ from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv from tinygrad.helpers import colored, getenv
from extra.jit import TinyJit
try: try:
from tinygrad.runtime.opencl import CL from tinygrad.runtime.opencl import CL
except ImportError: except ImportError:
@ -39,6 +40,7 @@ def helper_test_speed(f1, *args):
del ret del ret
GlobalCounters.global_ops = 0 GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0 GlobalCounters.global_mem = 0
args = [(x+1).realize() if isinstance(x,Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
st = time.monotonic() st = time.monotonic()
ret = f1(*args) ret = f1(*args)
if CL is not None and ret.device in ["GPU"]: if CL is not None and ret.device in ["GPU"]:
@ -52,22 +54,22 @@ def helper_test_speed(f1, *args):
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
return ret.cpu().numpy(), np.min(ets) return ret.cpu().numpy(), np.min(ets)
def helper_test_generic_square(name, N, f1, f2): def helper_test_generic_square(name, N, f1, f2, onearg=False):
torch.manual_seed(0) torch.manual_seed(0)
torch_a = (torch.rand(N, N) - 0.5).to(torch_device) torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
torch_b = (torch.rand(N, N) - 0.5).to(torch_device) torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None
tiny_a = Tensor(torch_a.cpu().numpy()) tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy()) tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", partial(f1, torch_a, torch_b), partial(f2, tiny_a, tiny_b)) helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
prefix = None prefix = None
def helper_test_generic(name, f1, f2): def helper_test_generic(name, f1, f1_args, f2, f2_args):
global prefix global prefix
with torch.no_grad(): with torch.no_grad():
val_torch, et_torch = helper_test_speed(f1) val_torch, et_torch = helper_test_speed(f1, *f1_args)
val_tinygrad, et_tinygrad = helper_test_speed(lambda: f2().realize()) val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args)
desc = "faster" if et_torch > et_tinygrad else "slower" desc = "faster" if et_torch > et_tinygrad else "slower"
flops = save_ops*1e-6 flops = save_ops*1e-6
@ -92,24 +94,24 @@ class TestSpeed(unittest.TestCase):
def test_sum(self): def test_sum(self):
def f(a, b): return a.sum() def f(a, b): return a.sum()
helper_test_generic_square('sum', 4096, f, f) helper_test_generic_square('sum', 4096, f, f, onearg=True)
def test_partial_sum(self): def test_partial_sum(self):
R = 256 R = 256
def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1) def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
helper_test_generic_square('partial_sum', 4096, f, f) helper_test_generic_square('partial_sum', 4096, f, f, onearg=True)
def test_array_packing(self): def test_array_packing(self):
N = 2048 N = 2048
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous() def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
helper_test_generic_square('array_packing', N, f, f) helper_test_generic_square('array_packing', N, f, f, onearg=True)
def test_permute(self): def test_permute(self):
for N in [1024, 4096]: for N in [1024, 4096]:
# this is a 64MB tensor, M1 L1 cache is 128kB # this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f(a, b): return a.permute(1,0).contiguous() def f(a, b): return a.permute(1,0).contiguous()
helper_test_generic_square('permute', N, f, f) helper_test_generic_square('permute', N, f, f, onearg=True)
def test_double_permute(self): def test_double_permute(self):
N = 64 N = 64
@ -117,23 +119,23 @@ class TestSpeed(unittest.TestCase):
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device) torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy()) tiny_a = Tensor(torch_a.cpu().numpy())
def f(a): return a.permute(1,0,3,2).contiguous() def f(a): return a.permute(1,0,3,2).contiguous()
helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a)) helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
def test_neg(self): def test_neg(self):
def f(a, b): return -a def f(a, b): return -a
helper_test_generic_square('neg', 4096, f, f) helper_test_generic_square('neg', 4096, f, f, onearg=True)
def test_exp(self): def test_exp(self):
def f(a, b): return a.exp() def f(a, b): return a.exp()
helper_test_generic_square('exp', 2048, f, f) helper_test_generic_square('exp', 2048, f, f, onearg=True)
def test_relu(self): def test_relu(self):
def f(a, b): return a.relu() def f(a, b): return a.relu()
helper_test_generic_square('relu', 4096, f, f) helper_test_generic_square('relu', 4096, f, f, onearg=True)
def test_max(self): def test_max(self):
def f(a, b): return a.max() def f(a, b): return a.max()
helper_test_generic_square('max', 4096, f, f) helper_test_generic_square('max', 4096, f, f, onearg=True)
def test_mul_sum(self): def test_mul_sum(self):
def f(a, b): return (a*b).sum() def f(a, b): return (a*b).sum()
@ -146,11 +148,11 @@ class TestSpeed(unittest.TestCase):
def test_add_constant(self): def test_add_constant(self):
def f(a, b): return a+2.0 def f(a, b): return a+2.0
helper_test_generic_square('add_constant', 4096, f, f) helper_test_generic_square('add_constant', 4096, f, f, onearg=True)
def test_add_constant_zero(self): def test_add_constant_zero(self):
def f(a, b): return a+0.0 def f(a, b): return a+0.0
helper_test_generic_square('add_constant_zero', 4096, f, f) helper_test_generic_square('add_constant_zero', 4096, f, f, onearg=True)
def test_add_sq(self): def test_add_sq(self):
def f(a, b): return a*a + b*b def f(a, b): return a*a + b*b
@ -194,9 +196,9 @@ class TestSpeed(unittest.TestCase):
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(): return torch_conv(torch_dat.permute(0,3,1,2)) def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2))
def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize() def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2) helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
def test_conv2d(self): def test_conv2d(self):
torch.manual_seed(0) torch.manual_seed(0)
@ -211,9 +213,9 @@ class TestSpeed(unittest.TestCase):
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None) tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(): return torch_conv(torch_dat) def f1(torch_dat): return torch_conv(torch_dat)
def f2(): return tiny_conv(tiny_dat).realize() def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2) helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()