move graph_dedup out of class [run_process_replay] (#5197)

This commit is contained in:
George Hotz 2024-06-27 12:04:00 -07:00 committed by GitHub
parent d094a6828f
commit 345bcc2099
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 24 deletions

View File

@ -3,7 +3,7 @@ from test.helpers import TestUOps
from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite, graph_dedup
#from tinygrad.engine.graph import print_tree
simple_pm = PatternMatcher([
@ -14,6 +14,12 @@ simple_pm = PatternMatcher([
])
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
nout = graph_dedup(v1+v2)
self.assertIs(nout.src[0], nout.src[1])
def test_simple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)

View File

@ -296,6 +296,28 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
return __inner_rewrite(new_n) if (new_n := pm.rewrite(n)) else n
return __inner_rewrite(sink)
def graph_dedup(sink:UOp):
# add nodes to graph in reverse BFS order
# dedup all nodes
in_degree: Dict[UOp, int] = {}
children: Dict[UOp, List[UOp]] = {}
get_children_dfs(sink, children, in_degree)
nodes: Dict[Tuple, UOp] = {}
queue = deque([k for k, v in in_degree.items() if v == 0])
replace_nodes: Dict[UOp, UOp] = {}
while queue:
n = queue.popleft()
if n in replace_nodes: continue
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
if found:=nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = nodes[key] = UOp(*key)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
queue.append(x)
return replace_nodes.get(sink, sink)
class UOpGraph:
def __init__(self, sinks:List[UOp]):
self.sinks: List[UOp] = sinks
@ -321,34 +343,13 @@ class UOpGraph:
for i,u in enumerate(self):
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
def graph_dedup(self, sink:UOp):
# add nodes to graph in reverse BFS order
# dedup all nodes
in_degree: Dict[UOp, int] = {}
children: Dict[UOp, List[UOp]] = {}
get_children_dfs(sink, children, in_degree)
queue = deque([k for k, v in in_degree.items() if v == 0])
replace_nodes: Dict[UOp, UOp] = {}
while queue:
n = queue.popleft()
if n in replace_nodes: continue
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
if found:=self.nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
queue.append(x)
return replace_nodes.get(sink, sink)
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
self.nodes: Dict[Tuple, UOp] = {}
sink = UOp(UOps.SINK, None, tuple(self.sinks))
# dedup all nodes in graph
sink = self.graph_dedup(UOp(UOps.SINK, None, tuple(self.sinks)))
sink = graph_dedup(sink)
# do graph rewrite
sink = graph_rewrite(sink, constant_folder)