prep for viz move to core [pr] (#6938)

* prep for viz move to core [pr]

* polish
This commit is contained in:
qazal 2024-10-07 18:24:04 +03:00 committed by GitHub
parent e4c0743188
commit 0ecc417dd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 105 deletions

View File

@ -1,35 +1,54 @@
#!/usr/bin/env python3
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
from dataclasses import asdict
from urllib.parse import parse_qs, urlparse
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.codegen.kernel import Kernel
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import getenv, to_function_name, tqdm
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
from tinygrad.engine.graph import word_wrap, uops_colors
from tinygrad.codegen.kernel import Kernel
def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [ctx.sink]
diffs: List[List[str]] = []
changed_nodes: List[List[int]] = []
seen_replaces: Dict[UOp, UOp] = {}
for i, (first, rewritten, upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
if new_sink is uops[-1]:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
# update ret data
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
uops.append(new_sink)
return uops, diffs, changed_nodes
# ** API spec
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
@dataclass
class GraphRewriteMetadata:
"""Specifies metadata about a single call to graph_rewrite"""
loc: Tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: Optional[str]
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str]]
"""List of all the applied UPats"""
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
"""Sink at every step of graph_rewrite"""
diffs: List[List[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: List[List[int]]
"""Nodes that changed at every step of graph_rewrite"""
kernel_code: Optional[str]
"""The program after all rewrites"""
# ** API functions
def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
for k,ctxs in contexts:
name = to_function_name(k.name) if isinstance(k, Kernel) else None
for ctx in ctxs:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
@ -37,31 +56,32 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
if getenv("WITH_SHAPE"):
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
if u.st is not None: label += f"\n{u.st.shape}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
return ret
def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \
TrackedRewriteContext, Any]]]:
kernels = defaultdict(list)
for k,rewrites in contexts:
if isinstance(k, Kernel): name = to_function_name(k.name)
else: name = None
for ctx in rewrites:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k))
return kernels
@functools.lru_cache(None)
def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
replaces[u0] = u1
new_sink = _replace_uop(sink, {**replaces})
# sanity check
if new_sink is sink:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
g.graphs.append(_uop_to_json(sink:=new_sink))
return g
# ** HTTP server
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@ -83,17 +103,15 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers()
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
else:
self.send_response(404)
ret = b""
return self.wfile.write(ret)
BROWSER = getenv("BROWSER", 1)
# ** main loop
stop_reloader = threading.Event()
def reloader():
mtime = os.stat(__file__).st_mtime
@ -108,18 +126,17 @@ if __name__ == "__main__":
print("*** viz is starting")
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = load_kernels(contexts)
kernels = get_metadata(contexts)
if getenv("FUZZ_VIZ"):
for v in tqdm(kernels.values()):
for _,ctx,_ in v: reconstruct_graph(ctx)
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
print(f"fuzzed {len(ret)} rewrite details")
print("*** loaded kernels")
server = HTTPServer(('', 8000), Handler)
st = time.perf_counter()
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
if BROWSER: webbrowser.open("http://localhost:8000")
try:
server.serve_forever()
if getenv("BROWSER", 1): webbrowser.open("http://localhost:8000")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()

View File

@ -1,26 +0,0 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@dataclass(frozen=True)
class GraphRewriteMetadata:
"""Specifies metadata about a single call to graph_rewrite"""
loc: Tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: Optional[str]
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str]]
"""List of all the applied UPats"""
@dataclass(frozen=True)
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
"""Sink at every step of graph_rewrite"""
diffs: List[List[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: List[List[int]]
"""Nodes that changed at every step of graph_rewrite"""
kernel_code: Optional[str]
"""The program after all rewrites"""

View File

@ -3,15 +3,13 @@ import unittest
import os, itertools
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["PRINT_MATCH_STATS"] = "0"
from tinygrad import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.engine.realize import lower_schedule
from tinygrad.dtype import PtrDType
from tinygrad.helpers import Context, all_same, getenv
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import Context, all_same, DEBUG, getenv
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
from viz.spec import GraphRewriteMetadata
from viz.serve import GraphRewriteMetadata, get_metadata, get_details, _uop_to_json
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
@ -22,7 +20,7 @@ class TestViz(unittest.TestCase):
def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
assert len(contexts) != 0
load_kernels(contexts)
get_metadata(contexts)
def assert_valid_graph(self, t):
contexts.clear()
@ -41,10 +39,10 @@ class TestViz(unittest.TestCase):
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
list(lower_schedule(schedule1))
list(lower_schedule(schedule2))
with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values())
with Context(TRACK_MATCH_STATS=0): ret = get_metadata(contexts)
assert len(ret) == 3
assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
@ -64,10 +62,10 @@ class TestViz(unittest.TestCase):
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
ret = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, ret)
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
assert graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
args = get_metadata(contexts)[0][0]
g = get_details(*args)
assert g.graphs[-1] == _uop_to_json(ret)
def test_devectorize_viz(self):
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=(
@ -96,8 +94,7 @@ class TestViz(unittest.TestCase):
pm = sym+(devectorize+float4_folding)
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
new_sink = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
f("test_rewrite")
self.assert_valid_ctx(contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
@ -111,9 +108,9 @@ class TestViz(unittest.TestCase):
a = Tensor.empty(4, 4).contiguous().realize()+2
b = Tensor.empty(4, 4).contiguous().realize()+2
Tensor.schedule(a, b)
with Context(TRACK_MATCH_STATS=0): kernels = load_kernels(contexts)
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)
self.assertEqual(len(kernels), 1)
rewrites = [x[0] for x in list(kernels.values())[0]]
rewrites = [x[2] for x in kernels[0]]
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)
def test_no_dedup_different_opts(self):
@ -122,23 +119,23 @@ class TestViz(unittest.TestCase):
s = a.schedule()
with Context(NOOPT=1): list(lower_schedule(s.copy()))
with Context(NOOPT=0): list(lower_schedule(s.copy()))
with Context(TRACK_MATCH_STATS=0): kernels = list(load_kernels(contexts).values())[1:]
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)[1:]
self.assertEqual(len(kernels), 2)
rewrites = [x[0] for x in kernels[0]]
rewrites = [x[2] for x in kernels[0]]
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())
def test_fold_const_nodes(self):
a = Tensor.empty(4, 4)+2
contexts.clear()
sink = a.schedule()[-1].ast
ret = uop_to_json(sink)
ret = _uop_to_json(sink)
assert not any(v[0].startswith("CONST") for v in ret.values())
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
@unittest.skip("VIZ for a single CONST isn't supported anymore")
def test_no_fold_single_const(self):
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
ret = uop_to_json(node, base=node)
ret = _uop_to_json(node, base=node)
assert len(ret) == 1
if __name__ == "__main__":