mirror of https://github.com/commaai/tinygrad.git
parent
6841ea3baf
commit
63ba2d05d1
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
@ -8,7 +8,7 @@ from tinygrad.helpers import DEBUG, colored, getenv
|
|||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
|
||||
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
|
||||
def fuzz_uops(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
|
||||
paths: List[List[UOp]] = []
|
||||
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
|
||||
for p in find_all_toposorts(graph, in_degree):
|
||||
|
@ -50,7 +50,7 @@ class UOpsFuzzerRunner(CompiledRunner):
|
|||
print(colored(name, "red"))
|
||||
raise e
|
||||
|
||||
def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]) -> List[Tuple[UOp, ...]]:
|
||||
def find_all_toposorts(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]) -> List[Tuple[UOp, ...]]:
|
||||
visited: Set[UOp] = set()
|
||||
ret: List[Tuple[UOp, ...]] = []
|
||||
path: List[UOp] = []
|
||||
|
|
|
@ -278,6 +278,14 @@ constant_folder = PatternMatcher([
|
|||
|
||||
# *** uop graph ***
|
||||
|
||||
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
|
||||
if u in children: return
|
||||
children[u] = []
|
||||
for x in u.src:
|
||||
get_children_dfs(x, children, in_degree)
|
||||
children[x].append(u)
|
||||
in_degree[u] = len(u.src)
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self, sinks:List[UOp]):
|
||||
self.sinks: List[UOp] = sinks
|
||||
|
@ -335,15 +343,10 @@ class UOpGraph:
|
|||
# add nodes to graph in reverse BFS order
|
||||
# dedup all nodes
|
||||
# TODO: i feel like this BFS is written in a few places, possible to library it?
|
||||
unprocessed_nodes = [sink]
|
||||
early_in_degree: Dict[UOp, int] = {}
|
||||
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
while len(unprocessed_nodes):
|
||||
n = unprocessed_nodes.pop(0)
|
||||
if n in early_in_degree: continue
|
||||
early_in_degree[n] = len(n.src)
|
||||
for x in n.src: children[x].append(n)
|
||||
unprocessed_nodes += list(n.src)
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
get_children_dfs(sink, children, early_in_degree)
|
||||
|
||||
early_queue = [k for k, v in early_in_degree.items() if v == 0]
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while len(early_queue):
|
||||
|
@ -372,21 +375,9 @@ class UOpGraph:
|
|||
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
||||
loops:List[UOp] = []
|
||||
ifs:List[UOp] = []
|
||||
nodes: Dict[UOp, None] = {}
|
||||
def add_parents(u:UOp):
|
||||
if u in nodes: return
|
||||
nodes[u] = None
|
||||
for x in u.src:
|
||||
add_parents(x)
|
||||
in_degree[u] += 1
|
||||
graph[x].append(u)
|
||||
if u.op is UOps.RANGE: loops.append(u)
|
||||
if u.op is UOps.IF: ifs.append(u)
|
||||
add_parents(sink)
|
||||
graph: Dict[UOp, List[UOp]] = {}
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
get_children_dfs(sink, graph, in_degree)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
||||
|
@ -394,6 +385,7 @@ class UOpGraph:
|
|||
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.op is not end]))
|
||||
# scope children impact the toposort and END* insertion
|
||||
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
|
||||
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
|
||||
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
|
@ -404,7 +396,7 @@ class UOpGraph:
|
|||
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
||||
heapq.heappush(queue, (priority, u))
|
||||
|
||||
for u in nodes:
|
||||
for u in graph:
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
if getenv("FUZZ_UOPS", 0):
|
||||
|
|
Loading…
Reference in New Issue