mirror of https://github.com/commaai/tinygrad.git
viz highlight new nodes (#6665)
* p2 * ret adds and dels * maybe that way * add additions * simpler test_viz
This commit is contained in:
parent
da5b741656
commit
e9248b9e27
|
@ -155,13 +155,17 @@
|
|||
<div class="container metadata"></div>
|
||||
</div>
|
||||
<script>
|
||||
function renderGraph(graph) {
|
||||
const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
function renderGraph(graph, additions) {
|
||||
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: #2ea04326" : "display: none;"});
|
||||
for ([k,u] of Object.entries(graph)) {
|
||||
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
|
||||
for (src of u[2]) {
|
||||
g.setEdge(src, k)
|
||||
}
|
||||
if (additions.includes(parseInt(k))) {
|
||||
g.setParent(k, "addition");
|
||||
}
|
||||
}
|
||||
const svg = d3.select("svg");
|
||||
const inner = svg.select("g");
|
||||
|
@ -237,7 +241,7 @@
|
|||
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
|
||||
cache[cacheKey] = ret;
|
||||
}
|
||||
renderGraph(ret[0].graphs[currentRewrite][0]);
|
||||
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
metadata.innerHTML = "";
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
|
||||
|
@ -296,7 +300,7 @@
|
|||
event.preventDefault()
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
currentKernel = Math.min(Array.from(Object.keys(kernels)).length, currentKernel+1)
|
||||
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
|
||||
return main()
|
||||
}
|
||||
}
|
||||
|
|
15
viz/serve.py
15
viz/serve.py
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Dict, List, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad import Device
|
||||
|
@ -17,15 +17,16 @@ from tinygrad.engine.schedule import full_ast_rewrite
|
|||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str
|
||||
graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite
|
||||
graphs: List[UOp] # snapshot of the entire AST after each rewrite
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
additions: List[List[int]]
|
||||
@staticmethod
|
||||
def from_ctx(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
uops: List[UOp] = [ctx.sink]
|
||||
graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)]
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
additions: List[List[int]] = [[]]
|
||||
seen_replaces: Dict[bytes, UOp] = {}
|
||||
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
|
||||
if pattern.location[0].split("/")[-1] == "ops.py": continue
|
||||
|
@ -35,14 +36,12 @@ class UOpRet:
|
|||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
# update ret data
|
||||
additions.append([id(x) for x in rewritten.sparents])
|
||||
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
graphs.append((new_sink, uops[-1], rewritten, first))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, graphs, diffs, extra)
|
||||
def to_json(self) -> Dict:
|
||||
return {"loc": self.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in self.graphs],
|
||||
"diffs": self.diffs, "extra": self.extra}
|
||||
return UOpRet(ctx.loc, uops, diffs, extra, additions)
|
||||
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
|
||||
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
|
|
|
@ -23,8 +23,7 @@ class TestViz(unittest.TestCase):
|
|||
except Exception as e:
|
||||
print(colored(f"failed to create graph for ctx {i}", "red"))
|
||||
raise e
|
||||
rewrites = [x[0] for x in ret.graphs]
|
||||
for j,(x,y) in enumerate(zip(rewrites, rewrites[1:])):
|
||||
for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
|
||||
|
||||
|
@ -68,7 +67,7 @@ class TestViz(unittest.TestCase):
|
|||
ret = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
g = UOpRet.from_ctx(contexts[0])
|
||||
assert g.graphs[-1][0].key == ret.key
|
||||
assert g.graphs[-1].key == ret.key
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
def test_devectorize_viz(self):
|
||||
|
|
Loading…
Reference in New Issue