mirror of https://github.com/commaai/tinygrad.git
pickle main pattern matcher [run_process_replay] (#6827)
* pickle main pattern matcher [run_process_replay] * del line
This commit is contained in:
parent
d726eb6f48
commit
8a93c48901
|
@ -20,6 +20,10 @@ class TestPickle(unittest.TestCase):
|
|||
pm2 = pickle.loads(pm_str)
|
||||
self.assertEqual(pm2.rewrite(sink).key, tt.key)
|
||||
|
||||
def test_pickle_main_pattern_matcher(self):
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
pickle.dumps(sym)
|
||||
|
||||
def test_pickle_realized_tensor(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
|
|
|
@ -17,11 +17,13 @@ class TestContextVars(unittest.TestCase):
|
|||
_TMP = ContextVar("_TMP", 5)
|
||||
self.assertEqual(_TMP.value, 5)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_multiple_creation_ignored(self):
|
||||
_TMP2 = ContextVar("_TMP2", 1)
|
||||
_TMP2 = ContextVar("_TMP2", 2)
|
||||
self.assertEqual(_TMP2.value, 1)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_new_var_inside_context(self):
|
||||
# Creating a _new_ variable inside a context should not have any effect on its scope (?)
|
||||
with Context(VARIABLE=1):
|
||||
|
@ -29,6 +31,7 @@ class TestContextVars(unittest.TestCase):
|
|||
_TMP3 = ContextVar("_TMP3", 2)
|
||||
self.assertEqual(_TMP3.value, 1)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_value_accross_modules(self):
|
||||
# Mocking module import by invoking the code but not in our globals().
|
||||
exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used
|
||||
|
@ -36,6 +39,7 @@ class TestContextVars(unittest.TestCase):
|
|||
C = ContextVar("C", 0)
|
||||
self.assertEqual(C.value, 13)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assignment_across_modules(self):
|
||||
B = ContextVar("B", 1)
|
||||
# local assignment
|
||||
|
@ -56,6 +60,7 @@ class TestContextVars(unittest.TestCase):
|
|||
with Context(SOMETHING_ELSE=1):
|
||||
pass
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_inside_context_assignment(self):
|
||||
with Context(VARIABLE=4):
|
||||
# What you can and cannot do inside a context.
|
||||
|
@ -70,6 +75,7 @@ class TestContextVars(unittest.TestCase):
|
|||
# Related to 2. above. Note that VARIABLE is back to 0 again as expected.
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_new_var_inside_context_other_module(self):
|
||||
with Context(VARIABLE=1):
|
||||
_NEW2 = ContextVar("_NEW2", 0)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect, importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
|
@ -103,11 +103,10 @@ class ContextVar:
|
|||
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
||||
value: int
|
||||
key: str
|
||||
def __new__(cls, key, default_value):
|
||||
if key in ContextVar._cache: return ContextVar._cache[key]
|
||||
instance = ContextVar._cache[key] = super().__new__(cls)
|
||||
instance.value, instance.key = getenv(key, default_value), key
|
||||
return instance
|
||||
def __init__(self, key, default_value):
|
||||
assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
|
||||
ContextVar._cache[key] = self
|
||||
self.value, self.key = getenv(key, default_value), key
|
||||
def __bool__(self): return bool(self.value)
|
||||
def __ge__(self, x): return self.value >= x
|
||||
def __gt__(self, x): return self.value > x
|
||||
|
@ -384,3 +383,6 @@ def _serialize_code(code:types.CodeType):
|
|||
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
|
||||
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
|
||||
copyreg.pickle(types.CodeType, _serialize_code)
|
||||
|
||||
def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,)
|
||||
copyreg.pickle(types.ModuleType, _serialize_module)
|
||||
|
|
|
@ -490,9 +490,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
|
|||
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
||||
for co in fxn.__code__.co_consts:
|
||||
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
||||
new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle!
|
||||
# NOTE: optional round trip through pickle!
|
||||
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
|
||||
return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__
|
||||
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
|
||||
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
|
|
Loading…
Reference in New Issue