add pickle support for pattern matchers [run_process_replay] (#6816)

* add pickle support for pattern matchers [run_process_replay]

* cleaner and all

* no closures

* fix tests

* revert that

* final

* cleaner

* python 3.8 fix

* add round trip back

* this

* waste lines on this. that's the final line count

* max print better

* more targetted fix

* regrettably add 3.8 support
This commit is contained in:
George Hotz 2024-09-30 21:54:46 +08:00 committed by GitHub
parent f59517754e
commit 0f28e93224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 62 additions and 9 deletions

View File

@ -151,8 +151,8 @@ jobs:
PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py
- name: Test DEBUG - name: Test DEBUG
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())" run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
- name: Repo line count <9800 lines - name: Repo line count <= 9999 lines
run: MAX_LINE_COUNT=9800 python sz.py run: MAX_LINE_COUNT=9999 python sz.py
testopencl: testopencl:
strategy: strategy:

2
sz.py
View File

@ -74,4 +74,4 @@ if __name__ == "__main__":
total_lines = sum([x[1] for x in table]) total_lines = sum([x[1] for x in table])
print(f"\ntotal line count: {total_lines}") print(f"\ntotal line count: {total_lines}")
max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1")) max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1"))
assert max_line_count == -1 or total_lines < max_line_count, f"OVER {max_line_count} LINES" assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES"

View File

@ -1,10 +1,25 @@
import unittest, pickle import unittest, pickle, types
import numpy as np import numpy as np
from test.helpers import assert_equiv_uops from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import PatternMatcher, UPat, UOp
class TestPickle(unittest.TestCase): class TestPickle(unittest.TestCase):
def test_pickle_code_object(self):
y = lambda x: x*2 # noqa: E731
code_str = pickle.dumps(y.__code__)
fxn = types.FunctionType(pickle.loads(code_str), globals())
self.assertEqual(fxn(2), 4)
def test_pickle_pattern_matcher(self):
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
sink = UOp.const(dtypes.int, 2)
tt = pm.rewrite(sink)
pm_str = pickle.dumps(pm)
pm2 = pickle.loads(pm_str)
self.assertEqual(pm2.rewrite(sink).key, tt.key)
def test_pickle_realized_tensor(self): def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize() t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t) st = pickle.dumps(t)

View File

@ -11,6 +11,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c2), None)
@unittest.skip("closures aren't supported on pattern matchers")
def test_match_sz_0(self): def test_match_sz_0(self):
match_cnt = 0 match_cnt = 0
def fxn(x): def fxn(x):
@ -25,6 +26,19 @@ class TestPatternMatcher(unittest.TestCase):
c1 = matcher.rewrite(c1) c1 = matcher.rewrite(c1)
self.assertEqual(match_cnt, 1) self.assertEqual(match_cnt, 1)
def test_match_sz_0_ctx(self):
def fxn(ctx, x):
ctx.append(True)
assert len(x.src) == 0
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
ctx = []
c1 = matcher.rewrite(c1, ctx)
c1 = matcher.rewrite(c1, ctx)
self.assertEqual(len(ctx), 1)
def test_uop(self): def test_uop(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)]) matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@ -373,3 +373,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}" if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x))) cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
# *** universal support for code object pickling
def _reconstruct_code(*args): return types.CodeType(*args)
def _serialize_code(code:types.CodeType):
# NOTE: this works in Python 3.8 and up
if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters
else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring',
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
copyreg.pickle(types.CodeType, _serialize_code)

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
from types import FrameType from types import FrameType
import sys, time, functools, itertools, math, operator, hashlib, os import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
@ -473,6 +473,14 @@ class UPatAny(UPat):
if (match:=x.match(uop, store.copy())): return match if (match:=x.match(uop, store.copy())): return match
return [] return []
def deconstruct_function(fxn:Callable) -> Tuple:
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
for co in fxn.__code__.co_consts:
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle!
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__
class PatternMatcher: class PatternMatcher:
def __init__(self, patterns:List[Tuple[UPat, Callable]]): def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns self.patterns = patterns
@ -481,7 +489,12 @@ class PatternMatcher:
# uop is required, arg is optional # uop is required, arg is optional
for p,fxn in self.patterns: for p,fxn in self.patterns:
assert p.op is not None assert p.op is not None
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject)) tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends
real_fxn = types.FunctionType(*tuple_fxn)
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, real_fxn, p.early_reject))
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
@ -531,7 +544,7 @@ class TrackedPatternMatcher(PatternMatcher):
if TRACK_MATCH_STATS: if TRACK_MATCH_STATS:
PatternMatcher = TrackedPatternMatcher # type: ignore PatternMatcher = TrackedPatternMatcher # type: ignore
import atexit, pickle import atexit
@atexit.register @atexit.register
def print_match_stats(): def print_match_stats():
if TRACK_MATCH_STATS >= 2: if TRACK_MATCH_STATS >= 2: