From 56fbd408a124b2fc85ff3cddb93223c6f37cb7cc Mon Sep 17 00:00:00 2001
From: qazal <77887910+Qazalin@users.noreply.github.com>
Date: Wed, 16 Oct 2024 11:38:27 +0300
Subject: [PATCH] viz print the sink tree as it's rewritten [pr] (#7094)
---
tinygrad/viz/index.html | 13 ++++++++-----
tinygrad/viz/serve.py | 2 +-
2 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html
index 6d7f8c1d..52c60881 100644
--- a/tinygrad/viz/index.html
+++ b/tinygrad/viz/index.html
@@ -311,13 +311,16 @@
metadata.style.userSelect = "initial";
}
// ** code blocks
+ let code = ret.uops[currentRewrite];
+ let lang = "python"
if (ret.kernel_code != null) {
- const code = ret.kernel_code.replaceAll("<", "<").replaceAll(">", ">");
- const pre = Object.assign(document.createElement("pre"), { innerHTML: `${DOMPurify.sanitize(code)}
`,
- className: "code-block language-cpp" });
- hljs.highlightElement(pre);
- metadata.appendChild(pre);
+ code = ret.kernel_code.replaceAll("<", "<").replaceAll(">", ">");
+ lang = "cpp";
}
+ const codeBlock = Object.assign(document.createElement("pre"), { innerHTML: `${DOMPurify.sanitize(code)}
`,
+ className: `code-block language-${lang}` });
+ hljs.highlightElement(codeBlock);
+ metadata.appendChild(codeBlock);
// ** rewrite list
if (ret.graphs.length > 1) {
const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" })
diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py
index 32ab61fc..152e5702 100755
--- a/tinygrad/viz/serve.py
+++ b/tinygrad/viz/serve.py
@@ -104,7 +104,7 @@ class Handler(BaseHTTPRequestHandler):
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
- ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs))}).encode()
+ ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(str, g.graphs))}).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
else:
self.send_response(404)