mirror of https://github.com/commaai/tinygrad.git
add pickle support for pattern matchers [run_process_replay] (#6816)
* add pickle support for pattern matchers [run_process_replay] * cleaner and all * no closures * fix tests * revert that * final * cleaner * python 3.8 fix * add round trip back * this * waste lines on this. that's the final line count * max print better * more targetted fix * regrettably add 3.8 support
This commit is contained in:
parent
f59517754e
commit
0f28e93224
|
@ -151,8 +151,8 @@ jobs:
|
|||
PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py
|
||||
- name: Test DEBUG
|
||||
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
- name: Repo line count <9800 lines
|
||||
run: MAX_LINE_COUNT=9800 python sz.py
|
||||
- name: Repo line count <= 9999 lines
|
||||
run: MAX_LINE_COUNT=9999 python sz.py
|
||||
|
||||
testopencl:
|
||||
strategy:
|
||||
|
|
2
sz.py
2
sz.py
|
@ -74,4 +74,4 @@ if __name__ == "__main__":
|
|||
total_lines = sum([x[1] for x in table])
|
||||
print(f"\ntotal line count: {total_lines}")
|
||||
max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1"))
|
||||
assert max_line_count == -1 or total_lines < max_line_count, f"OVER {max_line_count} LINES"
|
||||
assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES"
|
||||
|
|
|
@ -1,10 +1,25 @@
|
|||
import unittest, pickle
|
||||
import unittest, pickle, types
|
||||
import numpy as np
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, TinyJit, Variable
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_code_object(self):
|
||||
y = lambda x: x*2 # noqa: E731
|
||||
code_str = pickle.dumps(y.__code__)
|
||||
fxn = types.FunctionType(pickle.loads(code_str), globals())
|
||||
self.assertEqual(fxn(2), 4)
|
||||
|
||||
def test_pickle_pattern_matcher(self):
|
||||
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
|
||||
sink = UOp.const(dtypes.int, 2)
|
||||
tt = pm.rewrite(sink)
|
||||
pm_str = pickle.dumps(pm)
|
||||
pm2 = pickle.loads(pm_str)
|
||||
self.assertEqual(pm2.rewrite(sink).key, tt.key)
|
||||
|
||||
def test_pickle_realized_tensor(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
|
|
|
@ -11,6 +11,7 @@ class TestPatternMatcher(unittest.TestCase):
|
|||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
@unittest.skip("closures aren't supported on pattern matchers")
|
||||
def test_match_sz_0(self):
|
||||
match_cnt = 0
|
||||
def fxn(x):
|
||||
|
@ -25,6 +26,19 @@ class TestPatternMatcher(unittest.TestCase):
|
|||
c1 = matcher.rewrite(c1)
|
||||
self.assertEqual(match_cnt, 1)
|
||||
|
||||
def test_match_sz_0_ctx(self):
|
||||
def fxn(ctx, x):
|
||||
ctx.append(True)
|
||||
assert len(x.src) == 0
|
||||
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
ctx = []
|
||||
c1 = matcher.rewrite(c1, ctx)
|
||||
c1 = matcher.rewrite(c1, ctx)
|
||||
self.assertEqual(len(ctx), 1)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
|
@ -373,3 +373,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
|||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
# *** universal support for code object pickling
|
||||
|
||||
def _reconstruct_code(*args): return types.CodeType(*args)
|
||||
def _serialize_code(code:types.CodeType):
|
||||
# NOTE: this works in Python 3.8 and up
|
||||
if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters
|
||||
else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring',
|
||||
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
|
||||
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
|
||||
copyreg.pickle(types.CodeType, _serialize_code)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
from types import FrameType
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
|
@ -473,6 +473,14 @@ class UPatAny(UPat):
|
|||
if (match:=x.match(uop, store.copy())): return match
|
||||
return []
|
||||
|
||||
def deconstruct_function(fxn:Callable) -> Tuple:
|
||||
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
||||
for co in fxn.__code__.co_consts:
|
||||
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
||||
new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle!
|
||||
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
|
||||
return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
|
@ -481,7 +489,12 @@ class PatternMatcher:
|
|||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject))
|
||||
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
|
||||
tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends
|
||||
real_fxn = types.FunctionType(*tuple_fxn)
|
||||
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, real_fxn, p.early_reject))
|
||||
|
||||
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
@ -531,7 +544,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
|
||||
if TRACK_MATCH_STATS:
|
||||
PatternMatcher = TrackedPatternMatcher # type: ignore
|
||||
import atexit, pickle
|
||||
import atexit
|
||||
@atexit.register
|
||||
def print_match_stats():
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
|
|
Loading…
Reference in New Issue