graph_dedup function [run_process_replay] (#4955)

This commit is contained in:
George Hotz 2024-06-14 04:24:37 -07:00 committed by GitHub
parent 63a8add2c2
commit 14189bca68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View File

@ -17,7 +17,8 @@ if __name__ == "__main__":
with Timing("***** model schedule in "):
sched = out.schedule()
with Profiling(PROFILE):
# snakeviz /tmp/schedule.prof
with Profiling(PROFILE, fn="/tmp/schedule.prof"):
with Timing("***** model lower in "):
eis = list(lower_schedule(sched))

View File

@ -307,15 +307,10 @@ class UOpGraph:
assert run_cnt < 100, "exceeded 100 rewrite loops!"
return sink
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
self.nodes: Dict[Tuple, UOp] = {}
def graph_dedup(self, sink):
# add nodes to graph in reverse BFS order
# dedup all nodes
# TODO: i feel like this BFS is written in a few places, possible to library it?
sink = UOp(UOps.SINK, None, tuple(self.sinks))
unprocessed_nodes = [sink]
early_in_degree: DefaultDict[UOp, int] = defaultdict(int)
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
@ -340,7 +335,15 @@ class UOpGraph:
early_in_degree[x] -= 1
if early_in_degree[x] == 0:
early_queue.append(x)
sink = replace_nodes.get(sink, sink)
return replace_nodes.get(sink, sink)
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
self.nodes: Dict[Tuple, UOp] = {}
# dedup all nodes in graph
sink = self.graph_dedup(UOp(UOps.SINK, None, tuple(self.sinks)))
# do graph rewrite
sink = self.graph_rewrite(sink, constant_folder)