track et in viz [pr] (#7088)

This commit is contained in:
qazal 2024-10-16 07:53:08 +03:00 committed by GitHub
parent 40f33c110b
commit 9c9c241e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 8 deletions

View File

@ -597,9 +597,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat]]] = field(default_factory=list) # all matches of sparents. (start, replace, _maybe_ UPat)
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
@ -632,10 +632,10 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p))
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None))
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
return None
if TRACK_MATCH_STATS:

View File

@ -24,7 +24,7 @@ class GraphRewriteMetadata:
"""The Python line calling graph_rewrite"""
kernel_name: Optional[str]
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str]]
upats: List[Tuple[Tuple[str, int], str, float]]
"""List of all the applied UPats"""
@dataclass
@ -47,7 +47,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
name = to_function_name(k.name) if isinstance(k, Kernel) else None
for ctx in ctxs:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.matches if upat is not None]
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
@ -72,7 +72,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat) in enumerate(ctx.matches):
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
replaces[u0] = u0 if u1 is None else u1
# if the match didn't result in a rewrite we move forward
if u1 is None: continue