mirror of https://github.com/commaai/tinygrad.git
viz late to_json [pr] (#7070)
This commit is contained in:
parent
52d8afde2b
commit
1a45e94f5d
|
@ -3,7 +3,7 @@ import unittest
|
|||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import _replace_uop, get_details, get_metadata
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@track_rewrites
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx)
|
||||
|
@ -12,14 +12,9 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]:
|
|||
rewrite(sink, pm, ctx)
|
||||
assert len(contexts) == 1
|
||||
assert len(contexts[0][1]) == 1
|
||||
ctx = contexts[0][1][0]
|
||||
uops = [ctx.sink]
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
for u0,u1,_ in ctx.rewrites:
|
||||
replaces[u0] = u1
|
||||
new_sink = _replace_uop(uops[-1], {**replaces})
|
||||
uops.append(new_sink)
|
||||
return uops[1:]
|
||||
k = get_metadata(contexts)[0][0]
|
||||
g = get_details(*k)
|
||||
return g.graphs[1:]
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -91,12 +86,8 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_fold_const(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
rewrite(a, pm)
|
||||
graph = get_details(*get_metadata(contexts)[0][0]).graphs[-1]
|
||||
graph = uop_to_json(a)
|
||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class GraphRewriteMetadata:
|
|||
@dataclass
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
"""Full details about a single call to graph_rewrite"""
|
||||
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
|
||||
graphs: List[UOp]
|
||||
"""Sink at every step of graph_rewrite"""
|
||||
diffs: List[List[str]]
|
||||
""".diff style before and after of the rewritten UOp child"""
|
||||
|
@ -52,7 +52,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
|
|||
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
return list(kernels.values())
|
||||
|
||||
def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.sparents:
|
||||
|
@ -69,7 +69,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
|||
@functools.lru_cache(None)
|
||||
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
|
||||
|
@ -82,7 +82,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
|||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
|
||||
g.graphs.append(_uop_to_json(sink:=new_sink))
|
||||
g.graphs.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
# ** HTTP server
|
||||
|
@ -101,7 +101,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.end_headers()
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode()
|
||||
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
|
||||
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs))}).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
|
Loading…
Reference in New Issue