no codegen import in ops [pr] (#6888)

* no codegen import in ops [pr]

* @track_rewrites

* all functions need this

* polish
This commit is contained in:
qazal 2024-10-07 15:54:21 +03:00 committed by GitHub
parent f7f94cd62f
commit 9250452da4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 19 deletions

View File

@ -5,8 +5,8 @@ from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
from enum import Enum, auto
from tinygrad.ops import BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, graph_rewrite, PatternMatcher
from tinygrad.ops import resolve
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
graph_rewrite, track_rewrites
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType, PtrDType
@ -704,6 +704,7 @@ class Kernel:
# **** this is the lowerer ****
@track_rewrites
def linearize(self) -> Kernel:
modified_ast = self.get_optimized_ast()

View File

@ -2,9 +2,10 @@ import sys, pickle, atexit, uuid
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite, resolve
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \
graph_rewrite, track_rewrites
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, GlobalCounters, all_same, \
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.engine.lazy import LazyBuffer
@ -124,6 +125,7 @@ reduceop_fusor = PatternMatcher([
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))])
@track_rewrites
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return base_sink
sink = graph_rewrite(base_sink, reduceop_fusor)

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
from types import FrameType
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
@ -587,8 +586,17 @@ class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
def track_rewrites(func):
def __wrapper(self, *args, **kwargs):
if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, []))
ret = func(self, *args, **kwargs)
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
return ret
return __wrapper
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
@ -649,16 +657,9 @@ class RewriteContext:
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:
from tinygrad.codegen.kernel import Kernel
frm = sys._getframe(1)
# get Kernel we are rewriting in the context of
frm_walk: Optional[FrameType] = frm
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)]))
ret = RewriteContext(pm, ctx).rewrite(sink)
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
return ret
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
return RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****

View File

@ -5,7 +5,7 @@ os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["PRINT_MATCH_STATS"] = "0"
from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import Context, all_same, DEBUG, getenv
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
@ -61,7 +61,9 @@ class TestViz(unittest.TestCase):
lambda root,const: UOp.const_like(root, const.arg) if all_same(root.src) else None),
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),), location="test"), lambda root,x: root.const_like(x.arg))
])
ret = graph_rewrite(sink, pm)
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
ret = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, ret)
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
assert graphs[-1].key == ret.key
@ -92,7 +94,9 @@ class TestViz(unittest.TestCase):
x11,
x7,)),)),)),))
pm = sym+(devectorize+float4_folding)
new_sink = graph_rewrite(sink, pm)
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
new_sink = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
self.assert_valid_ctx(contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)