lazy.py: cache consts (#3577)

* lazy.py: cache consts

* add regression test

* always always cache const

* bump by 1
This commit is contained in:
George Hotz 2024-03-02 03:50:05 -08:00 committed by GitHub
parent fb8acd1851
commit 41f0a25b53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 1 deletions

View File

@ -76,7 +76,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 192 if CI else 719, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):

View File

@ -423,5 +423,11 @@ class TestSchedule(unittest.TestCase):
out = x + y
check_schedule(out, 2) # TODO: this should be 1
def test_const_no_recompute(self):
x = Tensor(2) + Tensor(2)
y = Tensor(2) + Tensor(2)
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -13,6 +13,7 @@ lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
if op == LoadOps.CONST: enable_cache = True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference