mirror of https://github.com/commaai/tinygrad.git
clean up graph dedup function [run_process_replay] (#5169)
This commit is contained in:
parent
3a04e518ec
commit
396ce6cfc9
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
||||||
import functools, itertools, heapq, math
|
import functools, itertools, heapq, math
|
||||||
from collections import defaultdict
|
from collections import defaultdict, deque
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from tinygrad.dtype import ConstType, dtypes, DType
|
from tinygrad.dtype import ConstType, dtypes, DType
|
||||||
|
@ -334,23 +334,22 @@ class UOpGraph:
|
||||||
def graph_dedup(self, sink:UOp):
|
def graph_dedup(self, sink:UOp):
|
||||||
# add nodes to graph in reverse BFS order
|
# add nodes to graph in reverse BFS order
|
||||||
# dedup all nodes
|
# dedup all nodes
|
||||||
# TODO: i feel like this BFS is written in a few places, possible to library it?
|
in_degree: Dict[UOp, int] = {}
|
||||||
early_in_degree: Dict[UOp, int] = {}
|
|
||||||
children: Dict[UOp, List[UOp]] = {}
|
children: Dict[UOp, List[UOp]] = {}
|
||||||
get_children_dfs(sink, children, early_in_degree)
|
get_children_dfs(sink, children, in_degree)
|
||||||
|
|
||||||
early_queue = [k for k, v in early_in_degree.items() if v == 0]
|
queue = deque([k for k, v in in_degree.items() if v == 0])
|
||||||
replace_nodes: Dict[UOp, UOp] = {}
|
replace_nodes: Dict[UOp, UOp] = {}
|
||||||
while len(early_queue):
|
while queue:
|
||||||
n = early_queue.pop(0)
|
n = queue.popleft()
|
||||||
if n in replace_nodes: continue
|
if n in replace_nodes: continue
|
||||||
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
||||||
if found:=self.nodes.get(key): replace_nodes[n] = found
|
if found:=self.nodes.get(key): replace_nodes[n] = found
|
||||||
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
||||||
for x in children[n]:
|
for x in children[n]:
|
||||||
early_in_degree[x] -= 1
|
in_degree[x] -= 1
|
||||||
if early_in_degree[x] == 0:
|
if in_degree[x] == 0:
|
||||||
early_queue.append(x)
|
queue.append(x)
|
||||||
return replace_nodes.get(sink, sink)
|
return replace_nodes.get(sink, sink)
|
||||||
|
|
||||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||||
|
|
Loading…
Reference in New Issue