mirror of https://github.com/commaai/tinygrad.git
no codegen import in ops [pr] (#6888)
* no codegen import in ops [pr] * @track_rewrites * all functions need this * polish
This commit is contained in:
parent
f7f94cd62f
commit
9250452da4
|
@ -5,8 +5,8 @@ from collections import defaultdict
|
|||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, graph_rewrite, PatternMatcher
|
||||
from tinygrad.ops import resolve
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
|
@ -704,6 +704,7 @@ class Kernel:
|
|||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
@track_rewrites
|
||||
def linearize(self) -> Kernel:
|
||||
modified_ast = self.get_optimized_ast()
|
||||
|
||||
|
|
|
@ -2,9 +2,10 @@ import sys, pickle, atexit, uuid
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite, resolve
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, GlobalCounters, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
|
@ -124,6 +125,7 @@ reduceop_fusor = PatternMatcher([
|
|||
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))])
|
||||
|
||||
@track_rewrites
|
||||
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
if not AST_REWRITE: return base_sink
|
||||
sink = graph_rewrite(base_sink, reduceop_fusor)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
from types import FrameType
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -587,8 +586,17 @@ class TrackedRewriteContext:
|
|||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
def track_rewrites(func):
|
||||
def __wrapper(self, *args, **kwargs):
|
||||
if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, []))
|
||||
ret = func(self, *args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
return ret
|
||||
return __wrapper
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
|
@ -649,16 +657,9 @@ class RewriteContext:
|
|||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
frm = sys._getframe(1)
|
||||
# get Kernel we are rewriting in the context of
|
||||
frm_walk: Optional[FrameType] = frm
|
||||
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
|
||||
rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)]))
|
||||
ret = RewriteContext(pm, ctx).rewrite(sink)
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
return ret
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
|
||||
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
return RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ os.environ["TRACK_MATCH_STATS"] = "2"
|
|||
os.environ["PRINT_MATCH_STATS"] = "0"
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import Context, all_same, DEBUG, getenv
|
||||
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
|
||||
|
@ -61,7 +61,9 @@ class TestViz(unittest.TestCase):
|
|||
lambda root,const: UOp.const_like(root, const.arg) if all_same(root.src) else None),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),), location="test"), lambda root,x: root.const_like(x.arg))
|
||||
])
|
||||
ret = graph_rewrite(sink, pm)
|
||||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
ret = f("test_rewrite")
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
|
||||
assert graphs[-1].key == ret.key
|
||||
|
@ -92,7 +94,9 @@ class TestViz(unittest.TestCase):
|
|||
x11,
|
||||
x7,)),)),)),))
|
||||
pm = sym+(devectorize+float4_folding)
|
||||
new_sink = graph_rewrite(sink, pm)
|
||||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
new_sink = f("test_rewrite")
|
||||
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
|
||||
|
|
Loading…
Reference in New Issue