clean up graph dedup function [run_process_replay] (#5169)

This commit is contained in:
George Hotz 2024-06-26 15:07:34 -07:00 committed by GitHub
parent 3a04e518ec
commit 396ce6cfc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
import functools, itertools, heapq, math
from collections import defaultdict
from collections import defaultdict, deque
from enum import Enum, auto
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, dtypes, DType
@ -334,23 +334,22 @@ class UOpGraph:
def graph_dedup(self, sink:UOp):
# add nodes to graph in reverse BFS order
# dedup all nodes
# TODO: i feel like this BFS is written in a few places, possible to library it?
early_in_degree: Dict[UOp, int] = {}
in_degree: Dict[UOp, int] = {}
children: Dict[UOp, List[UOp]] = {}
get_children_dfs(sink, children, early_in_degree)
get_children_dfs(sink, children, in_degree)
early_queue = [k for k, v in early_in_degree.items() if v == 0]
queue = deque([k for k, v in in_degree.items() if v == 0])
replace_nodes: Dict[UOp, UOp] = {}
while len(early_queue):
n = early_queue.pop(0)
while queue:
n = queue.popleft()
if n in replace_nodes: continue
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
if found:=self.nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
for x in children[n]:
early_in_degree[x] -= 1
if early_in_degree[x] == 0:
early_queue.append(x)
in_degree[x] -= 1
if in_degree[x] == 0:
queue.append(x)
return replace_nodes.get(sink, sink)
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):