viz highlight new nodes (#6665)

* p2

* ret adds and dels

* maybe that way

* add additions

* simpler test_viz
This commit is contained in:
qazal 2024-09-23 13:46:18 +08:00 committed by GitHub
parent da5b741656
commit e9248b9e27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 15 deletions

View File

@ -155,13 +155,17 @@
<div class="container metadata"></div>
</div>
<script>
function renderGraph(graph) {
const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: #2ea04326" : "display: none;"});
for ([k,u] of Object.entries(graph)) {
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
for (src of u[2]) {
g.setEdge(src, k)
}
if (additions.includes(parseInt(k))) {
g.setParent(k, "addition");
}
}
const svg = d3.select("svg");
const inner = svg.select("g");
@ -237,7 +241,7 @@
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
cache[cacheKey] = ret;
}
renderGraph(ret[0].graphs[currentRewrite][0]);
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
@ -296,7 +300,7 @@
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.min(Array.from(Object.keys(kernels)).length, currentKernel+1)
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
return main()
}
}

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Dict, List, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad import Device
@ -17,15 +17,16 @@ from tinygrad.engine.schedule import full_ast_rewrite
@dataclass(frozen=True)
class UOpRet:
loc: str
graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite
graphs: List[UOp] # snapshot of the entire AST after each rewrite
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
additions: List[List[int]]
@staticmethod
def from_ctx(ctx:TrackedRewriteContext) -> UOpRet:
uops: List[UOp] = [ctx.sink]
graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)]
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
extra: List[List[str]] = [[str(ctx.sink)]]
additions: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
if pattern.location[0].split("/")[-1] == "ops.py": continue
@ -35,14 +36,12 @@ class UOpRet:
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
additions.append([id(x) for x in rewritten.sparents])
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
graphs.append((new_sink, uops[-1], rewritten, first))
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, graphs, diffs, extra)
def to_json(self) -> Dict:
return {"loc": self.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in self.graphs],
"diffs": self.diffs, "extra": self.extra}
return UOpRet(ctx.loc, uops, diffs, extra, additions)
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)

View File

@ -23,8 +23,7 @@ class TestViz(unittest.TestCase):
except Exception as e:
print(colored(f"failed to create graph for ctx {i}", "red"))
raise e
rewrites = [x[0] for x in ret.graphs]
for j,(x,y) in enumerate(zip(rewrites, rewrites[1:])):
for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
@ -68,7 +67,7 @@ class TestViz(unittest.TestCase):
ret = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, ret)
g = UOpRet.from_ctx(contexts[0])
assert g.graphs[-1][0].key == ret.key
assert g.graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
def test_devectorize_viz(self):