diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index ec335740..36ef139e 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -82,7 +82,7 @@ class TestGraphRewrite(unittest.TestCase): b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1)) c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1)) d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1)) - outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)] + outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, constant_folder) print(sink) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4aeaf8a0..f17e7057 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -5,7 +5,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition -from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify +from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -310,7 +310,7 @@ constant_folder = PatternMatcher([ lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None), # ** move add consts to end (NOTE: this is still happening before constant folding) ** (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), - (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y'))), + (UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]), lambda x,c1,y: (x+y)+c1), ]) @@ -485,10 +485,7 @@ class UOpGraph: from tinygrad.engine.graph import graph_uops graph_uops(self.uops) - def print(self): - for i,u in enumerate(self): - formatted_parents = [self.uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src] - print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}") + def print(self): print_uops(self.uops) cnt = 0 def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> UOpGraph: diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index ed392a78..6f72f461 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -237,6 +237,11 @@ def uop_alu_resolve(u:UOp) -> sint: if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) raise RuntimeError(f"ALU resolve fail @ {u.op}") +def print_uops(uops:List[UOp]): + for i,u in enumerate(uops): + formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src] + print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}") + def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: flops: sint = 0 mem: sint = 0