diff --git a/test/test_uops.py b/test/test_uops.py index 4f04909b..16f8937c 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -13,7 +13,7 @@ from tinygrad.codegen.uopgraph import linearize_uop from test.helpers import is_dtype_supported, TestUOps as TestEqUOps def _uops_to_prg(uops_list): - uops = linearize_uop(uops_list, extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop(uops_list, opts=Device[Device.DEFAULT].renderer) src = Device[Device.DEFAULT].renderer.render("test", uops) has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, @@ -326,7 +326,7 @@ class TestAssembly(unittest.TestCase): l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) - uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHL) self.assertEqual(uops[-2].arg, BinaryOps.MUL) @@ -338,7 +338,7 @@ class TestAssembly(unittest.TestCase): l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) - uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHR) self.assertEqual(uops[-2].arg, BinaryOps.IDIV) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 9a63eb2f..fccde8f2 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -737,7 +737,7 @@ class Kernel: print(self.applied_opts) verify_ast(modified_ast) - self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts, extra_pm=self.opts.extra_matcher) + self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts) if DEBUG >= 5: print_uops(self.uops) if getenv("GRAPHUOPS"): from tinygrad.engine.graph import graph_uops diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ad13f59b..98e12e21 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -151,6 +151,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: # ***** transcendental ***** +@functools.lru_cache(None) def transcendental_folding(ops): return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v)) for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops]) @@ -493,6 +494,9 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), ]) +no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"), + lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]) + # *** uop graph *** def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]): @@ -518,19 +522,18 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return __inner_rewrite(sink) linearize_cnt = 0 -def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]: +def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, skip_check=False) -> List[UOp]: global linearize_cnt, acc_number sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in)) assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" - folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys()) + folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys())) # do graph rewrite acc_number = 0 sink = graph_rewrite(sink, folder) # rewrite pyint to int32 - sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"), - lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))])) + sink = graph_rewrite(sink, no_pyint) # expand linearize_cnt += 1 @@ -539,7 +542,7 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, e sink = graph_rewrite(sink, folder+expander+reducer) # for PTX only - if extra_pm: sink = graph_rewrite(sink, folder+extra_pm) + if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher) # filter nodes that don't link to a sink # BFS toposort