mirror of https://github.com/commaai/tinygrad.git
test speed v torch uses jit
This commit is contained in:
parent
693d4b89a4
commit
de71c13934
25
extra/jit.py
25
extra/jit.py
|
@ -1,5 +1,6 @@
|
||||||
from typing import Callable, List, Tuple
|
from typing import Callable, List, Tuple
|
||||||
import itertools
|
import itertools
|
||||||
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import DEBUG, GlobalCounters
|
from tinygrad.ops import DEBUG, GlobalCounters
|
||||||
|
|
||||||
|
@ -12,20 +13,24 @@ class TinyJit:
|
||||||
self.input_replace = {}
|
self.input_replace = {}
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
|
||||||
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||||
|
assert len(input_tensors) != 0, "no inputs to JIT"
|
||||||
if self.cnt >= 2:
|
if self.cnt >= 2:
|
||||||
for a,idx in self.input_replace.items(): a._buf = input_tensors[idx]
|
for a,idx in self.input_replace.items(): a._buf = input_tensors[idx]
|
||||||
for prg, args in self.jit_cache: prg(*args)
|
for prg, args in self.jit_cache: prg(*args)
|
||||||
else:
|
elif self.cnt == 1:
|
||||||
if self.cnt == 1: GlobalCounters.cache = []
|
GlobalCounters.cache = []
|
||||||
self.ret = self.fxn(*args, **kwargs).realize()
|
self.ret = self.fxn(*args, **kwargs)
|
||||||
if self.cnt == 1:
|
self.jit_cache = GlobalCounters.cache
|
||||||
self.jit_cache = GlobalCounters.cache
|
GlobalCounters.cache = None
|
||||||
GlobalCounters.cache = None
|
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||||
|
|
||||||
# get the inputs for replacement
|
# get the inputs for replacement
|
||||||
for prg, args in self.jit_cache: # pylint: disable=E1133
|
for prg, args in self.jit_cache: # pylint: disable=E1133
|
||||||
self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()})
|
self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()})
|
||||||
assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found"
|
assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found"
|
||||||
|
elif self.cnt == 0:
|
||||||
|
self.ret = self.fxn(*args, **kwargs)
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
return self.ret
|
return self.ret
|
||||||
|
|
|
@ -63,7 +63,7 @@ def model_exec(run_onnx, using_graph, **inputs):
|
||||||
ret = next(iter(run_onnx(inputs).values()))
|
ret = next(iter(run_onnx(inputs).values()))
|
||||||
GlobalCounters.cache = [] # don't cache pre-realize
|
GlobalCounters.cache = [] # don't cache pre-realize
|
||||||
if using_graph: graph.GRAPH = True
|
if using_graph: graph.GRAPH = True
|
||||||
return ret
|
return ret.realize()
|
||||||
|
|
||||||
def compile(dat, output_fn):
|
def compile(dat, output_fn):
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
|
|
|
@ -8,7 +8,7 @@ from extra.jit import TinyJit
|
||||||
class TestJit(unittest.TestCase):
|
class TestJit(unittest.TestCase):
|
||||||
def test_simple_jit(self):
|
def test_simple_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add(a, b): return a+b
|
def add(a, b): return (a+b).realize()
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
a = Tensor.randn(10, 10)
|
a = Tensor.randn(10, 10)
|
||||||
b = Tensor.randn(10, 10)
|
b = Tensor.randn(10, 10)
|
||||||
|
@ -17,7 +17,7 @@ class TestJit(unittest.TestCase):
|
||||||
|
|
||||||
def test_kwargs_jit(self):
|
def test_kwargs_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add_kwargs(first, second): return first+second
|
def add_kwargs(first, second): return (first+second).realize()
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
a = Tensor.randn(10, 10)
|
a = Tensor.randn(10, 10)
|
||||||
b = Tensor.randn(10, 10)
|
b = Tensor.randn(10, 10)
|
||||||
|
@ -26,12 +26,12 @@ class TestJit(unittest.TestCase):
|
||||||
|
|
||||||
def test_array_jit(self):
|
def test_array_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add_array(arr): return arr[0]+arr[1]
|
def add_array(a, arr): return (a+arr[0]).realize()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
a = Tensor.randn(10, 10)
|
a = Tensor.randn(10, 10)
|
||||||
b = Tensor.randn(10, 10)
|
b = Tensor.randn(10, 10)
|
||||||
a.realize(), b.realize()
|
a.realize(), b.realize()
|
||||||
c = add_array([a,b])
|
c = add_array(a, [b])
|
||||||
if i == 2:
|
if i == 2:
|
||||||
# should fail once jitted since jit can't handle arrays
|
# should fail once jitted since jit can't handle arrays
|
||||||
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)
|
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from tinygrad.ops import GlobalCounters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import Conv2d
|
from tinygrad.nn import Conv2d
|
||||||
from tinygrad.helpers import colored, getenv
|
from tinygrad.helpers import colored, getenv
|
||||||
|
from extra.jit import TinyJit
|
||||||
try:
|
try:
|
||||||
from tinygrad.runtime.opencl import CL
|
from tinygrad.runtime.opencl import CL
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -39,6 +40,7 @@ def helper_test_speed(f1, *args):
|
||||||
del ret
|
del ret
|
||||||
GlobalCounters.global_ops = 0
|
GlobalCounters.global_ops = 0
|
||||||
GlobalCounters.global_mem = 0
|
GlobalCounters.global_mem = 0
|
||||||
|
args = [(x+1).realize() if isinstance(x,Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
|
||||||
st = time.monotonic()
|
st = time.monotonic()
|
||||||
ret = f1(*args)
|
ret = f1(*args)
|
||||||
if CL is not None and ret.device in ["GPU"]:
|
if CL is not None and ret.device in ["GPU"]:
|
||||||
|
@ -52,22 +54,22 @@ def helper_test_speed(f1, *args):
|
||||||
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
|
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||||
return ret.cpu().numpy(), np.min(ets)
|
return ret.cpu().numpy(), np.min(ets)
|
||||||
|
|
||||||
def helper_test_generic_square(name, N, f1, f2):
|
def helper_test_generic_square(name, N, f1, f2, onearg=False):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
|
torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
|
||||||
torch_b = (torch.rand(N, N) - 0.5).to(torch_device)
|
torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None
|
||||||
|
|
||||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||||
tiny_b = Tensor(torch_b.cpu().numpy())
|
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
|
||||||
|
|
||||||
helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", partial(f1, torch_a, torch_b), partial(f2, tiny_a, tiny_b))
|
helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
|
||||||
|
|
||||||
prefix = None
|
prefix = None
|
||||||
def helper_test_generic(name, f1, f2):
|
def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||||
global prefix
|
global prefix
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_torch, et_torch = helper_test_speed(f1)
|
val_torch, et_torch = helper_test_speed(f1, *f1_args)
|
||||||
val_tinygrad, et_tinygrad = helper_test_speed(lambda: f2().realize())
|
val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args)
|
||||||
|
|
||||||
desc = "faster" if et_torch > et_tinygrad else "slower"
|
desc = "faster" if et_torch > et_tinygrad else "slower"
|
||||||
flops = save_ops*1e-6
|
flops = save_ops*1e-6
|
||||||
|
@ -92,24 +94,24 @@ class TestSpeed(unittest.TestCase):
|
||||||
|
|
||||||
def test_sum(self):
|
def test_sum(self):
|
||||||
def f(a, b): return a.sum()
|
def f(a, b): return a.sum()
|
||||||
helper_test_generic_square('sum', 4096, f, f)
|
helper_test_generic_square('sum', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_partial_sum(self):
|
def test_partial_sum(self):
|
||||||
R = 256
|
R = 256
|
||||||
def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
|
def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
|
||||||
helper_test_generic_square('partial_sum', 4096, f, f)
|
helper_test_generic_square('partial_sum', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_array_packing(self):
|
def test_array_packing(self):
|
||||||
N = 2048
|
N = 2048
|
||||||
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
|
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
|
||||||
helper_test_generic_square('array_packing', N, f, f)
|
helper_test_generic_square('array_packing', N, f, f, onearg=True)
|
||||||
|
|
||||||
def test_permute(self):
|
def test_permute(self):
|
||||||
for N in [1024, 4096]:
|
for N in [1024, 4096]:
|
||||||
# this is a 64MB tensor, M1 L1 cache is 128kB
|
# this is a 64MB tensor, M1 L1 cache is 128kB
|
||||||
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
|
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
|
||||||
def f(a, b): return a.permute(1,0).contiguous()
|
def f(a, b): return a.permute(1,0).contiguous()
|
||||||
helper_test_generic_square('permute', N, f, f)
|
helper_test_generic_square('permute', N, f, f, onearg=True)
|
||||||
|
|
||||||
def test_double_permute(self):
|
def test_double_permute(self):
|
||||||
N = 64
|
N = 64
|
||||||
|
@ -117,23 +119,23 @@ class TestSpeed(unittest.TestCase):
|
||||||
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
|
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
|
||||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||||
def f(a): return a.permute(1,0,3,2).contiguous()
|
def f(a): return a.permute(1,0,3,2).contiguous()
|
||||||
helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a))
|
helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
|
||||||
|
|
||||||
def test_neg(self):
|
def test_neg(self):
|
||||||
def f(a, b): return -a
|
def f(a, b): return -a
|
||||||
helper_test_generic_square('neg', 4096, f, f)
|
helper_test_generic_square('neg', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_exp(self):
|
def test_exp(self):
|
||||||
def f(a, b): return a.exp()
|
def f(a, b): return a.exp()
|
||||||
helper_test_generic_square('exp', 2048, f, f)
|
helper_test_generic_square('exp', 2048, f, f, onearg=True)
|
||||||
|
|
||||||
def test_relu(self):
|
def test_relu(self):
|
||||||
def f(a, b): return a.relu()
|
def f(a, b): return a.relu()
|
||||||
helper_test_generic_square('relu', 4096, f, f)
|
helper_test_generic_square('relu', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_max(self):
|
def test_max(self):
|
||||||
def f(a, b): return a.max()
|
def f(a, b): return a.max()
|
||||||
helper_test_generic_square('max', 4096, f, f)
|
helper_test_generic_square('max', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_mul_sum(self):
|
def test_mul_sum(self):
|
||||||
def f(a, b): return (a*b).sum()
|
def f(a, b): return (a*b).sum()
|
||||||
|
@ -146,11 +148,11 @@ class TestSpeed(unittest.TestCase):
|
||||||
|
|
||||||
def test_add_constant(self):
|
def test_add_constant(self):
|
||||||
def f(a, b): return a+2.0
|
def f(a, b): return a+2.0
|
||||||
helper_test_generic_square('add_constant', 4096, f, f)
|
helper_test_generic_square('add_constant', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_add_constant_zero(self):
|
def test_add_constant_zero(self):
|
||||||
def f(a, b): return a+0.0
|
def f(a, b): return a+0.0
|
||||||
helper_test_generic_square('add_constant_zero', 4096, f, f)
|
helper_test_generic_square('add_constant_zero', 4096, f, f, onearg=True)
|
||||||
|
|
||||||
def test_add_sq(self):
|
def test_add_sq(self):
|
||||||
def f(a, b): return a*a + b*b
|
def f(a, b): return a*a + b*b
|
||||||
|
@ -194,9 +196,9 @@ class TestSpeed(unittest.TestCase):
|
||||||
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
|
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
|
||||||
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
||||||
|
|
||||||
def f1(): return torch_conv(torch_dat.permute(0,3,1,2))
|
def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2))
|
||||||
def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
|
def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
|
||||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2)
|
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||||
|
|
||||||
def test_conv2d(self):
|
def test_conv2d(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -211,9 +213,9 @@ class TestSpeed(unittest.TestCase):
|
||||||
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None)
|
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None)
|
||||||
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
||||||
|
|
||||||
def f1(): return torch_conv(torch_dat)
|
def f1(torch_dat): return torch_conv(torch_dat)
|
||||||
def f2(): return tiny_conv(tiny_dat).realize()
|
def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
|
||||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2)
|
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue