extra matcher from renderer [run_process_replay] (#6130)

* extra matcher from renderer

* cache_pm [run_process_replay]
This commit is contained in:
George Hotz 2024-08-16 23:53:11 -07:00 committed by GitHub
parent 9bc81c6db4
commit 3a2d724cb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 9 deletions

View File

@ -13,7 +13,7 @@ from tinygrad.codegen.uopgraph import linearize_uop
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
def _uops_to_prg(uops_list):
uops = linearize_uop(uops_list, extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
uops = linearize_uop(uops_list, opts=Device[Device.DEFAULT].renderer)
src = Device[Device.DEFAULT].renderer.render("test", uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
@ -326,7 +326,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
@ -338,7 +338,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)

View File

@ -737,7 +737,7 @@ class Kernel:
print(self.applied_opts)
verify_ast(modified_ast)
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts, extra_pm=self.opts.extra_matcher)
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts)
if DEBUG >= 5: print_uops(self.uops)
if getenv("GRAPHUOPS"):
from tinygrad.engine.graph import graph_uops

View File

@ -151,6 +151,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
# ***** transcendental *****
@functools.lru_cache(None)
def transcendental_folding(ops):
return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v))
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops])
@ -493,6 +494,9 @@ reducer = PatternMatcher([
(UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
])
no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))])
# *** uop graph ***
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
@ -518,19 +522,18 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
return __inner_rewrite(sink)
linearize_cnt = 0
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]:
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, skip_check=False) -> List[UOp]:
global linearize_cnt, acc_number
sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in))
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys())
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
# do graph rewrite
acc_number = 0
sink = graph_rewrite(sink, folder)
# rewrite pyint to int32
sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]))
sink = graph_rewrite(sink, no_pyint)
# expand
linearize_cnt += 1
@ -539,7 +542,7 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, e
sink = graph_rewrite(sink, folder+expander+reducer)
# for PTX only
if extra_pm: sink = graph_rewrite(sink, folder+extra_pm)
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
# filter nodes that don't link to a sink
# BFS toposort