always record matches in viz (#7073)

* always record matches in viz

* simpler
This commit is contained in:
qazal 2024-10-15 23:03:12 +03:00 committed by GitHub
parent b025495e5c
commit 545e79969f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 12 deletions

View File

@ -43,7 +43,6 @@ class TestViz(unittest.TestCase):
self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
@unittest.expectedFailure
def test_rewrite_with_ctx(self):
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0)))

View File

@ -596,9 +596,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat]]] = field(default_factory=list) # all matches of sparents. (start, replace, _maybe_ UPat)
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
@ -631,9 +631,10 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].rewrites.append((uop, ret, p))
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None))
return None
if TRACK_MATCH_STATS:
@ -643,7 +644,7 @@ if TRACK_MATCH_STATS:
def print_match_stats():
if TRACK_MATCH_STATS >= 2:
with open("/tmp/rewrites.pkl", "wb") as f:
print(f"rewrote {len(contexts)} graphs and applied {sum(len(r.rewrites) for _,x in contexts for r in x)} rules, saved to /tmp/rewrites.pkl")
print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to /tmp/rewrites.pkl")
pickle.dump(contexts, f)
if getenv("VIZ"):
os.environ["VIZ"] = "0"

View File

@ -4,7 +4,7 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, word_wrap
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.codegen.kernel import Kernel
@ -47,7 +47,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
name = to_function_name(k.name) if isinstance(k, Kernel) else None
for ctx in ctxs:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.matches if upat is not None]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
@ -72,13 +72,15 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
replaces[u0] = u1
for i,(u0,u1,upat) in enumerate(ctx.matches):
replaces[u0] = u0 if u1 is None else u1
# if the match didn't result in a rewrite we move forward
if u1 is None: continue
# first, rewrite this UOp with the current rewrite + all the seen matches before this
new_sink = _replace_uop(sink, {**replaces})
# sanity check
if new_sink is sink:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))