pickle main pattern matcher [run_process_replay] (#6827)

* pickle main pattern matcher [run_process_replay]

* del line
This commit is contained in:
George Hotz 2024-10-01 13:58:42 +08:00 committed by GitHub
parent d726eb6f48
commit 8a93c48901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 8 deletions

View File

@ -20,6 +20,10 @@ class TestPickle(unittest.TestCase):
pm2 = pickle.loads(pm_str)
self.assertEqual(pm2.rewrite(sink).key, tt.key)
def test_pickle_main_pattern_matcher(self):
from tinygrad.codegen.uopgraph import sym
pickle.dumps(sym)
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)

View File

@ -17,11 +17,13 @@ class TestContextVars(unittest.TestCase):
_TMP = ContextVar("_TMP", 5)
self.assertEqual(_TMP.value, 5)
@unittest.expectedFailure
def test_multiple_creation_ignored(self):
_TMP2 = ContextVar("_TMP2", 1)
_TMP2 = ContextVar("_TMP2", 2)
self.assertEqual(_TMP2.value, 1)
@unittest.expectedFailure
def test_new_var_inside_context(self):
# Creating a _new_ variable inside a context should not have any effect on its scope (?)
with Context(VARIABLE=1):
@ -29,6 +31,7 @@ class TestContextVars(unittest.TestCase):
_TMP3 = ContextVar("_TMP3", 2)
self.assertEqual(_TMP3.value, 1)
@unittest.expectedFailure
def test_value_accross_modules(self):
# Mocking module import by invoking the code but not in our globals().
exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used
@ -36,6 +39,7 @@ class TestContextVars(unittest.TestCase):
C = ContextVar("C", 0)
self.assertEqual(C.value, 13)
@unittest.expectedFailure
def test_assignment_across_modules(self):
B = ContextVar("B", 1)
# local assignment
@ -56,6 +60,7 @@ class TestContextVars(unittest.TestCase):
with Context(SOMETHING_ELSE=1):
pass
@unittest.expectedFailure
def test_inside_context_assignment(self):
with Context(VARIABLE=4):
# What you can and cannot do inside a context.
@ -70,6 +75,7 @@ class TestContextVars(unittest.TestCase):
# Related to 2. above. Note that VARIABLE is back to 0 again as expected.
self.assertEqual(VARIABLE.value, 0)
@unittest.expectedFailure
def test_new_var_inside_context_other_module(self):
with Context(VARIABLE=1):
_NEW2 = ContextVar("_NEW2", 0)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect, importlib
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@ -103,11 +103,10 @@ class ContextVar:
_cache: ClassVar[Dict[str, ContextVar]] = {}
value: int
key: str
def __new__(cls, key, default_value):
if key in ContextVar._cache: return ContextVar._cache[key]
instance = ContextVar._cache[key] = super().__new__(cls)
instance.value, instance.key = getenv(key, default_value), key
return instance
def __init__(self, key, default_value):
assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
ContextVar._cache[key] = self
self.value, self.key = getenv(key, default_value), key
def __bool__(self): return bool(self.value)
def __ge__(self, x): return self.value >= x
def __gt__(self, x): return self.value > x
@ -384,3 +383,6 @@ def _serialize_code(code:types.CodeType):
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
copyreg.pickle(types.CodeType, _serialize_code)
def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,)
copyreg.pickle(types.ModuleType, _serialize_module)

View File

@ -490,9 +490,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
for co in fxn.__code__.co_consts:
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle!
# NOTE: optional round trip through pickle!
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
class PatternMatcher:
def __init__(self, patterns:List[Tuple[UPat, Callable]]):