From 56fbd408a124b2fc85ff3cddb93223c6f37cb7cc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:38:27 +0300 Subject: [PATCH] viz print the sink tree as it's rewritten [pr] (#7094) --- tinygrad/viz/index.html | 13 ++++++++----- tinygrad/viz/serve.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 6d7f8c1d..52c60881 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -311,13 +311,16 @@ metadata.style.userSelect = "initial"; } // ** code blocks + let code = ret.uops[currentRewrite]; + let lang = "python" if (ret.kernel_code != null) { - const code = ret.kernel_code.replaceAll("<", "<").replaceAll(">", ">"); - const pre = Object.assign(document.createElement("pre"), { innerHTML: `${DOMPurify.sanitize(code)}`, - className: "code-block language-cpp" }); - hljs.highlightElement(pre); - metadata.appendChild(pre); + code = ret.kernel_code.replaceAll("<", "<").replaceAll(">", ">"); + lang = "cpp"; } + const codeBlock = Object.assign(document.createElement("pre"), { innerHTML: `${DOMPurify.sanitize(code)}`, + className: `code-block language-${lang}` }); + hljs.highlightElement(codeBlock); + metadata.appendChild(codeBlock); // ** rewrite list if (ret.graphs.length > 1) { const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" }) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 32ab61fc..152e5702 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -104,7 +104,7 @@ class Handler(BaseHTTPRequestHandler): query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) - ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs))}).encode() + ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(str, g.graphs))}).encode() else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode() else: self.send_response(404)