mirror of https://github.com/commaai/tinygrad.git
clean up graph dedup function [run_process_replay] (#5169)
This commit is contained in:
parent
3a04e518ec
commit
396ce6cfc9
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
||||
import functools, itertools, heapq, math
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, dtypes, DType
|
||||
|
@ -334,23 +334,22 @@ class UOpGraph:
|
|||
def graph_dedup(self, sink:UOp):
|
||||
# add nodes to graph in reverse BFS order
|
||||
# dedup all nodes
|
||||
# TODO: i feel like this BFS is written in a few places, possible to library it?
|
||||
early_in_degree: Dict[UOp, int] = {}
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
get_children_dfs(sink, children, early_in_degree)
|
||||
get_children_dfs(sink, children, in_degree)
|
||||
|
||||
early_queue = [k for k, v in early_in_degree.items() if v == 0]
|
||||
queue = deque([k for k, v in in_degree.items() if v == 0])
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while len(early_queue):
|
||||
n = early_queue.pop(0)
|
||||
while queue:
|
||||
n = queue.popleft()
|
||||
if n in replace_nodes: continue
|
||||
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
||||
if found:=self.nodes.get(key): replace_nodes[n] = found
|
||||
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
||||
for x in children[n]:
|
||||
early_in_degree[x] -= 1
|
||||
if early_in_degree[x] == 0:
|
||||
early_queue.append(x)
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
queue.append(x)
|
||||
return replace_nodes.get(sink, sink)
|
||||
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||
|
|
Loading…
Reference in New Issue