mirror of https://github.com/commaai/tinygrad.git
prep for viz move to core [pr] (#6938)
* prep for viz move to core [pr] * polish
This commit is contained in:
parent
e4c0743188
commit
0ecc417dd2
133
viz/serve.py
133
viz/serve.py
|
@ -1,35 +1,54 @@
|
|||
#!/usr/bin/env python3
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
|
||||
from dataclasses import asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from tinygrad.helpers import getenv, to_function_name, tqdm
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
|
||||
from tinygrad.engine.graph import word_wrap, uops_colors
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
|
||||
uops: List[UOp] = [ctx.sink]
|
||||
diffs: List[List[str]] = []
|
||||
changed_nodes: List[List[int]] = []
|
||||
seen_replaces: Dict[UOp, UOp] = {}
|
||||
for i, (first, rewritten, upat) in enumerate(ctx.rewrites):
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
seen_replaces[first] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
if new_sink is uops[-1]:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
|
||||
# update ret data
|
||||
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
|
||||
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
|
||||
uops.append(new_sink)
|
||||
return uops, diffs, changed_nodes
|
||||
# ** API spec
|
||||
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
@dataclass
|
||||
class GraphRewriteMetadata:
|
||||
"""Specifies metadata about a single call to graph_rewrite"""
|
||||
loc: Tuple[str, int]
|
||||
"""File_path, Lineno"""
|
||||
code_line: str
|
||||
"""The Python line calling graph_rewrite"""
|
||||
kernel_name: Optional[str]
|
||||
"""The kernel calling graph_rewrite"""
|
||||
upats: List[Tuple[Tuple[str, int], str]]
|
||||
"""List of all the applied UPats"""
|
||||
|
||||
@dataclass
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
"""Full details about a single call to graph_rewrite"""
|
||||
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
|
||||
"""Sink at every step of graph_rewrite"""
|
||||
diffs: List[List[str]]
|
||||
""".diff style before and after of the rewritten UOp child"""
|
||||
changed_nodes: List[List[int]]
|
||||
"""Nodes that changed at every step of graph_rewrite"""
|
||||
kernel_code: Optional[str]
|
||||
"""The program after all rewrites"""
|
||||
|
||||
# ** API functions
|
||||
|
||||
def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
|
||||
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
|
||||
for k,ctxs in contexts:
|
||||
name = to_function_name(k.name) if isinstance(k, Kernel) else None
|
||||
for ctx in ctxs:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
|
||||
if name not in kernels: kernels[name] = []
|
||||
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
return list(kernels.values())
|
||||
|
||||
def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.sparents:
|
||||
|
@ -37,31 +56,32 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
|||
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
||||
if getenv("WITH_SHAPE"):
|
||||
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
|
||||
if u.st is not None: label += f"\n{u.st.shape}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
|
||||
def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base)) is not None: return found
|
||||
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
|
||||
replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
|
||||
return ret
|
||||
|
||||
def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \
|
||||
TrackedRewriteContext, Any]]]:
|
||||
kernels = defaultdict(list)
|
||||
for k,rewrites in contexts:
|
||||
if isinstance(k, Kernel): name = to_function_name(k.name)
|
||||
else: name = None
|
||||
for ctx in rewrites:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
|
||||
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k))
|
||||
return kernels
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
replaces[u0] = u1
|
||||
new_sink = _replace_uop(sink, {**replaces})
|
||||
# sanity check
|
||||
if new_sink is sink:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
|
||||
g.graphs.append(_uop_to_json(sink:=new_sink))
|
||||
return g
|
||||
|
||||
# ** HTTP server
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
@ -83,17 +103,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.end_headers()
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
|
||||
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
|
||||
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
|
||||
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
|
||||
ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
ret = b""
|
||||
return self.wfile.write(ret)
|
||||
|
||||
BROWSER = getenv("BROWSER", 1)
|
||||
# ** main loop
|
||||
|
||||
stop_reloader = threading.Event()
|
||||
def reloader():
|
||||
mtime = os.stat(__file__).st_mtime
|
||||
|
@ -108,18 +126,17 @@ if __name__ == "__main__":
|
|||
print("*** viz is starting")
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f)
|
||||
print("*** unpickled saved rewrites")
|
||||
kernels = load_kernels(contexts)
|
||||
kernels = get_metadata(contexts)
|
||||
if getenv("FUZZ_VIZ"):
|
||||
for v in tqdm(kernels.values()):
|
||||
for _,ctx,_ in v: reconstruct_graph(ctx)
|
||||
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
|
||||
print(f"fuzzed {len(ret)} rewrite details")
|
||||
print("*** loaded kernels")
|
||||
server = HTTPServer(('', 8000), Handler)
|
||||
st = time.perf_counter()
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
reloader_thread.start()
|
||||
if BROWSER: webbrowser.open("http://localhost:8000")
|
||||
try:
|
||||
server.serve_forever()
|
||||
if getenv("BROWSER", 1): webbrowser.open("http://localhost:8000")
|
||||
try: server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("*** viz is shutting down...")
|
||||
stop_reloader.set()
|
||||
|
|
26
viz/spec.py
26
viz/spec.py
|
@ -1,26 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphRewriteMetadata:
|
||||
"""Specifies metadata about a single call to graph_rewrite"""
|
||||
loc: Tuple[str, int]
|
||||
"""File_path, Lineno"""
|
||||
code_line: str
|
||||
"""The Python line calling graph_rewrite"""
|
||||
kernel_name: Optional[str]
|
||||
"""The kernel calling graph_rewrite"""
|
||||
upats: List[Tuple[Tuple[str, int], str]]
|
||||
"""List of all the applied UPats"""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
"""Full details about a single call to graph_rewrite"""
|
||||
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
|
||||
"""Sink at every step of graph_rewrite"""
|
||||
diffs: List[List[str]]
|
||||
""".diff style before and after of the rewritten UOp child"""
|
||||
changed_nodes: List[List[int]]
|
||||
"""Nodes that changed at every step of graph_rewrite"""
|
||||
kernel_code: Optional[str]
|
||||
"""The program after all rewrites"""
|
|
@ -3,15 +3,13 @@ import unittest
|
|||
import os, itertools
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
os.environ["PRINT_MATCH_STATS"] = "0"
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import Context, all_same, getenv
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import Context, all_same, DEBUG, getenv
|
||||
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
|
||||
from viz.spec import GraphRewriteMetadata
|
||||
from viz.serve import GraphRewriteMetadata, get_metadata, get_details, _uop_to_json
|
||||
|
||||
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
|
||||
|
||||
|
@ -22,7 +20,7 @@ class TestViz(unittest.TestCase):
|
|||
|
||||
def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
|
||||
assert len(contexts) != 0
|
||||
load_kernels(contexts)
|
||||
get_metadata(contexts)
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
contexts.clear()
|
||||
|
@ -41,10 +39,10 @@ class TestViz(unittest.TestCase):
|
|||
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values())
|
||||
with Context(TRACK_MATCH_STATS=0): ret = get_metadata(contexts)
|
||||
assert len(ret) == 3
|
||||
assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
|
@ -64,10 +62,10 @@ class TestViz(unittest.TestCase):
|
|||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
ret = f("test_rewrite")
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
|
||||
assert graphs[-1].key == ret.key
|
||||
self.assert_valid_ctx(contexts)
|
||||
args = get_metadata(contexts)[0][0]
|
||||
g = get_details(*args)
|
||||
assert g.graphs[-1] == _uop_to_json(ret)
|
||||
|
||||
def test_devectorize_viz(self):
|
||||
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=(
|
||||
|
@ -96,8 +94,7 @@ class TestViz(unittest.TestCase):
|
|||
pm = sym+(devectorize+float4_folding)
|
||||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
new_sink = f("test_rewrite")
|
||||
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
||||
f("test_rewrite")
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
|
||||
|
||||
|
@ -111,9 +108,9 @@ class TestViz(unittest.TestCase):
|
|||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
Tensor.schedule(a, b)
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = load_kernels(contexts)
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)
|
||||
self.assertEqual(len(kernels), 1)
|
||||
rewrites = [x[0] for x in list(kernels.values())[0]]
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)
|
||||
|
||||
def test_no_dedup_different_opts(self):
|
||||
|
@ -122,23 +119,23 @@ class TestViz(unittest.TestCase):
|
|||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = list(load_kernels(contexts).values())[1:]
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)[1:]
|
||||
self.assertEqual(len(kernels), 2)
|
||||
rewrites = [x[0] for x in kernels[0]]
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())
|
||||
|
||||
def test_fold_const_nodes(self):
|
||||
a = Tensor.empty(4, 4)+2
|
||||
contexts.clear()
|
||||
sink = a.schedule()[-1].ast
|
||||
ret = uop_to_json(sink)
|
||||
ret = _uop_to_json(sink)
|
||||
assert not any(v[0].startswith("CONST") for v in ret.values())
|
||||
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
@unittest.skip("VIZ for a single CONST isn't supported anymore")
|
||||
def test_no_fold_single_const(self):
|
||||
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
|
||||
ret = uop_to_json(node, base=node)
|
||||
ret = _uop_to_json(node, base=node)
|
||||
assert len(ret) == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue