scheduler docs 2 (#4551)

* docs

* delete cleanups
This commit is contained in:
qazal 2024-05-12 17:15:39 +08:00 committed by GitHub
parent e07c7668b3
commit 3da152f0fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 4 deletions

View File

@ -42,9 +42,9 @@ class _LBScheduleItem:
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
"""recursively create a lazyop"""
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
@ -95,6 +95,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
return ret
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
"""create a schedule item from a list of outputs"""
inputs: List[LazyBuffer] = []
ast: List[LazyOp] = []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
@ -115,9 +116,9 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
# *** DAG creation: decide which LazyBuffers should realize ***
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if buf.base != buf:
@ -148,9 +149,9 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
# recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if tr in realizes:
# can only fuse contiguous
# max one reduceop per kernel
@ -166,6 +167,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
Dict[LazyBuffer, _LBScheduleItem]]:
"""create a graph for realizing the outputs"""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base: None for x in outs if not x.base.realized}
allbufs: Dict[LazyBuffer, None] = {}
@ -181,7 +183,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
for r in allbufs:
if r != r.base or r.op not in ReduceOps or r in realizes: continue
group: Set[LazyBuffer] = set()
@ -191,11 +193,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
if not forced_realize and len(group) > 1:
# create a multi output kernel if the LazyBufferss can cleanly group
rc_parents, rc_children = deque(group), deque(group)
while rc_parents and not forced_realize:
# max one reduceop per kernel
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
# search descendants of the reduceop that can cleanly group
realized_descendants: Set[LazyBuffer] = set()
while rc_children and not forced_realize:
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
@ -204,6 +208,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
if c in realizes and c not in group: realized_descendants.add(c)
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
group.update(realized_descendants)
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
parents = deque((r, *group))
while parents and not forced_realize: