viz prep refactor for tracked scope decorator [pr] (#6920)

* viz prep refactor for tracked scope decorator [pr]

* fix fuzzer
This commit is contained in:
qazal 2024-10-06 16:02:09 +03:00 committed by GitHub
parent 837f9c6832
commit 10ff1d6fb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 34 deletions

View File

@ -10,7 +10,6 @@ from tinygrad.helpers import ContextVar, prod, getenv, all_same
if TYPE_CHECKING:
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.codegen.kernel import Kernel
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
class FastEnum(IntEnum):
@ -587,10 +586,9 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
kernel: Optional[Kernel] = None # the kernel being rewritten
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
contexts: List[TrackedRewriteContext] = []
rewrite_stack: List[TrackedRewriteContext] = []
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
@ -611,7 +609,7 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1].rewrites.append((uop, ret, p))
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].rewrites.append((uop, ret, p))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
return None
@ -623,7 +621,7 @@ if TRACK_MATCH_STATS:
def print_match_stats():
if TRACK_MATCH_STATS >= 2:
with open("/tmp/rewrites.pkl", "wb") as f:
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
print(f"rewrote {len(contexts)} graphs and applied {sum(len(r.rewrites) for _,x in contexts for r in x)} rules, saved to /tmp/rewrites.pkl")
pickle.dump(contexts, f)
if getenv("VIZ"):
os.environ["VIZ"] = "0"
@ -657,7 +655,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
# get Kernel we are rewriting in the context of
frm_walk: Optional[FrameType] = frm
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
rewrite_stack.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)]))
ret = RewriteContext(pm, ctx).rewrite(sink)
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
return ret

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
from dataclasses import asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import getenv, to_function_name, tqdm
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
@ -47,17 +48,20 @@ def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
return ret
def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \
TrackedRewriteContext, Any]]]:
kernels = defaultdict(list)
for ctx in contexts:
if ctx.sink.op is UOps.CONST: continue
name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx))
for k,rewrites in contexts:
if isinstance(k, Kernel): name = to_function_name(k.name)
else: name = None
for ctx in rewrites:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k))
return kernels
@functools.lru_cache(None)
def get_src(k) -> Optional[str]: return k.to_program().src if k else None
def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@ -79,10 +83,10 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers()
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode()
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
else:
self.send_response(404)
@ -102,12 +106,12 @@ def reloader():
if __name__ == "__main__":
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
print("*** viz is starting")
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = load_kernels(contexts)
if getenv("FUZZ_VIZ"):
for v in tqdm(kernels.values()):
for _,ctx in v: reconstruct_graph(ctx)
for _,ctx,_ in v: reconstruct_graph(ctx)
print("*** loaded kernels")
server = HTTPServer(('', 8000), Handler)
st = time.perf_counter()

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Any, List, Tuple
import unittest
import os, itertools
os.environ["TRACK_MATCH_STATS"] = "2"
@ -7,7 +7,7 @@ from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import Context, all_same, DEBUG, colored, getenv
from tinygrad.helpers import Context, all_same, DEBUG, getenv
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
@ -20,16 +20,9 @@ class TestViz(unittest.TestCase):
from tinygrad.ops import contexts
if not getenv("VIZ"): contexts.clear()
def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]):
def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
assert len(contexts) != 0
for i,ctx in enumerate(contexts):
try: graphs,_,_ = reconstruct_graph(ctx)
except Exception as e:
print(colored(f"failed to create graph for ctx {i}", "red"))
raise e
for j,(x,y) in enumerate(zip(graphs, graphs[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
load_kernels(contexts)
def assert_valid_graph(self, t):
contexts.clear()
@ -50,8 +43,8 @@ class TestViz(unittest.TestCase):
list(lower_schedule(schedule2))
with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values())
assert len(ret) == 3
assert all(len([x for x,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for x,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
@ -70,7 +63,7 @@ class TestViz(unittest.TestCase):
])
ret = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, ret)
graphs,_,_ = reconstruct_graph(contexts[0])
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
assert graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
@ -102,7 +95,7 @@ class TestViz(unittest.TestCase):
new_sink = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
self.assert_valid_ctx(contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
def test_no_ctx(self):
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])