mirror of https://github.com/commaai/tinygrad.git
lazy.py: cache consts (#3577)
* lazy.py: cache consts * add regression test * always always cache const * bump by 1
This commit is contained in:
parent
fb8acd1851
commit
41f0a25b53
|
@ -76,7 +76,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 192 if CI else 719, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
|
|
|
@ -423,5 +423,11 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x + y
|
||||
check_schedule(out, 2) # TODO: this should be 1
|
||||
|
||||
def test_const_no_recompute(self):
|
||||
x = Tensor(2) + Tensor(2)
|
||||
y = Tensor(2) + Tensor(2)
|
||||
out = x.contiguous() + y.contiguous()
|
||||
check_schedule(out, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -13,6 +13,7 @@ lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
|
|||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op == LoadOps.CONST: enable_cache = True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference
|
||||
|
|
Loading…
Reference in New Issue