mirror of https://github.com/commaai/tinygrad.git
rewrite the jit in the context of new schedule (#4162)
* rewrite the jit in the context of new schedule * mypy better * fix placeholder * tests * all functionality should work * fix tests * no CacheCollector
This commit is contained in:
parent
b67f759780
commit
ebc94c9d6c
|
@ -1,66 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import gc
|
||||
import time
|
||||
from tqdm import trange
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.jit import CacheCollector
|
||||
|
||||
def tensors_allocated():
|
||||
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
||||
|
||||
NUM = getenv("NUM", 2)
|
||||
BS = getenv("BS", 8)
|
||||
CNT = getenv("CNT", 10)
|
||||
BACKWARD = getenv("BACKWARD", 0)
|
||||
TRAINING = getenv("TRAINING", 1)
|
||||
ADAM = getenv("ADAM", 0)
|
||||
CLCACHE = getenv("CLCACHE", 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters: p.realize()
|
||||
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else: optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Tensor.training = TRAINING
|
||||
Tensor.no_grad = not BACKWARD
|
||||
for i in trange(CNT):
|
||||
GlobalCounters.reset()
|
||||
cpy = time.monotonic()
|
||||
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
|
||||
y_train = Tensor.randn(BS, 1000, requires_grad=False).realize()
|
||||
|
||||
# TODO: replace with TinyJit
|
||||
if i < 3 or not CLCACHE:
|
||||
st = time.monotonic()
|
||||
out = model.forward(x_train)
|
||||
loss = out.log_softmax().mul(y_train).mean()
|
||||
if i == 2 and CLCACHE: CacheCollector.start()
|
||||
if BACKWARD:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mt = time.monotonic()
|
||||
loss.realize()
|
||||
for p in parameters:
|
||||
p.realize()
|
||||
et = time.monotonic()
|
||||
else:
|
||||
st = mt = time.monotonic()
|
||||
for prg, args in cl_cache: prg(*args)
|
||||
et = time.monotonic()
|
||||
|
||||
if i == 2 and CLCACHE:
|
||||
cl_cache = CacheCollector.finish()
|
||||
|
||||
mem_used = GlobalCounters.mem_used
|
||||
loss_cpu = loss.detach().numpy()
|
||||
cl = time.monotonic()
|
||||
|
||||
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
|
@ -10,24 +10,27 @@ from tinygrad.helpers import getenv
|
|||
from tinygrad.nn import optim
|
||||
#from tinygrad.lazy import PUSH_PERMUTES
|
||||
PUSH_PERMUTES = False
|
||||
from tinygrad.engine.jit import CacheCollector
|
||||
from tinygrad.engine.realize import capturing
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
|
||||
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
self.count = 0
|
||||
def add(self, ei): self.count += 1
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
GlobalCounters.reset()
|
||||
CacheCollector.start(self.var_vals)
|
||||
capturing.append(self)
|
||||
print("cache: entering")
|
||||
return self
|
||||
def __exit__(self, type, value, traceback):
|
||||
cache = CacheCollector.finish()
|
||||
print(f"cache: exiting with size {len(cache)}", f"allowed {self.allowed}" if self.allowed is not None else "")
|
||||
capturing.clear()
|
||||
print(f"cache: exiting with size {self.count}", f"allowed {self.allowed}" if self.allowed is not None else "")
|
||||
if self.allowed is not None:
|
||||
assert len(cache) <= self.allowed and (not self.strict or len(cache) == self.allowed), f"used too many kernels! {len(cache)} > {self.allowed}"
|
||||
assert self.count <= self.allowed and (not self.strict or self.count == self.allowed), f"used too many kernels! {self.count} > {self.allowed}"
|
||||
|
||||
from extra.models.convnext import ConvNeXt
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
@ -77,9 +80,9 @@ class TestInferenceMinKernels(unittest.TestCase):
|
|||
model = ViT(embed_dim=192, num_heads=3)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
with CLCache(222): # NOTE: this is way too high
|
||||
with CLCache(222) as cache: # NOTE: this is way too high
|
||||
out = model.forward(img)
|
||||
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
|
||||
assert cache.count == 0, "ViT prerealized?"
|
||||
out.realize()
|
||||
|
||||
@unittest.skip("llama is fp16 but CI does not have fp16")
|
||||
|
@ -97,12 +100,12 @@ class TestOptBinOp(unittest.TestCase):
|
|||
def _test_no_binop_rerun(self, f1, f2=None, allowed=1):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = f1(a, b)
|
||||
if f2 is not None: d = f2(a, b)
|
||||
c.realize()
|
||||
if f2 is not None: d.realize()
|
||||
assert len(CacheCollector.cache) == allowed, "binop was rerun!"
|
||||
assert cache.count == allowed, "binop was rerun!"
|
||||
if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5)
|
||||
|
||||
def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1))
|
||||
|
@ -125,22 +128,22 @@ class TestOptReduceLoop(unittest.TestCase):
|
|||
def test_loop_left(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
t = a.sum(0)
|
||||
b = t.reshape(16,1).expand(16,16).sum(0)
|
||||
c = (t+b)
|
||||
c.realize()
|
||||
assert len(CacheCollector.cache) == 2, "loop left fusion broken"
|
||||
assert cache.count == 2, "loop left fusion broken"
|
||||
|
||||
def test_loop_right(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
t = a.sum(0)
|
||||
b = t.reshape(16,1).expand(16,16).sum(0)
|
||||
c = (b+t)
|
||||
c.realize()
|
||||
assert len(CacheCollector.cache) == 2, "loop right fusion broken"
|
||||
assert cache.count == 2, "loop right fusion broken"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptWChild(unittest.TestCase):
|
||||
|
@ -148,12 +151,12 @@ class TestOptWChild(unittest.TestCase):
|
|||
def test_unrealized_child(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = (a*b).sum()
|
||||
d = c+1
|
||||
e = c+2 # noqa: F841
|
||||
d.realize()
|
||||
assert len(CacheCollector.cache) == 2, "don't fuse if you have children"
|
||||
assert cache.count == 2, "don't fuse if you have children"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
|
@ -168,34 +171,34 @@ class TestOpt(unittest.TestCase):
|
|||
def test_fold_reduce_elementwise(self):
|
||||
img = Tensor.ones(32).contiguous()
|
||||
addme = Tensor.ones(1)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
ret = img.sum() + addme
|
||||
ret.realize()
|
||||
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert cache.count == 1, "optimizer didn't fold reduce/elementwise"
|
||||
assert ret.item() == 33
|
||||
|
||||
def test_fold_batchnorm(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,32,4,4).contiguous()
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}"
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(get_parameters(c1))
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
|
||||
assert cache.count in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {cache.count}"
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
with Tensor.train():
|
||||
|
@ -239,74 +242,74 @@ class TestOpt(unittest.TestCase):
|
|||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
# precache the bn
|
||||
bn(c1(img)).relu().realize()
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
bn(c1(img)).relu().realize()
|
||||
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}"
|
||||
|
||||
def test_fold_conv_batchnorm(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
|
||||
assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}"
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/elu"
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def test_fold_conv_relu_nobias(self):
|
||||
img = Tensor.ones(1,4,8,8)
|
||||
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
|
||||
print(img_conv)
|
||||
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
assert cache.count == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def test_permute_was_pushed(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(2)
|
||||
d = c.permute(1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,16).permute(1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(2) as cache:
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
|
@ -315,35 +318,35 @@ class TestOpt(unittest.TestCase):
|
|||
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
|
||||
def test_permute_was_pushed_through_expand_reshape(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2)
|
||||
d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
|
||||
def test_no_reduceop_rerun(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2)
|
||||
d = a.sum(2).permute(1,0)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
|
||||
def test_no_reduceop_rerun_alt(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
with CLCache() as cache:
|
||||
c = a.sum(2).permute(1,0)
|
||||
d = a.sum(2)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(CacheCollector.cache)
|
||||
cache_len = cache.count
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
|
|
|
@ -9,9 +9,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import CacheCollector
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule
|
||||
from tinygrad.helpers import prod, Context, getenv
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
@ -20,9 +19,10 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_arg_dedup(self):
|
||||
a, b = Tensor.randn(4), Tensor.randn(4)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
CacheCollector.start()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize()
|
||||
rawbufs = CacheCollector.finish()[0].rawbufs
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
lowered = list(lower_schedule(create_schedule([c.lazydata])))
|
||||
for ei in lowered: ei.run()
|
||||
rawbufs = lowered[-1].rawbufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple
|
||||
import importlib, inspect, functools, pathlib, time, ctypes, os
|
||||
from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, CACHECOLLECTING, BEAM, NOOPT, GlobalCounters
|
||||
from tinygrad.helpers import DEBUG, BEAM, NOOPT, GlobalCounters
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.ops import LazyOp, get_lazyop_info
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
|
@ -44,11 +44,7 @@ class Runner:
|
|||
self.op_estimate:sint = 0
|
||||
self.mem_estimate:sint = 0
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
var_vals = var_vals if var_vals is not None else {}
|
||||
from tinygrad.engine.jit import CacheCollector
|
||||
et = self(rawbufs, var_vals)
|
||||
if CACHECOLLECTING: CacheCollector.add(self, rawbufs, var_vals)
|
||||
return et
|
||||
return self(rawbufs, {} if var_vals is None else var_vals)
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
|
|
|
@ -1,33 +1,28 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
|
||||
import functools, itertools, operator
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
|
||||
from tinygrad.device import Compiled, Runner, CompiledRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
|
||||
from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.realize import ExecItem, capturing
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
# TODO: these graph functions probably shouldn't exist here
|
||||
|
||||
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
|
||||
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \
|
||||
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0)
|
||||
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and \
|
||||
((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))]
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars]
|
||||
|
||||
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
|
@ -68,7 +63,38 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
# *** JIT ***
|
||||
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
|
||||
class PlaceHolder:
|
||||
placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
|
||||
def __init__(self, buf:Buffer):
|
||||
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
|
||||
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
|
||||
def __hash__(self): return hash(self.to_tuple())
|
||||
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
|
||||
@staticmethod
|
||||
def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]:
|
||||
if found:=PlaceHolder.placeholders.get(buf, None): return found
|
||||
if hasattr(buf, '_buf'): return buf
|
||||
PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here?
|
||||
return ret
|
||||
|
||||
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
|
||||
ret = self.ref()
|
||||
if ret: return ret
|
||||
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
|
||||
return buffer_cache[self]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WeakExecItem:
|
||||
prg: Runner
|
||||
rawbufs: List[Union[PlaceHolder, Buffer]]
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
|
@ -76,62 +102,56 @@ class TinyJit(Generic[ReturnType]):
|
|||
self.fxn = fxn
|
||||
self.reset()
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None]))
|
||||
|
||||
def reset(self):
|
||||
self._cc: List[WeakExecItem] = []
|
||||
self.jit_cache: List[ExecItem] = []
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.cnt: int = 0
|
||||
self.ret: Optional[ReturnType] = None
|
||||
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
||||
self.expected_name_sts_dtype_device: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType, Union[str, Tuple[str, ...]]], ...]] = None
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs (except const) are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
|
||||
Tensor.corealize(input_tensors.values())
|
||||
input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()}
|
||||
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501
|
||||
|
||||
# get rawbuffers
|
||||
lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \
|
||||
flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)])
|
||||
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
||||
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
|
||||
Tensor.corealize([x[1] for x in input_tensors])
|
||||
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
|
||||
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
|
||||
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
||||
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
var_vals: Dict[Variable, int] = merge_dicts([x[1] for x in expected_sts_var_dtype_device] + \
|
||||
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in lbs] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
|
||||
expected_vals = tuple(var_vals.keys())
|
||||
|
||||
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
|
||||
if self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_vals == expected_vals and self.expected_name_sts_dtype_device is not None, "missing/mismatch of var_vals"
|
||||
assert all(x[0] == y[0] and x[1].views == y[1].views and x[2] == y[2] and x[3] == y[3]
|
||||
for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \
|
||||
f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}"
|
||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
|
||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
self.expected_vals, self.expected_name_sts_dtype_device = expected_vals, expected_name_sts_dtype_device
|
||||
CacheCollector.start(var_vals)
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
|
||||
self.expected_names: List[Union[int, str]] = expected_names
|
||||
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
Tensor.corealize(get_parameters(self.ret))
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
# TODO: reset doesn't work if we delete this
|
||||
#del self.fxn
|
||||
if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
|
||||
print("WARNING: some input tensors not found")
|
||||
capturing.clear()
|
||||
assert len(self._cc), "didn't JIT anything!"
|
||||
buffer_cache: Dict[PlaceHolder, Buffer] = {}
|
||||
self.jit_cache = \
|
||||
[ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc]
|
||||
del self._cc
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
||||
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
|
||||
elif self.cnt == 0:
|
||||
# jit ignore
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
@ -141,42 +161,4 @@ class TinyJit(Generic[ReturnType]):
|
|||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
self.cnt += 1
|
||||
return cast(ReturnType, self.ret)
|
||||
|
||||
class PlaceHolder:
|
||||
def __init__(self, buf:Buffer):
|
||||
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
|
||||
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
|
||||
def __hash__(self): return hash(self.to_tuple())
|
||||
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
|
||||
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
|
||||
ret = self.ref()
|
||||
if ret: return ret
|
||||
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
|
||||
return buffer_cache[self]
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self):
|
||||
self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None
|
||||
|
||||
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.cache = []
|
||||
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
|
||||
self.var_vals = var_vals if var_vals is not None else {}
|
||||
|
||||
def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]):
|
||||
if self.cache is None: return
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
|
||||
# Buffer optimization is allowed only for kernel operations. Avoids for copies (prevents parallelism) and syncs (incorrect buffer reuse).
|
||||
if isinstance(prg, CompiledRunner):
|
||||
for i in range(prg.outcount): self.placeholders[rawbufs[i]] = PlaceHolder(rawbufs[i])
|
||||
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
|
||||
|
||||
def finish(self) -> List[ExecItem]:
|
||||
if self.cache is None: return []
|
||||
buffer_cache: Dict[PlaceHolder, Buffer] = {}
|
||||
saved_cache, self.cache = self.cache, None
|
||||
return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
|
||||
CacheCollector = _CacheCollector()
|
||||
return self.ret
|
||||
|
|
|
@ -10,8 +10,8 @@ from tinygrad.shape.symbolic import Variable
|
|||
class ExecItem:
|
||||
prg: Runner
|
||||
rawbufs: List[Optional[Buffer]]
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {})
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False):
|
||||
self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait, jit=jit)
|
||||
|
||||
class CustomOp(Runner):
|
||||
def __init__(self, fxn):
|
||||
|
@ -38,5 +38,9 @@ def lower_schedule_item(si:ScheduleItem) -> Runner:
|
|||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
|
||||
|
||||
capturing: List = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
|
||||
for ei in lower_schedule(schedule): ei.run(var_vals)
|
||||
for ei in lower_schedule(schedule):
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
ei.run(var_vals)
|
||||
|
|
|
@ -52,6 +52,10 @@ class LazyBuffer:
|
|||
@property
|
||||
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
||||
|
||||
# same API as multi
|
||||
@property
|
||||
def lbs(self) -> List[LazyBuffer]: return [self]
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
||||
assert isinstance(src, tuple)
|
||||
|
|
Loading…
Reference in New Issue