viz lowerer and graph_rewrite dedup try 2 (#6652)

This commit is contained in:
qazal 2024-09-22 21:09:46 +08:00 committed by GitHub
parent 6b65d8c461
commit d1bae42d35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 11 deletions

View File

@ -378,8 +378,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the UPat
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None)

View File

@ -67,9 +67,9 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
class KernelRet:
name: str
code: str
ctxs: List[TrackedRewriteContext]
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
def to_json(self) -> Dict:
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs]}
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
ret: Dict[str, KernelRet] = {}
@ -79,8 +79,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name = ctx.kernel_name
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, [])
ret[k].ctxs.append(ctx)
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
return list(ret.values())
class Handler(BaseHTTPRequestHandler):
@ -112,8 +112,8 @@ class Handler(BaseHTTPRequestHandler):
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
kernels = load_kernels(contexts)
k = kernels[int(query["kernel_idx"][0])]
g = UOpRet.from_ctx(k.ctxs[int(query["uop_idx"][0])])
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs])).encode()
g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])])
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode()
else:
self.send_response(404)
ret = b""

View File

@ -6,7 +6,7 @@ from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import CI, all_same, DEBUG, colored, getenv
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import UOpRet, load_kernels
@ -47,8 +47,8 @@ class TestViz(unittest.TestCase):
list(lower_schedule(schedule2))
ret = load_kernels(contexts)
assert len(ret) == 2
assert all(len([x for x in y.ctxs if "schedule" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs if "uopgraph" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
@ -115,5 +115,28 @@ class TestViz(unittest.TestCase):
simple_pm.rewrite(UOp.const(dtypes.int, 2))
self.assertEqual(len(contexts), 0)
def test_dedup_ast(self):
contexts.clear()
a = Tensor.randn(4, 4)+2
b = Tensor.randn(4, 4)+2
Tensor.schedule(a, b)
kernels = load_kernels(contexts)
self.assertEqual(len(kernels), 1)
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 1)
def test_no_dedup_different_opts(self):
contexts.clear()
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
s = a.schedule()
with Context(NOOPT=1): list(lower_schedule(s.copy()))
with Context(NOOPT=0): list(lower_schedule(s.copy()))
kernels = load_kernels(contexts)
self.assertEqual(len(kernels), 2)
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 1)
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 0)
if __name__ == "__main__":
unittest.main()