mirror of https://github.com/commaai/tinygrad.git
revert isolated dags scheduling (#5724)
This commit is contained in:
parent
845b0d1c9d
commit
1b53207b4f
|
@ -209,7 +209,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 16)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
|
@ -256,7 +256,7 @@ class TestSchedule(unittest.TestCase):
|
|||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
@ -620,7 +620,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out0 = a.sum() + b.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
# run_schedule(check_schedule([out0, out1], 1))
|
||||
run_schedule(check_schedule([out0, out1], 2))
|
||||
run_schedule(check_schedule([out0, out1], 4))
|
||||
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
@ -649,7 +649,7 @@ class TestSchedule(unittest.TestCase):
|
|||
out1 = b.max() + out0*2
|
||||
out2 = a.sum() + out1
|
||||
# run_schedule(check_schedule([out0, out1, out2], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2], 3))
|
||||
run_schedule(check_schedule([out0, out1, out2], 4))
|
||||
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
|
||||
|
@ -1096,7 +1096,7 @@ class TestSchedule(unittest.TestCase):
|
|||
c = a.sum() + 2
|
||||
d = (a.sum() - b.sum()) * 4
|
||||
# run_schedule(check_schedule([c, d], 1))
|
||||
run_schedule(check_schedule([c, d], 2))
|
||||
run_schedule(check_schedule([c, d], 3))
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
|
|
@ -186,47 +186,35 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
|||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, bool], cache:Set):
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set):
|
||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||
if (tr, st) in cache: return
|
||||
cache.add((tr, st))
|
||||
if tr in realizes and tr is not r:
|
||||
# can only fuse contiguous
|
||||
# max one reduceop per kernel
|
||||
group[tr] = st.contiguous and st.size == r.st.size and tr not in reduce_for_op
|
||||
return
|
||||
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
|
||||
return group.setdefault(tr)
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if tr_next.op in ReduceOps or len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1:
|
||||
group[tr_next] = False
|
||||
return
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def _get_inputs(buf:LazyBuffer, r:LazyBuffer, realizes:Dict[LazyBuffer, None], cache:Dict, first=True) -> Set[LazyBuffer]:
|
||||
if buf.realized is not None or buf.op is MetaOps.CONST or buf in cache: return cache.get(buf, set())
|
||||
if not first and (buf in realizes or buf is r): cache.setdefault(buf, set((buf,)))
|
||||
return cache.setdefault(buf, set.union(set(), *iter(_get_inputs(x.base, r, realizes, cache, False) for x in buf.srcs)))
|
||||
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], realize_reduceop:bool) -> Dict[LazyBuffer, None]:
|
||||
rc_parents, cache, is_complete = deque(group), set(), not realize_reduceop and len(group) > 1
|
||||
while rc_parents and is_complete:
|
||||
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
|
||||
rc_parents, cache = deque(group), set()
|
||||
while rc_parents:
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps: is_complete = False
|
||||
if p.op in ReduceOps: return {}
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
if not is_complete:
|
||||
group = {tr:None for tr in group if tr is not r and len(_get_inputs(tr, r, realizes, {})) == 1}
|
||||
# if only some children can group, we have to realize the reduceop
|
||||
if group: realizes[r] = None
|
||||
return group
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: Dict[LazyBuffer, bool] = {}
|
||||
descendants: Dict[LazyBuffer, None] = {}
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set())
|
||||
descendants_to_group = {tr:None for tr,can_group in descendants.items() if can_group}
|
||||
return merge_dicts([group, descendants_to_group if len(descendants_to_group) == len(descendants) else {}])
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
"""create a graph for realizing the outputs"""
|
||||
|
@ -249,23 +237,23 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
|||
for r in allbufs:
|
||||
if r.op not in ReduceOps or r in realizes: continue
|
||||
|
||||
reduceop_children: Dict[LazyBuffer, bool] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, reduceop_children, cache=set())
|
||||
group: Dict[LazyBuffer, None] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set())
|
||||
# max one reduceop per kernel
|
||||
can_chase = all(tr not in reduce_for_op for tr in reduceop_children)
|
||||
realize_reduceop = any(not can_group for can_group in reduceop_children.values())
|
||||
group = {tr:None for tr,can_group in reduceop_children.items() if can_group}
|
||||
if len(group) > 1 or realize_reduceop:
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group, realize_reduceop)
|
||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
if not forced_realize and len(group) > 1:
|
||||
group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if any(x.op is MetaOps.ASSIGN for x in group):
|
||||
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
|
||||
parents = deque((r, *group))
|
||||
while parents and group:
|
||||
while parents and not forced_realize:
|
||||
if (p:=parents.pop().base).realized or p in realizes:
|
||||
if p in assign_targets and assign_targets[p] not in group: group, can_chase = {}, False
|
||||
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
|
||||
continue
|
||||
parents.extend(p.srcs)
|
||||
if not group:
|
||||
if forced_realize or not group:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
|
@ -278,14 +266,12 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
|||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
tr = tr.srcs[0].base
|
||||
reduce_for_op[tr] = r
|
||||
if not FUSE_AS_ONE_KERNEL: realizes[tr] = None
|
||||
else:
|
||||
if realize_reduceop: realizes[r] = None
|
||||
reduce_for_op.update((tr, r) for tr in group)
|
||||
else: reduce_for_op.update((tr, r) for tr in group)
|
||||
|
||||
# fuse double reduces with no other child
|
||||
if FUSE_CONV_BW:
|
||||
|
|
Loading…
Reference in New Issue