mirror of https://github.com/commaai/tinygrad.git
track et in viz [pr] (#7088)
This commit is contained in:
parent
40f33c110b
commit
9c9c241e58
|
@ -597,9 +597,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
|||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedRewriteContext:
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat]]] = field(default_factory=list) # all matches of sparents. (start, replace, _maybe_ UPat)
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
|
||||
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
|
@ -632,10 +632,10 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p))
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None))
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
|
||||
return None
|
||||
|
||||
if TRACK_MATCH_STATS:
|
||||
|
|
|
@ -24,7 +24,7 @@ class GraphRewriteMetadata:
|
|||
"""The Python line calling graph_rewrite"""
|
||||
kernel_name: Optional[str]
|
||||
"""The kernel calling graph_rewrite"""
|
||||
upats: List[Tuple[Tuple[str, int], str]]
|
||||
upats: List[Tuple[Tuple[str, int], str, float]]
|
||||
"""List of all the applied UPats"""
|
||||
|
||||
@dataclass
|
||||
|
@ -47,7 +47,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
|
|||
name = to_function_name(k.name) if isinstance(k, Kernel) else None
|
||||
for ctx in ctxs:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.matches if upat is not None]
|
||||
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
|
||||
if name not in kernels: kernels[name] = []
|
||||
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
return list(kernels.values())
|
||||
|
@ -72,7 +72,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
|||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat) in enumerate(ctx.matches):
|
||||
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
|
||||
replaces[u0] = u0 if u1 is None else u1
|
||||
# if the match didn't result in a rewrite we move forward
|
||||
if u1 is None: continue
|
||||
|
|
Loading…
Reference in New Issue