rename to realize_reduceop (#5713)

* rename to realize_reduceop

* shorter comment
This commit is contained in:
qazal 2024-07-26 01:57:33 +08:00 committed by GitHub
parent 05e02ddfb3
commit f02124ffa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 9 deletions

View File

@ -209,8 +209,8 @@ def _get_inputs(buf:LazyBuffer, r:LazyBuffer, realizes:Dict[LazyBuffer, None], c
return cache.setdefault(buf, set.union(set(), *iter(_get_inputs(x.base, r, realizes, cache, False) for x in buf.srcs)))
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], forced_realize:bool) -> Dict[LazyBuffer, None]:
rc_parents, cache, is_complete = deque(group), set(), not forced_realize and len(group) > 1
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], realize_reduceop:bool) -> Dict[LazyBuffer, None]:
rc_parents, cache, is_complete = deque(group), set(), not realize_reduceop and len(group) > 1
while rc_parents and is_complete:
if (p:=rc_parents.pop()) in cache: continue
cache.add(p)
@ -219,7 +219,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
if not is_complete:
group = {tr:None for tr in group if tr is not r and len(_get_inputs(tr, r, realizes, {})) == 1}
# if it can only group some children, we have to realize the reduceop
# if only some children can group, we have to realize the reduceop
if group: realizes[r] = None
return group
# search descendants of the reduceop that can cleanly group
@ -253,10 +253,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, reduceop_children, cache=set())
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in reduceop_children)
# TODO: forced_realize is poorly named now
forced_realize = any(not can_group for can_group in reduceop_children.values())
if len(group:={tr:None for tr,can_group in reduceop_children.items() if can_group}) > 1 or forced_realize:
group = _get_isolated_children(r, reduce_for_op, children, realizes, group, forced_realize)
realize_reduceop = any(not can_group for can_group in reduceop_children.values())
group = {tr:None for tr,can_group in reduceop_children.items() if can_group}
if len(group) > 1 or realize_reduceop:
group = _get_isolated_children(r, reduce_for_op, children, realizes, group, realize_reduceop)
# can only fuse assign if no other assign_target is used in the kernel
if any(x.op is MetaOps.ASSIGN for x in group):
parents = deque((r, *group))
@ -278,13 +278,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
# don't cast to higher size before store
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
tr = tr.srcs[0].base
reduce_for_op[tr] = r
if not FUSE_AS_ONE_KERNEL: realizes[tr] = None
else:
if forced_realize: realizes[r] = None
if realize_reduceop: realizes[r] = None
reduce_for_op.update((tr, r) for tr in group)
# fuse double reduces with no other child