mirror of https://github.com/commaai/tinygrad.git
rename to realize_reduceop (#5713)
* rename to realize_reduceop * shorter comment
This commit is contained in:
parent
05e02ddfb3
commit
f02124ffa0
|
@ -209,8 +209,8 @@ def _get_inputs(buf:LazyBuffer, r:LazyBuffer, realizes:Dict[LazyBuffer, None], c
|
|||
return cache.setdefault(buf, set.union(set(), *iter(_get_inputs(x.base, r, realizes, cache, False) for x in buf.srcs)))
|
||||
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], forced_realize:bool) -> Dict[LazyBuffer, None]:
|
||||
rc_parents, cache, is_complete = deque(group), set(), not forced_realize and len(group) > 1
|
||||
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], realize_reduceop:bool) -> Dict[LazyBuffer, None]:
|
||||
rc_parents, cache, is_complete = deque(group), set(), not realize_reduceop and len(group) > 1
|
||||
while rc_parents and is_complete:
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
|
@ -219,7 +219,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
|||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
if not is_complete:
|
||||
group = {tr:None for tr in group if tr is not r and len(_get_inputs(tr, r, realizes, {})) == 1}
|
||||
# if it can only group some children, we have to realize the reduceop
|
||||
# if only some children can group, we have to realize the reduceop
|
||||
if group: realizes[r] = None
|
||||
return group
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
|
@ -253,10 +253,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
|||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, reduceop_children, cache=set())
|
||||
# max one reduceop per kernel
|
||||
can_chase = all(tr not in reduce_for_op for tr in reduceop_children)
|
||||
# TODO: forced_realize is poorly named now
|
||||
forced_realize = any(not can_group for can_group in reduceop_children.values())
|
||||
if len(group:={tr:None for tr,can_group in reduceop_children.items() if can_group}) > 1 or forced_realize:
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group, forced_realize)
|
||||
realize_reduceop = any(not can_group for can_group in reduceop_children.values())
|
||||
group = {tr:None for tr,can_group in reduceop_children.items() if can_group}
|
||||
if len(group) > 1 or realize_reduceop:
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group, realize_reduceop)
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if any(x.op is MetaOps.ASSIGN for x in group):
|
||||
parents = deque((r, *group))
|
||||
|
@ -278,13 +278,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
|||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
# don't cast to higher size before store
|
||||
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
tr = tr.srcs[0].base
|
||||
reduce_for_op[tr] = r
|
||||
if not FUSE_AS_ONE_KERNEL: realizes[tr] = None
|
||||
else:
|
||||
if forced_realize: realizes[r] = None
|
||||
if realize_reduceop: realizes[r] = None
|
||||
reduce_for_op.update((tr, r) for tr in group)
|
||||
|
||||
# fuse double reduces with no other child
|
||||
|
|
Loading…
Reference in New Issue