From 0b47818e0fbc71df46717c907d62928f598c8c78 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 2 May 2024 15:15:30 +0300 Subject: [PATCH] simpler reduceop children chasing (#4350) * simplest case * midreduce case * all tests * pending things * unify tests --- test/test_schedule.py | 34 +++++++++++++++++++++++++++++++++- tinygrad/engine/schedule.py | 5 ++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 86c3e99c..8470d533 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -5,7 +5,7 @@ import unittest from typing import List, Optional, Union from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps, ReduceOps +from tinygrad.ops import BinaryOps, LoadOps, ReduceOps from tinygrad.helpers import DEBUG, GRAPH, flatten from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.graph import print_tree, realized_lazybuffer @@ -680,5 +680,37 @@ class TestSchedule(unittest.TestCase): # sched = check_schedule([b, c], 4) # doesn't store either in half because it doesn't chase + def test_reduce_simple_chase(self): + a = Tensor.empty(4, 4, 4) + r = a.sum(0) + 6 + b = r.sum(0) * 4 + c = r.sum(1) * 2 + schedule = check_schedule([b, c], 3) + assert schedule[0].ast[0].src[0].op is BinaryOps.ADD + + def test_push_permute_chase(self): + a = Tensor.empty(4, 4, 4) + b = Tensor.empty(4, 4) + r = a.sum(2) + b + d = r.T * 4 + e = r * d + schedule = check_schedule([d, e], 3) + assert schedule[0].ast[0].src[0].op is BinaryOps.ADD + + def test_push_shrink_chase(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(4) + c = Tensor.empty(16, ) + r = a.sum(1) + c + d = r[:4] * b + schedule = check_schedule(d, 2) + assert schedule[0].ast[0].src[0].op is BinaryOps.ADD + + def test_midreduce_nochase(self): + a = Tensor.empty(16, 16) + b = (a.sum(0) + a.max(1)) + 2 + schedule = check_schedule(b, 2) + assert schedule[0].ast[0].src[0].op is ReduceOps.MAX + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e4ca558e..2d3bc529 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -162,8 +162,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul realized_children[tr] = st # can only reduce contiguous # max one reduceop per kernel - if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): - can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r + if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: + can_chase = tr not in reduce_for_op forced_realize = True break if len(realized_children) > 1: @@ -174,7 +174,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul if p is r: continue # max one reduceop per kernel if p.op in ReduceOps: - can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r forced_realize = True break for x in p.srcs: rc_parents.append(x.base)