fix UnboundLocalError when running Compiler with DISABLE_COMPILER_CACHE (#3296)

This commit is contained in:
Felix Wu 2024-02-02 03:12:33 +01:00 committed by GitHub
parent a5bf4afc1a
commit 021eea3a52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 2 deletions

View File

@ -1,6 +1,9 @@
#!/usr/bin/env python
import unittest
from tinygrad.device import Device
from unittest.mock import patch
import os
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
@ -15,5 +18,24 @@ class TestDevice(unittest.TestCase):
assert Device.canonicalize("GPU:2") == "GPU:2"
assert Device.canonicalize("disk:/dev/shm/test") == "DISK:/dev/shm/test"
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()
class TestCompiler(unittest.TestCase):
def test_compile_cached(self):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
assert MockCompiler("key").compile_cached("123") == str.encode("123")
assert diskcache_get("key", "123") == str.encode("123")
def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
assert diskcache_get("disabled_key", "123") is None
if __name__ == "__main__":
unittest.main()

View File

@ -268,7 +268,7 @@ class Compiler:
def render(self, name:str, uops) -> str: raise NotImplementedError("need a render function")
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile_cached(self, src:str) -> bytes:
if self.cachekey is not None: lib = diskcache_get(self.cachekey, src)
lib = diskcache_get(self.cachekey, src) if self.cachekey is not None else None
if lib is None:
lib = self.compile(src)
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)