rewrite the jit in the context of new schedule (#4162)

* rewrite the jit in the context of new schedule

* mypy better

* fix placeholder

* tests

* all functionality should work

* fix tests

* no CacheCollector
This commit is contained in:
George Hotz 2024-04-12 21:54:36 -07:00 committed by GitHub
parent b67f759780
commit ebc94c9d6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 134 additions and 211 deletions

View File

@ -1,66 +0,0 @@
#!/usr/bin/env python3
import gc
import time
from tqdm import trange
from extra.models.efficientnet import EfficientNet
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.engine.jit import CacheCollector
def tensors_allocated():
return sum(isinstance(x, Tensor) for x in gc.get_objects())
NUM = getenv("NUM", 2)
BS = getenv("BS", 8)
CNT = getenv("CNT", 10)
BACKWARD = getenv("BACKWARD", 0)
TRAINING = getenv("TRAINING", 1)
ADAM = getenv("ADAM", 0)
CLCACHE = getenv("CLCACHE", 0)
if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
parameters = get_parameters(model)
for p in parameters: p.realize()
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
else: optimizer = optim.SGD(parameters, lr=0.001)
Tensor.training = TRAINING
Tensor.no_grad = not BACKWARD
for i in trange(CNT):
GlobalCounters.reset()
cpy = time.monotonic()
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
y_train = Tensor.randn(BS, 1000, requires_grad=False).realize()
# TODO: replace with TinyJit
if i < 3 or not CLCACHE:
st = time.monotonic()
out = model.forward(x_train)
loss = out.log_softmax().mul(y_train).mean()
if i == 2 and CLCACHE: CacheCollector.start()
if BACKWARD:
optimizer.zero_grad()
loss.backward()
optimizer.step()
mt = time.monotonic()
loss.realize()
for p in parameters:
p.realize()
et = time.monotonic()
else:
st = mt = time.monotonic()
for prg, args in cl_cache: prg(*args)
et = time.monotonic()
if i == 2 and CLCACHE:
cl_cache = CacheCollector.finish()
mem_used = GlobalCounters.mem_used
loss_cpu = loss.detach().numpy()
cl = time.monotonic()
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")

View File

@ -10,24 +10,27 @@ from tinygrad.helpers import getenv
from tinygrad.nn import optim
#from tinygrad.lazy import PUSH_PERMUTES
PUSH_PERMUTES = False
from tinygrad.engine.jit import CacheCollector
from tinygrad.engine.realize import capturing
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
self.count = 0
def add(self, ei): self.count += 1
def __enter__(self):
if self.preclear:
gc.collect()
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
x.realize()
GlobalCounters.reset()
CacheCollector.start(self.var_vals)
capturing.append(self)
print("cache: entering")
return self
def __exit__(self, type, value, traceback):
cache = CacheCollector.finish()
print(f"cache: exiting with size {len(cache)}", f"allowed {self.allowed}" if self.allowed is not None else "")
capturing.clear()
print(f"cache: exiting with size {self.count}", f"allowed {self.allowed}" if self.allowed is not None else "")
if self.allowed is not None:
assert len(cache) <= self.allowed and (not self.strict or len(cache) == self.allowed), f"used too many kernels! {len(cache)} > {self.allowed}"
assert self.count <= self.allowed and (not self.strict or self.count == self.allowed), f"used too many kernels! {self.count} > {self.allowed}"
from extra.models.convnext import ConvNeXt
from extra.models.efficientnet import EfficientNet
@ -77,9 +80,9 @@ class TestInferenceMinKernels(unittest.TestCase):
model = ViT(embed_dim=192, num_heads=3)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(222): # NOTE: this is way too high
with CLCache(222) as cache: # NOTE: this is way too high
out = model.forward(img)
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
assert cache.count == 0, "ViT prerealized?"
out.realize()
@unittest.skip("llama is fp16 but CI does not have fp16")
@ -97,12 +100,12 @@ class TestOptBinOp(unittest.TestCase):
def _test_no_binop_rerun(self, f1, f2=None, allowed=1):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)
with CLCache():
with CLCache() as cache:
c = f1(a, b)
if f2 is not None: d = f2(a, b)
c.realize()
if f2 is not None: d.realize()
assert len(CacheCollector.cache) == allowed, "binop was rerun!"
assert cache.count == allowed, "binop was rerun!"
if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5)
def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1))
@ -125,22 +128,22 @@ class TestOptReduceLoop(unittest.TestCase):
def test_loop_left(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)
with CLCache():
with CLCache() as cache:
t = a.sum(0)
b = t.reshape(16,1).expand(16,16).sum(0)
c = (t+b)
c.realize()
assert len(CacheCollector.cache) == 2, "loop left fusion broken"
assert cache.count == 2, "loop left fusion broken"
def test_loop_right(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)
with CLCache():
with CLCache() as cache:
t = a.sum(0)
b = t.reshape(16,1).expand(16,16).sum(0)
c = (b+t)
c.realize()
assert len(CacheCollector.cache) == 2, "loop right fusion broken"
assert cache.count == 2, "loop right fusion broken"
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptWChild(unittest.TestCase):
@ -148,12 +151,12 @@ class TestOptWChild(unittest.TestCase):
def test_unrealized_child(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)
with CLCache():
with CLCache() as cache:
c = (a*b).sum()
d = c+1
e = c+2 # noqa: F841
d.realize()
assert len(CacheCollector.cache) == 2, "don't fuse if you have children"
assert cache.count == 2, "don't fuse if you have children"
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOpt(unittest.TestCase):
@ -168,34 +171,34 @@ class TestOpt(unittest.TestCase):
def test_fold_reduce_elementwise(self):
img = Tensor.ones(32).contiguous()
addme = Tensor.ones(1)
with CLCache():
with CLCache() as cache:
ret = img.sum() + addme
ret.realize()
assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise"
assert cache.count == 1, "optimizer didn't fold reduce/elementwise"
assert ret.item() == 33
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.ones(1,32,4,4).contiguous()
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
with CLCache() as cache:
img_bn = bn(img).realize()
print(img_bn)
assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}"
assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}"
def test_fold_conv_sgd(self):
with Tensor.train():
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(get_parameters(c1))
with CLCache():
with CLCache() as cache:
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}"
assert cache.count in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {cache.count}"
def test_fold_2convs_sgd(self):
with Tensor.train():
@ -239,74 +242,74 @@ class TestOpt(unittest.TestCase):
bn = nn.BatchNorm2d(32, track_running_stats=False)
# precache the bn
bn(c1(img)).relu().realize()
with CLCache():
with CLCache() as cache:
bn(c1(img)).relu().realize()
assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}"
assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}"
def test_fold_conv_batchnorm(self):
with Tensor.train():
img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache():
with CLCache() as cache:
img_conv = bn(c1(img)).relu().realize()
print(img_conv)
assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}"
assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}"
def test_fold_conv_elu(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
with CLCache():
with CLCache() as cache:
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize()
print(img_conv)
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu"
assert cache.count == 2, "optimizer didn't fold conv/elu"
def test_fold_conv_relu(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
with CLCache():
with CLCache() as cache:
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
print(img_conv)
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
assert cache.count == 2, "optimizer didn't fold conv/relu"
def test_fold_conv_relu_nobias(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
with CLCache():
with CLCache() as cache:
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
print(img_conv)
assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu"
assert cache.count == 2, "optimizer didn't fold conv/relu"
def test_permute_was_pushed(self):
a = Tensor.randn(16, 16, 16)
with CLCache(2):
with CLCache(2) as cache:
c = a.sum(2)
d = c.permute(1,0).contiguous()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_through_contract_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache(2):
with CLCache(2) as cache:
c = a.sum(-1)
d = c.reshape(16,16).permute(1,0).contiguous()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_through_contractw1s_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache(2):
with CLCache(2) as cache:
c = a.sum(-1)
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
@ -315,35 +318,35 @@ class TestOpt(unittest.TestCase):
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
def test_permute_was_pushed_through_expand_reshape(self):
a = Tensor.randn(16, 16, 16)
with CLCache():
with CLCache() as cache:
c = a.sum(2)
d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
def test_no_reduceop_rerun(self):
a = Tensor.randn(16, 16, 16)
with CLCache():
with CLCache() as cache:
c = a.sum(2)
d = a.sum(2).permute(1,0)
c.realize()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
assert cache_len == 1, "reduceop was rerun!"
@unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES")
def test_no_reduceop_rerun_alt(self):
a = Tensor.randn(16, 16, 16)
with CLCache():
with CLCache() as cache:
c = a.sum(2).permute(1,0)
d = a.sum(2)
c.realize()
d.realize()
cache_len = len(CacheCollector.cache)
cache_len = cache.count
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
assert cache_len == 1, "reduceop was rerun!"

View File

@ -9,9 +9,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import CacheCollector
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule
from tinygrad.helpers import prod, Context, getenv
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
@ -20,9 +19,10 @@ class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
a, b = Tensor.randn(4), Tensor.randn(4)
np_a, np_b = a.numpy(), b.numpy()
CacheCollector.start()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize()
rawbufs = CacheCollector.finish()[0].rawbufs
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(create_schedule([c.lazydata])))
for ei in lowered: ei.run()
rawbufs = lowered[-1].rawbufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple
import importlib, inspect, functools, pathlib, time, ctypes, os
from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import DEBUG, CACHECOLLECTING, BEAM, NOOPT, GlobalCounters
from tinygrad.helpers import DEBUG, BEAM, NOOPT, GlobalCounters
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, get_lazyop_info
from tinygrad.buffer import Buffer, BufferOptions
@ -44,11 +44,7 @@ class Runner:
self.op_estimate:sint = 0
self.mem_estimate:sint = 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.engine.jit import CacheCollector
et = self(rawbufs, var_vals)
if CACHECOLLECTING: CacheCollector.add(self, rawbufs, var_vals)
return et
return self(rawbufs, {} if var_vals is None else var_vals)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")

View File

@ -1,33 +1,28 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
import functools, itertools, operator
from tinygrad.nn.state import get_parameters
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
from tinygrad.device import Compiled, Runner, CompiledRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from dataclasses import dataclass
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.realize import ExecItem, capturing
from tinygrad.nn.state import get_parameters
from weakref import ref, WeakKeyDictionary
# TODO: these graph functions probably shouldn't exist here
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0)
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and \
((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))]
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars]
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
@ -68,7 +63,38 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
if len(current_batch) > 0: flush_batch()
return graphed_jit_cache
# *** JIT ***
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
class PlaceHolder:
placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
def __init__(self, buf:Buffer):
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
@staticmethod
def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]:
if found:=PlaceHolder.placeholders.get(buf, None): return found
if hasattr(buf, '_buf'): return buf
PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here?
return ret
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
return buffer_cache[self]
@dataclass(frozen=True)
class WeakExecItem:
prg: Runner
rawbufs: List[Union[PlaceHolder, Buffer]]
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
@ -76,62 +102,56 @@ class TinyJit(Generic[ReturnType]):
self.fxn = fxn
self.reset()
def add(self, ei:ExecItem):
self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None]))
def reset(self):
self._cc: List[WeakExecItem] = []
self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_name_sts_dtype_device: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType, Union[str, Tuple[str, ...]]], ...]] = None
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs (except const) are realized
input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501
Tensor.corealize(input_tensors.values())
input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()}
expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501
# get rawbuffers
lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \
flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)])
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
Tensor.corealize([x[1] for x in input_tensors])
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
var_vals: Dict[Variable, int] = merge_dicts([x[1] for x in expected_sts_var_dtype_device] + \
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in lbs] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
expected_vals = tuple(var_vals.keys())
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
if self.cnt >= 2:
# jit exec
assert self.expected_vals == expected_vals and self.expected_name_sts_dtype_device is not None, "missing/mismatch of var_vals"
assert all(x[0] == y[0] and x[1].views == y[1].views and x[2] == y[2] and x[3] == y[3]
for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \
f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}"
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels")
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
for ei in self.jit_cache: ei.run(var_vals, jit=True)
elif self.cnt == 1:
# jit capture
self.expected_vals, self.expected_name_sts_dtype_device = expected_vals, expected_name_sts_dtype_device
CacheCollector.start(var_vals)
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
self.expected_names: List[Union[int, str]] = expected_names
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
# TODO: reset doesn't work if we delete this
#del self.fxn
if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
print("WARNING: some input tensors not found")
capturing.clear()
assert len(self._cc), "didn't JIT anything!"
buffer_cache: Dict[PlaceHolder, Buffer] = {}
self.jit_cache = \
[ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc]
del self._cc
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# Condense the items into a graph executor.
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
@ -141,42 +161,4 @@ class TinyJit(Generic[ReturnType]):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return cast(ReturnType, self.ret)
class PlaceHolder:
def __init__(self, buf:Buffer):
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
return buffer_cache[self]
class _CacheCollector:
def __init__(self):
self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
# Buffer optimization is allowed only for kernel operations. Avoids for copies (prevents parallelism) and syncs (incorrect buffer reuse).
if isinstance(prg, CompiledRunner):
for i in range(prg.outcount): self.placeholders[rawbufs[i]] = PlaceHolder(rawbufs[i])
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
def finish(self) -> List[ExecItem]:
if self.cache is None: return []
buffer_cache: Dict[PlaceHolder, Buffer] = {}
saved_cache, self.cache = self.cache, None
return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector()
return self.ret

View File

@ -10,8 +10,8 @@ from tinygrad.shape.symbolic import Variable
class ExecItem:
prg: Runner
rawbufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None):
self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {})
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False):
self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait, jit=jit)
class CustomOp(Runner):
def __init__(self, fxn):
@ -38,5 +38,9 @@ def lower_schedule_item(si:ScheduleItem) -> Runner:
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
capturing: List = [] # put classes with an add method in here
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule): ei.run(var_vals)
for ei in lower_schedule(schedule):
if len(capturing): capturing[0].add(ei)
ei.run(var_vals)

View File

@ -52,6 +52,10 @@ class LazyBuffer:
@property
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
# same API as multi
@property
def lbs(self) -> List[LazyBuffer]: return [self]
@staticmethod
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
assert isinstance(src, tuple)