From 0ecc417dd2a9d7bb4be3b2877f503b44c4cec827 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:24:04 +0300 Subject: [PATCH] prep for viz move to core [pr] (#6938) * prep for viz move to core [pr] * polish --- viz/serve.py | 133 +++++++++++++++++++++++++++--------------------- viz/spec.py | 26 ---------- viz/test_viz.py | 39 +++++++------- 3 files changed, 93 insertions(+), 105 deletions(-) delete mode 100644 viz/spec.py diff --git a/viz/serve.py b/viz/serve.py index 3c56c105..2ac515cf 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,35 +1,54 @@ #!/usr/bin/env python3 -from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Optional, Tuple -import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools -from dataclasses import asdict -from urllib.parse import parse_qs, urlparse +import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser from http.server import HTTPServer, BaseHTTPRequestHandler -from tinygrad.codegen.kernel import Kernel +from urllib.parse import parse_qs, urlparse +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Tuple, Optional from tinygrad.helpers import getenv, to_function_name, tqdm from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines -from tinygrad.engine.graph import uops_colors, word_wrap -from viz.spec import GraphRewriteDetails, GraphRewriteMetadata +from tinygrad.engine.graph import word_wrap, uops_colors +from tinygrad.codegen.kernel import Kernel -def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: - uops: List[UOp] = [ctx.sink] - diffs: List[List[str]] = [] - changed_nodes: List[List[int]] = [] - seen_replaces: Dict[UOp, UOp] = {} - for i, (first, rewritten, upat) in enumerate(ctx.rewrites): - # first, rewrite this UOp with the current rewrite + all the seen rewrites before this - seen_replaces[first] = rewritten - new_sink = replace_uop(uops[-1], {**seen_replaces}) - # sanity check - if new_sink is uops[-1]: - raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}") - # update ret data - changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST]) - diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))) - uops.append(new_sink) - return uops, diffs, changed_nodes +# ** API spec -def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: +@dataclass +class GraphRewriteMetadata: + """Specifies metadata about a single call to graph_rewrite""" + loc: Tuple[str, int] + """File_path, Lineno""" + code_line: str + """The Python line calling graph_rewrite""" + kernel_name: Optional[str] + """The kernel calling graph_rewrite""" + upats: List[Tuple[Tuple[str, int], str]] + """List of all the applied UPats""" + +@dataclass +class GraphRewriteDetails(GraphRewriteMetadata): + """Full details about a single call to graph_rewrite""" + graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] + """Sink at every step of graph_rewrite""" + diffs: List[List[str]] + """.diff style before and after of the rewritten UOp child""" + changed_nodes: List[List[int]] + """Nodes that changed at every step of graph_rewrite""" + kernel_code: Optional[str] + """The program after all rewrites""" + +# ** API functions + +def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]: + kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {} + for k,ctxs in contexts: + name = to_function_name(k.name) if isinstance(k, Kernel) else None + for ctx in ctxs: + if ctx.sink.op is UOps.CONST: continue + upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites] + if name not in kernels: kernels[name] = [] + kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) + return list(kernels.values()) + +def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.sparents: @@ -37,31 +56,32 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}" - if getenv("WITH_SHAPE"): - with contextlib.suppress(Exception): # if the UOp is indexed already it's fine - if u.st is not None: label += f"\n{u.st.shape}" graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph - -def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: +def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: if (found:=replaces.get(base)) is not None: return found - replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src)) + replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src)) return ret - -def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \ - TrackedRewriteContext, Any]]]: - kernels = defaultdict(list) - for k,rewrites in contexts: - if isinstance(k, Kernel): name = to_function_name(k.name) - else: name = None - for ctx in rewrites: - if ctx.sink.op is UOps.CONST: continue - upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites] - kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k)) - return kernels - @functools.lru_cache(None) -def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None +def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None +def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: + g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k)) + replaces: Dict[UOp, UOp] = {} + sink = ctx.sink + for i,(u0,u1,upat) in enumerate(ctx.rewrites): + # first, rewrite this UOp with the current rewrite + all the seen rewrites before this + replaces[u0] = u1 + new_sink = _replace_uop(sink, {**replaces}) + # sanity check + if new_sink is sink: + raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}") + # update ret data + g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST]) + g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines()))) + g.graphs.append(_uop_to_json(sink:=new_sink)) + return g + +# ** HTTP server class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -83,17 +103,15 @@ class Handler(BaseHTTPRequestHandler): self.end_headers() query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: - metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])] - graphs, diffs, changed_nodes = reconstruct_graph(ctx) - ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)), - diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode() - else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode() + ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode() + else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode() else: self.send_response(404) ret = b"" return self.wfile.write(ret) -BROWSER = getenv("BROWSER", 1) +# ** main loop + stop_reloader = threading.Event() def reloader(): mtime = os.stat(__file__).st_mtime @@ -108,18 +126,17 @@ if __name__ == "__main__": print("*** viz is starting") with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f) print("*** unpickled saved rewrites") - kernels = load_kernels(contexts) + kernels = get_metadata(contexts) if getenv("FUZZ_VIZ"): - for v in tqdm(kernels.values()): - for _,ctx,_ in v: reconstruct_graph(ctx) + ret = [get_details(*args) for v in tqdm(kernels) for args in v] + print(f"fuzzed {len(ret)} rewrite details") print("*** loaded kernels") server = HTTPServer(('', 8000), Handler) st = time.perf_counter() reloader_thread = threading.Thread(target=reloader) reloader_thread.start() - if BROWSER: webbrowser.open("http://localhost:8000") - try: - server.serve_forever() + if getenv("BROWSER", 1): webbrowser.open("http://localhost:8000") + try: server.serve_forever() except KeyboardInterrupt: print("*** viz is shutting down...") stop_reloader.set() diff --git a/viz/spec.py b/viz/spec.py deleted file mode 100644 index db5ec078..00000000 --- a/viz/spec.py +++ /dev/null @@ -1,26 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -@dataclass(frozen=True) -class GraphRewriteMetadata: - """Specifies metadata about a single call to graph_rewrite""" - loc: Tuple[str, int] - """File_path, Lineno""" - code_line: str - """The Python line calling graph_rewrite""" - kernel_name: Optional[str] - """The kernel calling graph_rewrite""" - upats: List[Tuple[Tuple[str, int], str]] - """List of all the applied UPats""" - -@dataclass(frozen=True) -class GraphRewriteDetails(GraphRewriteMetadata): - """Full details about a single call to graph_rewrite""" - graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] - """Sink at every step of graph_rewrite""" - diffs: List[List[str]] - """.diff style before and after of the rewritten UOp child""" - changed_nodes: List[List[int]] - """Nodes that changed at every step of graph_rewrite""" - kernel_code: Optional[str] - """The program after all rewrites""" diff --git a/viz/test_viz.py b/viz/test_viz.py index c40be1ba..4531668e 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -3,15 +3,13 @@ import unittest import os, itertools os.environ["TRACK_MATCH_STATS"] = "2" os.environ["PRINT_MATCH_STATS"] = "0" -from tinygrad import Tensor +from tinygrad import Tensor, dtypes from tinygrad.engine.realize import lower_schedule +from tinygrad.dtype import PtrDType +from tinygrad.helpers import Context, all_same, getenv from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites -from tinygrad.dtype import dtypes, PtrDType -from tinygrad.helpers import Context, all_same, DEBUG, getenv from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding -from test.external.process_replay.helpers import print_diff -from viz.serve import reconstruct_graph, uop_to_json, load_kernels -from viz.spec import GraphRewriteMetadata +from viz.serve import GraphRewriteMetadata, get_metadata, get_details, _uop_to_json def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)} @@ -22,7 +20,7 @@ class TestViz(unittest.TestCase): def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]): assert len(contexts) != 0 - load_kernels(contexts) + get_metadata(contexts) def assert_valid_graph(self, t): contexts.clear() @@ -41,10 +39,10 @@ class TestViz(unittest.TestCase): schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule() list(lower_schedule(schedule1)) list(lower_schedule(schedule2)) - with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values()) + with Context(TRACK_MATCH_STATS=0): ret = get_metadata(contexts) assert len(ret) == 3 - assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:]) - assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:]) + assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:]) + assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:]) def test_gemm_diff(self): x = Tensor.empty(64, 64).realize() @@ -64,10 +62,10 @@ class TestViz(unittest.TestCase): @track_rewrites def f(k): return graph_rewrite(sink, pm) ret = f("test_rewrite") - if DEBUG >= 4: print_diff(sink, ret) - graphs,_,_ = reconstruct_graph(contexts[0][1][0]) - assert graphs[-1].key == ret.key self.assert_valid_ctx(contexts) + args = get_metadata(contexts)[0][0] + g = get_details(*args) + assert g.graphs[-1] == _uop_to_json(ret) def test_devectorize_viz(self): sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=( @@ -96,8 +94,7 @@ class TestViz(unittest.TestCase): pm = sym+(devectorize+float4_folding) @track_rewrites def f(k): return graph_rewrite(sink, pm) - new_sink = f("test_rewrite") - if DEBUG >= 4: print_diff(sink, new_sink, unified=0) + f("test_rewrite") self.assert_valid_ctx(contexts) assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs) @@ -111,9 +108,9 @@ class TestViz(unittest.TestCase): a = Tensor.empty(4, 4).contiguous().realize()+2 b = Tensor.empty(4, 4).contiguous().realize()+2 Tensor.schedule(a, b) - with Context(TRACK_MATCH_STATS=0): kernels = load_kernels(contexts) + with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts) self.assertEqual(len(kernels), 1) - rewrites = [x[0] for x in list(kernels.values())[0]] + rewrites = [x[2] for x in kernels[0]] assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k) def test_no_dedup_different_opts(self): @@ -122,23 +119,23 @@ class TestViz(unittest.TestCase): s = a.schedule() with Context(NOOPT=1): list(lower_schedule(s.copy())) with Context(NOOPT=0): list(lower_schedule(s.copy())) - with Context(TRACK_MATCH_STATS=0): kernels = list(load_kernels(contexts).values())[1:] + with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)[1:] self.assertEqual(len(kernels), 2) - rewrites = [x[0] for x in kernels[0]] + rewrites = [x[2] for x in kernels[0]] assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items()) def test_fold_const_nodes(self): a = Tensor.empty(4, 4)+2 contexts.clear() sink = a.schedule()[-1].ast - ret = uop_to_json(sink) + ret = _uop_to_json(sink) assert not any(v[0].startswith("CONST") for v in ret.values()) assert len([x for x in ret.values() if "CONST" in x[0]]) == 1 @unittest.skip("VIZ for a single CONST isn't supported anymore") def test_no_fold_single_const(self): node = UOp(UOps.CONST, dtypes.float, (), 1.0) - ret = uop_to_json(node, base=node) + ret = _uop_to_json(node, base=node) assert len(ret) == 1 if __name__ == "__main__":