clean up graph dedup function [run_process_replay] (#5169)

This commit is contained in:
George Hotz 2024-06-26 15:07:34 -07:00 committed by GitHub
parent 3a04e518ec
commit 396ce6cfc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
import functools, itertools, heapq, math import functools, itertools, heapq, math
from collections import defaultdict from collections import defaultdict, deque
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, dtypes, DType from tinygrad.dtype import ConstType, dtypes, DType
@ -334,23 +334,22 @@ class UOpGraph:
def graph_dedup(self, sink:UOp): def graph_dedup(self, sink:UOp):
# add nodes to graph in reverse BFS order # add nodes to graph in reverse BFS order
# dedup all nodes # dedup all nodes
# TODO: i feel like this BFS is written in a few places, possible to library it? in_degree: Dict[UOp, int] = {}
early_in_degree: Dict[UOp, int] = {}
children: Dict[UOp, List[UOp]] = {} children: Dict[UOp, List[UOp]] = {}
get_children_dfs(sink, children, early_in_degree) get_children_dfs(sink, children, in_degree)
early_queue = [k for k, v in early_in_degree.items() if v == 0] queue = deque([k for k, v in in_degree.items() if v == 0])
replace_nodes: Dict[UOp, UOp] = {} replace_nodes: Dict[UOp, UOp] = {}
while len(early_queue): while queue:
n = early_queue.pop(0) n = queue.popleft()
if n in replace_nodes: continue if n in replace_nodes: continue
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg) key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
if found:=self.nodes.get(key): replace_nodes[n] = found if found:=self.nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = self.nodes[key] = UOp(*key) else: replace_nodes[n] = self.nodes[key] = UOp(*key)
for x in children[n]: for x in children[n]:
early_in_degree[x] -= 1 in_degree[x] -= 1
if early_in_degree[x] == 0: if in_degree[x] == 0:
early_queue.append(x) queue.append(x)
return replace_nodes.get(sink, sink) return replace_nodes.get(sink, sink)
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True): def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):