mirror of https://github.com/commaai/tinygrad.git
viz prep refactor for tracked scope decorator [pr] (#6920)
* viz prep refactor for tracked scope decorator [pr] * fix fuzzer
This commit is contained in:
parent
837f9c6832
commit
10ff1d6fb9
|
@ -10,7 +10,6 @@ from tinygrad.helpers import ContextVar, prod, getenv, all_same
|
|||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
class FastEnum(IntEnum):
|
||||
|
@ -587,10 +586,9 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
|||
class TrackedRewriteContext:
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
kernel: Optional[Kernel] = None # the kernel being rewritten
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
contexts: List[TrackedRewriteContext] = []
|
||||
rewrite_stack: List[TrackedRewriteContext] = []
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
|
@ -611,7 +609,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1].rewrites.append((uop, ret, p))
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].rewrites.append((uop, ret, p))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
return None
|
||||
|
@ -623,7 +621,7 @@ if TRACK_MATCH_STATS:
|
|||
def print_match_stats():
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
with open("/tmp/rewrites.pkl", "wb") as f:
|
||||
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
|
||||
print(f"rewrote {len(contexts)} graphs and applied {sum(len(r.rewrites) for _,x in contexts for r in x)} rules, saved to /tmp/rewrites.pkl")
|
||||
pickle.dump(contexts, f)
|
||||
if getenv("VIZ"):
|
||||
os.environ["VIZ"] = "0"
|
||||
|
@ -657,7 +655,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
|||
# get Kernel we are rewriting in the context of
|
||||
frm_walk: Optional[FrameType] = frm
|
||||
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
|
||||
rewrite_stack.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
|
||||
rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)]))
|
||||
ret = RewriteContext(pm, ctx).rewrite(sink)
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
return ret
|
||||
|
|
28
viz/serve.py
28
viz/serve.py
|
@ -1,10 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Dict, List, Optional, Tuple
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
|
||||
from dataclasses import asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import getenv, to_function_name, tqdm
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
|
@ -47,17 +48,20 @@ def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
|||
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
|
||||
return ret
|
||||
|
||||
def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
|
||||
def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \
|
||||
TrackedRewriteContext, Any]]]:
|
||||
kernels = defaultdict(list)
|
||||
for ctx in contexts:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
|
||||
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx))
|
||||
for k,rewrites in contexts:
|
||||
if isinstance(k, Kernel): name = to_function_name(k.name)
|
||||
else: name = None
|
||||
for ctx in rewrites:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
|
||||
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k))
|
||||
return kernels
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_src(k) -> Optional[str]: return k.to_program().src if k else None
|
||||
def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
@ -79,10 +83,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.end_headers()
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
|
||||
metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
|
||||
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
|
||||
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
|
||||
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode()
|
||||
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
@ -102,12 +106,12 @@ def reloader():
|
|||
if __name__ == "__main__":
|
||||
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
|
||||
print("*** viz is starting")
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f)
|
||||
print("*** unpickled saved rewrites")
|
||||
kernels = load_kernels(contexts)
|
||||
if getenv("FUZZ_VIZ"):
|
||||
for v in tqdm(kernels.values()):
|
||||
for _,ctx in v: reconstruct_graph(ctx)
|
||||
for _,ctx,_ in v: reconstruct_graph(ctx)
|
||||
print("*** loaded kernels")
|
||||
server = HTTPServer(('', 8000), Handler)
|
||||
st = time.perf_counter()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import Any, List, Tuple
|
||||
import unittest
|
||||
import os, itertools
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
|
@ -7,7 +7,7 @@ from tinygrad import Tensor
|
|||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.helpers import Context, all_same, DEBUG, getenv
|
||||
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
|
||||
|
@ -20,16 +20,9 @@ class TestViz(unittest.TestCase):
|
|||
from tinygrad.ops import contexts
|
||||
if not getenv("VIZ"): contexts.clear()
|
||||
|
||||
def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]):
|
||||
def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
|
||||
assert len(contexts) != 0
|
||||
for i,ctx in enumerate(contexts):
|
||||
try: graphs,_,_ = reconstruct_graph(ctx)
|
||||
except Exception as e:
|
||||
print(colored(f"failed to create graph for ctx {i}", "red"))
|
||||
raise e
|
||||
for j,(x,y) in enumerate(zip(graphs, graphs[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
|
||||
load_kernels(contexts)
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
contexts.clear()
|
||||
|
@ -50,8 +43,8 @@ class TestViz(unittest.TestCase):
|
|||
list(lower_schedule(schedule2))
|
||||
with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values())
|
||||
assert len(ret) == 3
|
||||
assert all(len([x for x,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for x,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
|
@ -70,7 +63,7 @@ class TestViz(unittest.TestCase):
|
|||
])
|
||||
ret = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
graphs,_,_ = reconstruct_graph(contexts[0])
|
||||
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
|
||||
assert graphs[-1].key == ret.key
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
|
@ -102,7 +95,7 @@ class TestViz(unittest.TestCase):
|
|||
new_sink = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
|
||||
|
||||
def test_no_ctx(self):
|
||||
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
||||
|
|
Loading…
Reference in New Issue