mirror of https://github.com/commaai/tinygrad.git
viz lowerer and graph_rewrite dedup try 2 (#6652)
This commit is contained in:
parent
6b65d8c461
commit
d1bae42d35
|
@ -378,8 +378,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
|||
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the UPat
|
||||
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
|
|
12
viz/serve.py
12
viz/serve.py
|
@ -67,9 +67,9 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
|||
class KernelRet:
|
||||
name: str
|
||||
code: str
|
||||
ctxs: List[TrackedRewriteContext]
|
||||
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
|
||||
def to_json(self) -> Dict:
|
||||
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs]}
|
||||
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
|
||||
|
||||
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
ret: Dict[str, KernelRet] = {}
|
||||
|
@ -79,8 +79,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
|||
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
|
||||
elif ctx.kernel_name is not None: kernel_name = ctx.kernel_name
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, [])
|
||||
ret[k].ctxs.append(ctx)
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
|
||||
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
|
||||
return list(ret.values())
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
|
@ -112,8 +112,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
kernels = load_kernels(contexts)
|
||||
k = kernels[int(query["kernel_idx"][0])]
|
||||
g = UOpRet.from_ctx(k.ctxs[int(query["uop_idx"][0])])
|
||||
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs])).encode()
|
||||
g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])])
|
||||
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
ret = b""
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad import Tensor
|
|||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import CI, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import UOpRet, load_kernels
|
||||
|
@ -47,8 +47,8 @@ class TestViz(unittest.TestCase):
|
|||
list(lower_schedule(schedule2))
|
||||
ret = load_kernels(contexts)
|
||||
assert len(ret) == 2
|
||||
assert all(len([x for x in y.ctxs if "schedule" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs if "uopgraph" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
|
@ -115,5 +115,28 @@ class TestViz(unittest.TestCase):
|
|||
simple_pm.rewrite(UOp.const(dtypes.int, 2))
|
||||
self.assertEqual(len(contexts), 0)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
contexts.clear()
|
||||
a = Tensor.randn(4, 4)+2
|
||||
b = Tensor.randn(4, 4)+2
|
||||
Tensor.schedule(a, b)
|
||||
kernels = load_kernels(contexts)
|
||||
self.assertEqual(len(kernels), 1)
|
||||
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 1)
|
||||
|
||||
def test_no_dedup_different_opts(self):
|
||||
contexts.clear()
|
||||
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
||||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
kernels = load_kernels(contexts)
|
||||
self.assertEqual(len(kernels), 2)
|
||||
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 1)
|
||||
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue