diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d3dbe110..5f8ff84a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -3,7 +3,7 @@ from test.helpers import TestUOps from tinygrad import dtypes, Variable from tinygrad.dtype import PtrDType from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps -from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite +from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite, graph_dedup #from tinygrad.engine.graph import print_tree simple_pm = PatternMatcher([ @@ -14,6 +14,12 @@ simple_pm = PatternMatcher([ ]) class TestGraphRewrite(unittest.TestCase): + def test_dedup(self): + v1 = UOp(UOps.DEFINE_VAR, dtypes.float) + v2 = UOp(UOps.DEFINE_VAR, dtypes.float) + nout = graph_dedup(v1+v2) + self.assertIs(nout.src[0], nout.src[1]) + def test_simple(self): c1 = UOp.const(dtypes.float, 1.0) c2 = UOp.const(dtypes.float, 2.0) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 390583c5..383c2afe 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -296,6 +296,28 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return __inner_rewrite(new_n) if (new_n := pm.rewrite(n)) else n return __inner_rewrite(sink) +def graph_dedup(sink:UOp): + # add nodes to graph in reverse BFS order + # dedup all nodes + in_degree: Dict[UOp, int] = {} + children: Dict[UOp, List[UOp]] = {} + get_children_dfs(sink, children, in_degree) + + nodes: Dict[Tuple, UOp] = {} + queue = deque([k for k, v in in_degree.items() if v == 0]) + replace_nodes: Dict[UOp, UOp] = {} + while queue: + n = queue.popleft() + if n in replace_nodes: continue + key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg) + if found:=nodes.get(key): replace_nodes[n] = found + else: replace_nodes[n] = nodes[key] = UOp(*key) + for x in children[n]: + in_degree[x] -= 1 + if in_degree[x] == 0: + queue.append(x) + return replace_nodes.get(sink, sink) + class UOpGraph: def __init__(self, sinks:List[UOp]): self.sinks: List[UOp] = sinks @@ -321,34 +343,13 @@ class UOpGraph: for i,u in enumerate(self): print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}") - def graph_dedup(self, sink:UOp): - # add nodes to graph in reverse BFS order - # dedup all nodes - in_degree: Dict[UOp, int] = {} - children: Dict[UOp, List[UOp]] = {} - get_children_dfs(sink, children, in_degree) - - queue = deque([k for k, v in in_degree.items() if v == 0]) - replace_nodes: Dict[UOp, UOp] = {} - while queue: - n = queue.popleft() - if n in replace_nodes: continue - key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg) - if found:=self.nodes.get(key): replace_nodes[n] = found - else: replace_nodes[n] = self.nodes[key] = UOp(*key) - for x in children[n]: - in_degree[x] -= 1 - if in_degree[x] == 0: - queue.append(x) - return replace_nodes.get(sink, sink) - def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True): # NOTE: relinearizering should be okay #assert self._uops is None, "already linearized" - self.nodes: Dict[Tuple, UOp] = {} + sink = UOp(UOps.SINK, None, tuple(self.sinks)) # dedup all nodes in graph - sink = self.graph_dedup(UOp(UOps.SINK, None, tuple(self.sinks))) + sink = graph_dedup(sink) # do graph rewrite sink = graph_rewrite(sink, constant_folder)