mirror of https://github.com/commaai/tinygrad.git
fix UnboundLocalError when running Compiler with DISABLE_COMPILER_CACHE (#3296)
This commit is contained in:
parent
a5bf4afc1a
commit
021eea3a52
|
@ -1,6 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.device import Device
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from tinygrad.device import Device, Compiler
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
|
||||
|
||||
class TestDevice(unittest.TestCase):
|
||||
def test_canonicalize(self):
|
||||
|
@ -15,5 +18,24 @@ class TestDevice(unittest.TestCase):
|
|||
assert Device.canonicalize("GPU:2") == "GPU:2"
|
||||
assert Device.canonicalize("disk:/dev/shm/test") == "DISK:/dev/shm/test"
|
||||
|
||||
class MockCompiler(Compiler):
|
||||
def __init__(self, key): super().__init__(key)
|
||||
def compile(self, src) -> bytes: return src.encode()
|
||||
|
||||
class TestCompiler(unittest.TestCase):
|
||||
def test_compile_cached(self):
|
||||
diskcache_put("key", "123", None) # clear cache
|
||||
getenv.cache_clear()
|
||||
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
|
||||
assert MockCompiler("key").compile_cached("123") == str.encode("123")
|
||||
assert diskcache_get("key", "123") == str.encode("123")
|
||||
|
||||
def test_compile_cached_disabled(self):
|
||||
diskcache_put("disabled_key", "123", None) # clear cache
|
||||
getenv.cache_clear()
|
||||
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
|
||||
assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
|
||||
assert diskcache_get("disabled_key", "123") is None
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -268,7 +268,7 @@ class Compiler:
|
|||
def render(self, name:str, uops) -> str: raise NotImplementedError("need a render function")
|
||||
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
||||
def compile_cached(self, src:str) -> bytes:
|
||||
if self.cachekey is not None: lib = diskcache_get(self.cachekey, src)
|
||||
lib = diskcache_get(self.cachekey, src) if self.cachekey is not None else None
|
||||
if lib is None:
|
||||
lib = self.compile(src)
|
||||
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
||||
|
|
Loading…
Reference in New Issue