From 9250452da4f6db3b137281a9f7cb7df9a9efa0b9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:54:21 +0300 Subject: [PATCH] no codegen import in ops [pr] (#6888) * no codegen import in ops [pr] * @track_rewrites * all functions need this * polish --- tinygrad/codegen/kernel.py | 5 +++-- tinygrad/engine/schedule.py | 8 +++++--- tinygrad/ops.py | 23 ++++++++++++----------- viz/test_viz.py | 10 +++++++--- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index ddbacc5c..0d4eabf7 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,8 +5,8 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict from enum import Enum, auto -from tinygrad.ops import BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, graph_rewrite, PatternMatcher -from tinygrad.ops import resolve +from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \ + graph_rewrite, track_rewrites from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.dtype import ImageDType, PtrDType @@ -704,6 +704,7 @@ class Kernel: # **** this is the lowerer **** + @track_rewrites def linearize(self) -> Kernel: modified_ast = self.get_optimized_ast() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index fdb38768..c492440a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,9 +2,10 @@ import sys, pickle, atexit, uuid from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast -from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite, resolve -from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \ - GlobalCounters, all_same, colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap +from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \ + graph_rewrite, track_rewrites +from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, GlobalCounters, all_same, \ + colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ImageDType, dtypes from tinygrad.engine.lazy import LazyBuffer @@ -124,6 +125,7 @@ reduceop_fusor = PatternMatcher([ enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0])))]) +@track_rewrites def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp: if not AST_REWRITE: return base_sink sink = graph_rewrite(base_sink, reduceop_fusor) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0ad12675..8be96d99 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,6 +1,5 @@ from __future__ import annotations from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar -from types import FrameType import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle from enum import auto, IntEnum, Enum from dataclasses import dataclass, field @@ -587,8 +586,17 @@ class TrackedRewriteContext: loc: Tuple[str, int] # location that called graph_rewrite sink: UOp # the sink passed into the rewrite rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat) + rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = [] contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = [] +def track_rewrites(func): + def __wrapper(self, *args, **kwargs): + if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, [])) + ret = func(self, *args, **kwargs) + if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) + return ret + return __wrapper + class TrackedPatternMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[UPat, Callable]]): super().__init__(patterns) @@ -649,16 +657,9 @@ class RewriteContext: return ret def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: - if TRACK_MATCH_STATS >= 2: - from tinygrad.codegen.kernel import Kernel - frm = sys._getframe(1) - # get Kernel we are rewriting in the context of - frm_walk: Optional[FrameType] = frm - while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back - rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)])) - ret = RewriteContext(pm, ctx).rewrite(sink) - if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) - return ret + if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: + rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + return RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** diff --git a/viz/test_viz.py b/viz/test_viz.py index 089b83dc..c40be1ba 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -5,7 +5,7 @@ os.environ["TRACK_MATCH_STATS"] = "2" os.environ["PRINT_MATCH_STATS"] = "0" from tinygrad import Tensor from tinygrad.engine.realize import lower_schedule -from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps +from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import Context, all_same, DEBUG, getenv from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding @@ -61,7 +61,9 @@ class TestViz(unittest.TestCase): lambda root,const: UOp.const_like(root, const.arg) if all_same(root.src) else None), (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),), location="test"), lambda root,x: root.const_like(x.arg)) ]) - ret = graph_rewrite(sink, pm) + @track_rewrites + def f(k): return graph_rewrite(sink, pm) + ret = f("test_rewrite") if DEBUG >= 4: print_diff(sink, ret) graphs,_,_ = reconstruct_graph(contexts[0][1][0]) assert graphs[-1].key == ret.key @@ -92,7 +94,9 @@ class TestViz(unittest.TestCase): x11, x7,)),)),)),)) pm = sym+(devectorize+float4_folding) - new_sink = graph_rewrite(sink, pm) + @track_rewrites + def f(k): return graph_rewrite(sink, pm) + new_sink = f("test_rewrite") if DEBUG >= 4: print_diff(sink, new_sink, unified=0) self.assert_valid_ctx(contexts) assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)