From 0f28e932244e1b4eb37ea9f5c2bd54adb306e8a0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:54:46 +0800 Subject: [PATCH] add pickle support for pattern matchers [run_process_replay] (#6816) * add pickle support for pattern matchers [run_process_replay] * cleaner and all * no closures * fix tests * revert that * final * cleaner * python 3.8 fix * add round trip back * this * waste lines on this. that's the final line count * max print better * more targetted fix * regrettably add 3.8 support --- .github/workflows/test.yml | 4 ++-- sz.py | 2 +- test/test_pickle.py | 19 +++++++++++++++++-- test/unit/test_pattern_matcher.py | 14 ++++++++++++++ tinygrad/helpers.py | 13 ++++++++++++- tinygrad/ops.py | 19 ++++++++++++++++--- 6 files changed, 62 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a400e2f..b2d52675 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -151,8 +151,8 @@ jobs: PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py - name: Test DEBUG run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())" - - name: Repo line count <9800 lines - run: MAX_LINE_COUNT=9800 python sz.py + - name: Repo line count <= 9999 lines + run: MAX_LINE_COUNT=9999 python sz.py testopencl: strategy: diff --git a/sz.py b/sz.py index aa775392..ffed8807 100755 --- a/sz.py +++ b/sz.py @@ -74,4 +74,4 @@ if __name__ == "__main__": total_lines = sum([x[1] for x in table]) print(f"\ntotal line count: {total_lines}") max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1")) - assert max_line_count == -1 or total_lines < max_line_count, f"OVER {max_line_count} LINES" + assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES" diff --git a/test/test_pickle.py b/test/test_pickle.py index 050cb4ce..d66769e6 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,10 +1,25 @@ -import unittest, pickle +import unittest, pickle, types import numpy as np from test.helpers import assert_equiv_uops -from tinygrad import Tensor, TinyJit, Variable +from tinygrad import Tensor, TinyJit, Variable, dtypes from tinygrad.engine.schedule import create_schedule +from tinygrad.ops import PatternMatcher, UPat, UOp class TestPickle(unittest.TestCase): + def test_pickle_code_object(self): + y = lambda x: x*2 # noqa: E731 + code_str = pickle.dumps(y.__code__) + fxn = types.FunctionType(pickle.loads(code_str), globals()) + self.assertEqual(fxn(2), 4) + + def test_pickle_pattern_matcher(self): + pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)]) + sink = UOp.const(dtypes.int, 2) + tt = pm.rewrite(sink) + pm_str = pickle.dumps(pm) + pm2 = pickle.loads(pm_str) + self.assertEqual(pm2.rewrite(sink).key, tt.key) + def test_pickle_realized_tensor(self): t = Tensor.rand(10, 10).realize() st = pickle.dumps(t) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 1354a0cb..7563bc60 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -11,6 +11,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) + @unittest.skip("closures aren't supported on pattern matchers") def test_match_sz_0(self): match_cnt = 0 def fxn(x): @@ -25,6 +26,19 @@ class TestPatternMatcher(unittest.TestCase): c1 = matcher.rewrite(c1) self.assertEqual(match_cnt, 1) + def test_match_sz_0_ctx(self): + def fxn(ctx, x): + ctx.append(True) + assert len(x.src) == 0 + return UOp(UOps.CONST, src=(UOp(UOps.CONST),)) + matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + # second rewrite shouldn't match anything + ctx = [] + c1 = matcher.rewrite(c1, ctx) + c1 = matcher.rewrite(c1, ctx) + self.assertEqual(len(ctx), 1) + def test_uop(self): matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 16349ddf..5218a8ac 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip -import itertools, urllib.request, subprocess, shutil, math, json, contextvars +import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect from dataclasses import dataclass from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 @@ -373,3 +373,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}" cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x))) return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs + +# *** universal support for code object pickling + +def _reconstruct_code(*args): return types.CodeType(*args) +def _serialize_code(code:types.CodeType): + # NOTE: this works in Python 3.8 and up + if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters + else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring', + 'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars'] + return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args) +copyreg.pickle(types.CodeType, _serialize_code) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0bbcfdf9..7b8fc617 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar from types import FrameType -import sys, time, functools, itertools, math, operator, hashlib, os +import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle from enum import auto, IntEnum, Enum from dataclasses import dataclass, field from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate @@ -473,6 +473,14 @@ class UPatAny(UPat): if (match:=x.match(uop, store.copy())): return match return [] +def deconstruct_function(fxn:Callable) -> Tuple: + new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} + for co in fxn.__code__.co_consts: + if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names}) + new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle! + assert fxn.__closure__ is None, "closures are not supported in pattern matchers" + return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__ + class PatternMatcher: def __init__(self, patterns:List[Tuple[UPat, Callable]]): self.patterns = patterns @@ -481,7 +489,12 @@ class PatternMatcher: # uop is required, arg is optional for p,fxn in self.patterns: assert p.op is not None - for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject)) + tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn) + tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends + real_fxn = types.FunctionType(*tuple_fxn) + for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, real_fxn, p.early_reject)) + + def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "" else fxn) for x,fxn in self.patterns],) @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) @@ -531,7 +544,7 @@ class TrackedPatternMatcher(PatternMatcher): if TRACK_MATCH_STATS: PatternMatcher = TrackedPatternMatcher # type: ignore - import atexit, pickle + import atexit @atexit.register def print_match_stats(): if TRACK_MATCH_STATS >= 2: