mirror of https://github.com/commaai/tinygrad.git
always record matches in viz (#7073)
* always record matches in viz * simpler
This commit is contained in:
parent
b025495e5c
commit
545e79969f
|
@ -43,7 +43,6 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(uops[0], a*2)
|
||||
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_rewrite_with_ctx(self):
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0)))
|
||||
|
|
|
@ -596,9 +596,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
|||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedRewriteContext:
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat]]] = field(default_factory=list) # all matches of sparents. (start, replace, _maybe_ UPat)
|
||||
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
|
@ -631,9 +631,10 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].rewrites.append((uop, ret, p))
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None))
|
||||
return None
|
||||
|
||||
if TRACK_MATCH_STATS:
|
||||
|
@ -643,7 +644,7 @@ if TRACK_MATCH_STATS:
|
|||
def print_match_stats():
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
with open("/tmp/rewrites.pkl", "wb") as f:
|
||||
print(f"rewrote {len(contexts)} graphs and applied {sum(len(r.rewrites) for _,x in contexts for r in x)} rules, saved to /tmp/rewrites.pkl")
|
||||
print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to /tmp/rewrites.pkl")
|
||||
pickle.dump(contexts, f)
|
||||
if getenv("VIZ"):
|
||||
os.environ["VIZ"] = "0"
|
||||
|
|
|
@ -4,7 +4,7 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
|
|||
from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, word_wrap
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
|
@ -47,7 +47,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
|
|||
name = to_function_name(k.name) if isinstance(k, Kernel) else None
|
||||
for ctx in ctxs:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.matches if upat is not None]
|
||||
if name not in kernels: kernels[name] = []
|
||||
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
return list(kernels.values())
|
||||
|
@ -72,13 +72,15 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
|||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
replaces[u0] = u1
|
||||
for i,(u0,u1,upat) in enumerate(ctx.matches):
|
||||
replaces[u0] = u0 if u1 is None else u1
|
||||
# if the match didn't result in a rewrite we move forward
|
||||
if u1 is None: continue
|
||||
# first, rewrite this UOp with the current rewrite + all the seen matches before this
|
||||
new_sink = _replace_uop(sink, {**replaces})
|
||||
# sanity check
|
||||
if new_sink is sink:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
|
||||
|
|
Loading…
Reference in New Issue