uops dfs cleanup (#5147)

* uops dfs cleanup

* Update uops.py
This commit is contained in:
George Hotz 2024-06-25 18:51:42 -07:00 committed by GitHub
parent 6841ea3baf
commit 63ba2d05d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 27 deletions

View File

@ -1,6 +1,6 @@
import numpy as np
from dataclasses import replace
from typing import DefaultDict, Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
@ -8,7 +8,7 @@ from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import _to_np_dtype
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
def fuzz_uops(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
paths: List[List[UOp]] = []
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
for p in find_all_toposorts(graph, in_degree):
@ -50,7 +50,7 @@ class UOpsFuzzerRunner(CompiledRunner):
print(colored(name, "red"))
raise e
def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]) -> List[Tuple[UOp, ...]]:
def find_all_toposorts(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]) -> List[Tuple[UOp, ...]]:
visited: Set[UOp] = set()
ret: List[Tuple[UOp, ...]] = []
path: List[UOp] = []

View File

@ -278,6 +278,14 @@ constant_folder = PatternMatcher([
# *** uop graph ***
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
if u in children: return
children[u] = []
for x in u.src:
get_children_dfs(x, children, in_degree)
children[x].append(u)
in_degree[u] = len(u.src)
class UOpGraph:
def __init__(self, sinks:List[UOp]):
self.sinks: List[UOp] = sinks
@ -335,15 +343,10 @@ class UOpGraph:
# add nodes to graph in reverse BFS order
# dedup all nodes
# TODO: i feel like this BFS is written in a few places, possible to library it?
unprocessed_nodes = [sink]
early_in_degree: Dict[UOp, int] = {}
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
while len(unprocessed_nodes):
n = unprocessed_nodes.pop(0)
if n in early_in_degree: continue
early_in_degree[n] = len(n.src)
for x in n.src: children[x].append(n)
unprocessed_nodes += list(n.src)
children: Dict[UOp, List[UOp]] = {}
get_children_dfs(sink, children, early_in_degree)
early_queue = [k for k, v in early_in_degree.items() if v == 0]
replace_nodes: Dict[UOp, UOp] = {}
while len(early_queue):
@ -372,21 +375,9 @@ class UOpGraph:
# filter nodes that don't link to a sink
# BFS toposort
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
in_degree: DefaultDict[UOp, int] = defaultdict(int)
loops:List[UOp] = []
ifs:List[UOp] = []
nodes: Dict[UOp, None] = {}
def add_parents(u:UOp):
if u in nodes: return
nodes[u] = None
for x in u.src:
add_parents(x)
in_degree[u] += 1
graph[x].append(u)
if u.op is UOps.RANGE: loops.append(u)
if u.op is UOps.IF: ifs.append(u)
add_parents(sink)
graph: Dict[UOp, List[UOp]] = {}
in_degree: Dict[UOp, int] = {}
get_children_dfs(sink, graph, in_degree)
@functools.lru_cache(None)
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
@ -394,6 +385,7 @@ class UOpGraph:
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.op is not end]))
# scope children impact the toposort and END* insertion
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
queue:List[Tuple[int, UOp]] = []
@ -404,7 +396,7 @@ class UOpGraph:
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
heapq.heappush(queue, (priority, u))
for u in nodes:
for u in graph:
if in_degree[u] == 0: push(u)
if getenv("FUZZ_UOPS", 0):