mirror of https://github.com/commaai/tinygrad.git
parent
e07c7668b3
commit
3da152f0fe
|
@ -42,9 +42,9 @@ class _LBScheduleItem:
|
|||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
"""recursively create a lazyop"""
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
st = buf.st + st
|
||||
|
@ -95,6 +95,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
|
|||
return ret
|
||||
|
||||
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
"""create a schedule item from a list of outputs"""
|
||||
inputs: List[LazyBuffer] = []
|
||||
ast: List[LazyOp] = []
|
||||
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
|
@ -115,9 +116,9 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
|
|||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
||||
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
||||
if buf in allbufs or buf.base.realized: return
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
if buf.base != buf:
|
||||
|
@ -148,9 +149,9 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
|||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
# recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
|
||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||
if tr in realizes:
|
||||
# can only fuse contiguous
|
||||
# max one reduceop per kernel
|
||||
|
@ -166,6 +167,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
|||
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
|
||||
Dict[LazyBuffer, _LBScheduleItem]]:
|
||||
"""create a graph for realizing the outputs"""
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Dict[LazyBuffer, None] = {x.base: None for x in outs if not x.base.realized}
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
|
@ -181,7 +183,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
for r in allbufs.keys():
|
||||
for r in allbufs:
|
||||
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
||||
|
||||
group: Set[LazyBuffer] = set()
|
||||
|
@ -191,11 +193,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
if not forced_realize and len(group) > 1:
|
||||
# create a multi output kernel if the LazyBufferss can cleanly group
|
||||
rc_parents, rc_children = deque(group), deque(group)
|
||||
while rc_parents and not forced_realize:
|
||||
# max one reduceop per kernel
|
||||
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
||||
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
realized_descendants: Set[LazyBuffer] = set()
|
||||
while rc_children and not forced_realize:
|
||||
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
||||
|
@ -204,6 +208,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
if c in realizes and c not in group: realized_descendants.add(c)
|
||||
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
||||
group.update(realized_descendants)
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
|
|
Loading…
Reference in New Issue