From 1b53207b4fae1ac37d1aa06d272ab99a8516e460 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 26 Jul 2024 07:45:12 +0800 Subject: [PATCH] revert isolated dags scheduling (#5724) --- test/test_schedule.py | 10 +++--- tinygrad/engine/schedule.py | 62 ++++++++++++++----------------------- 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index ab59748a..7d74ad64 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -209,7 +209,7 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 16)]: + for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -256,7 +256,7 @@ class TestSchedule(unittest.TestCase): fw = bn(x).contiguous_backward().relu().contiguous() fw.sum().backward() # TODO: this is too many - check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9) + check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10) def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) @@ -620,7 +620,7 @@ class TestSchedule(unittest.TestCase): out0 = a.sum() + b.sum() + 2 out1 = a.sum() + b.sum() + 4 # run_schedule(check_schedule([out0, out1], 1)) - run_schedule(check_schedule([out0, out1], 2)) + run_schedule(check_schedule([out0, out1], 4)) np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4) @@ -649,7 +649,7 @@ class TestSchedule(unittest.TestCase): out1 = b.max() + out0*2 out2 = a.sum() + out1 # run_schedule(check_schedule([out0, out1, out2], 1)) - run_schedule(check_schedule([out0, out1, out2], 3)) + run_schedule(check_schedule([out0, out1, out2], 4)) np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6) @@ -1096,7 +1096,7 @@ class TestSchedule(unittest.TestCase): c = a.sum() + 2 d = (a.sum() - b.sum()) * 4 # run_schedule(check_schedule([c, d], 1)) - run_schedule(check_schedule([c, d], 2)) + run_schedule(check_schedule([c, d], 3)) np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 60fa5a23..7b9e8c11 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -186,47 +186,35 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool: return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, bool], cache:Set): + realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set): """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if (tr, st) in cache: return cache.add((tr, st)) if tr in realizes and tr is not r: # can only fuse contiguous # max one reduceop per kernel - group[tr] = st.contiguous and st.size == r.st.size and tr not in reduce_for_op - return + if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r) + return group.setdefault(tr) for tr_next in children[tr]: # max one reduceop per kernel + if tr_next.op in ReduceOps: return group.setdefault(r) # can only fuse contiguous - if tr_next.op in ReduceOps or len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: - group[tr_next] = False - return + if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r) _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache) -def _get_inputs(buf:LazyBuffer, r:LazyBuffer, realizes:Dict[LazyBuffer, None], cache:Dict, first=True) -> Set[LazyBuffer]: - if buf.realized is not None or buf.op is MetaOps.CONST or buf in cache: return cache.get(buf, set()) - if not first and (buf in realizes or buf is r): cache.setdefault(buf, set((buf,))) - return cache.setdefault(buf, set.union(set(), *iter(_get_inputs(x.base, r, realizes, cache, False) for x in buf.srcs))) - def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\ - realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None], realize_reduceop:bool) -> Dict[LazyBuffer, None]: - rc_parents, cache, is_complete = deque(group), set(), not realize_reduceop and len(group) > 1 - while rc_parents and is_complete: + realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]: + rc_parents, cache = deque(group), set() + while rc_parents: if (p:=rc_parents.pop()) in cache: continue cache.add(p) # max one reduceop per kernel - if p.op in ReduceOps: is_complete = False + if p.op in ReduceOps: return {} rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) - if not is_complete: - group = {tr:None for tr in group if tr is not r and len(_get_inputs(tr, r, realizes, {})) == 1} - # if only some children can group, we have to realize the reduceop - if group: realizes[r] = None - return group # search descendants of the reduceop that can cleanly group - descendants: Dict[LazyBuffer, bool] = {} + descendants: Dict[LazyBuffer, None] = {} for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache=set()) - descendants_to_group = {tr:None for tr,can_group in descendants.items() if can_group} - return merge_dicts([group, descendants_to_group if len(descendants_to_group) == len(descendants) else {}]) + return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): """create a graph for realizing the outputs""" @@ -249,23 +237,23 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): for r in allbufs: if r.op not in ReduceOps or r in realizes: continue - reduceop_children: Dict[LazyBuffer, bool] = {} - _recursive_group(r, r.st, r, children, realizes, reduce_for_op, reduceop_children, cache=set()) + group: Dict[LazyBuffer, None] = {} + _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set()) # max one reduceop per kernel - can_chase = all(tr not in reduce_for_op for tr in reduceop_children) - realize_reduceop = any(not can_group for can_group in reduceop_children.values()) - group = {tr:None for tr,can_group in reduceop_children.items() if can_group} - if len(group) > 1 or realize_reduceop: - group = _get_isolated_children(r, reduce_for_op, children, realizes, group, realize_reduceop) + can_chase = all(tr not in reduce_for_op for tr in group) + # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs + forced_realize = r in group + if not forced_realize and len(group) > 1: + group = _get_isolated_children(r, reduce_for_op, children, realizes, group) # can only fuse assign if no other assign_target is used in the kernel - if any(x.op is MetaOps.ASSIGN for x in group): + if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group): parents = deque((r, *group)) - while parents and group: + while parents and not forced_realize: if (p:=parents.pop().base).realized or p in realizes: - if p in assign_targets and assign_targets[p] not in group: group, can_chase = {}, False + if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False continue parents.extend(p.srcs) - if not group: + if forced_realize or not group: tr = r if can_chase: # can chase this down to contiguous children @@ -278,14 +266,12 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): st = st + st_childs[0].st if not st.contiguous or tr_next.op in ReduceOps: break tr = tr_next - # don't cast to higher size before store + # don't cast to higher size before store (tr cannot be realized if forced_realize) if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize: tr = tr.srcs[0].base reduce_for_op[tr] = r if not FUSE_AS_ONE_KERNEL: realizes[tr] = None - else: - if realize_reduceop: realizes[r] = None - reduce_for_op.update((tr, r) for tr in group) + else: reduce_for_op.update((tr, r) for tr in group) # fuse double reduces with no other child if FUSE_CONV_BW: