mirror of https://github.com/commaai/tinygrad.git
move graph_dedup out of class [run_process_replay] (#5197)
This commit is contained in:
parent
d094a6828f
commit
345bcc2099
|
@ -3,7 +3,7 @@ from test.helpers import TestUOps
|
|||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite, graph_dedup
|
||||
#from tinygrad.engine.graph import print_tree
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
|
@ -14,6 +14,12 @@ simple_pm = PatternMatcher([
|
|||
])
|
||||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
nout = graph_dedup(v1+v2)
|
||||
self.assertIs(nout.src[0], nout.src[1])
|
||||
|
||||
def test_simple(self):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
|
|
|
@ -296,6 +296,28 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
|||
return __inner_rewrite(new_n) if (new_n := pm.rewrite(n)) else n
|
||||
return __inner_rewrite(sink)
|
||||
|
||||
def graph_dedup(sink:UOp):
|
||||
# add nodes to graph in reverse BFS order
|
||||
# dedup all nodes
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
get_children_dfs(sink, children, in_degree)
|
||||
|
||||
nodes: Dict[Tuple, UOp] = {}
|
||||
queue = deque([k for k, v in in_degree.items() if v == 0])
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while queue:
|
||||
n = queue.popleft()
|
||||
if n in replace_nodes: continue
|
||||
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
||||
if found:=nodes.get(key): replace_nodes[n] = found
|
||||
else: replace_nodes[n] = nodes[key] = UOp(*key)
|
||||
for x in children[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
queue.append(x)
|
||||
return replace_nodes.get(sink, sink)
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self, sinks:List[UOp]):
|
||||
self.sinks: List[UOp] = sinks
|
||||
|
@ -321,34 +343,13 @@ class UOpGraph:
|
|||
for i,u in enumerate(self):
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
||||
|
||||
def graph_dedup(self, sink:UOp):
|
||||
# add nodes to graph in reverse BFS order
|
||||
# dedup all nodes
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
get_children_dfs(sink, children, in_degree)
|
||||
|
||||
queue = deque([k for k, v in in_degree.items() if v == 0])
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while queue:
|
||||
n = queue.popleft()
|
||||
if n in replace_nodes: continue
|
||||
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
||||
if found:=self.nodes.get(key): replace_nodes[n] = found
|
||||
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
||||
for x in children[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
queue.append(x)
|
||||
return replace_nodes.get(sink, sink)
|
||||
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||
# NOTE: relinearizering should be okay
|
||||
#assert self._uops is None, "already linearized"
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
||||
|
||||
# dedup all nodes in graph
|
||||
sink = self.graph_dedup(UOp(UOps.SINK, None, tuple(self.sinks)))
|
||||
sink = graph_dedup(sink)
|
||||
|
||||
# do graph rewrite
|
||||
sink = graph_rewrite(sink, constant_folder)
|
||||
|
|
Loading…
Reference in New Issue