tinygrad/viz/serve.py

118 lines
5.4 KiB
Python
Executable File

#!/usr/bin/env python3
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
from dataclasses import asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.helpers import getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, UPat, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [sink]
diffs: List[List[str]] = []
changed_nodes: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, _) in enumerate(rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
uops.append(new_sink)
return uops, diffs, changed_nodes
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is UOps.CONST: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
if getenv("WITH_SHAPE"):
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
if u.st is not None: label += f"\n{u.st.shape}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
if (found:=replaces.get(base.key)) is not None: return found
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
return ret
def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
kernels = defaultdict(list)
for ctx in contexts:
name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx))
return kernels
@functools.lru_cache(None)
def get_src(k) -> Optional[str]: return k.to_program().src if k else None
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if (url:=urlparse(self.path)).path == "/favicon.svg":
self.send_response(200)
self.send_header("Content-type", "image/svg+xml")
self.end_headers()
with open(os.path.join(os.path.dirname(__file__), "favicon.svg"), "rb") as f:
ret = f.read()
if url.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f:
ret = f.read()
elif url.path == "/kernels":
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
graphs, diffs, changed_nodes = reconstruct_graph(ctx.sink, ctx.rewrites)
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
else:
self.send_response(404)
ret = b""
return self.wfile.write(ret)
BROWSER = getenv("BROWSER", 1)
stop_reloader = threading.Event()
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
if mtime != os.stat(__file__).st_mtime:
print("reloading server...")
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
if __name__ == "__main__":
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
print("*** viz is starting")
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = load_kernels(contexts)
print("*** loaded kernels")
server = HTTPServer(('', 8000), Handler)
st = time.perf_counter()
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
if BROWSER: webbrowser.open("http://localhost:8000")
try:
server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()