From 41f0a25b5324b6d4b43abd290da6d3a15c984e83 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 2 Mar 2024 03:50:05 -0800 Subject: [PATCH] lazy.py: cache consts (#3577) * lazy.py: cache consts * add regression test * always always cache const * bump by 1 --- test/models/test_real_world.py | 2 +- test/test_schedule.py | 6 ++++++ tinygrad/lazy.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 935987b8..c40309f6 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -76,7 +76,7 @@ class TestRealWorld(unittest.TestCase): @TinyJit def test(t): return model(t, 0).realize() # TODO: test first token vs rest properly - helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True) + helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 192 if CI else 719, all_jitted=True) @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") def test_gpt2(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index 733cf0f3..67fc0d58 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -423,5 +423,11 @@ class TestSchedule(unittest.TestCase): out = x + y check_schedule(out, 2) # TODO: this should be 1 + def test_const_no_recompute(self): + x = Tensor(2) + Tensor(2) + y = Tensor(2) + Tensor(2) + out = x.contiguous() + y.contiguous() + check_schedule(out, 2) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0fdd948b..b142f5b0 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -13,6 +13,7 @@ lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {} def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))): if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None + if op == LoadOps.CONST: enable_cache = True cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) if (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference