simpler reduceop children chasing (#4350)

* simplest case

* midreduce case

* all tests

* pending things

* unify tests
This commit is contained in:
qazal 2024-05-02 15:15:30 +03:00 committed by GitHub
parent 22376e53b7
commit 0b47818e0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 4 deletions

View File

@ -5,7 +5,7 @@
import unittest
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
@ -680,5 +680,37 @@ class TestSchedule(unittest.TestCase):
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase
def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_push_permute_chase(self):
a = Tensor.empty(4, 4, 4)
b = Tensor.empty(4, 4)
r = a.sum(2) + b
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_push_shrink_chase(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(4)
c = Tensor.empty(16, )
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_midreduce_nochase(self):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -162,8 +162,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
realized_children[tr] = st
# can only reduce contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op:
can_chase = tr not in reduce_for_op
forced_realize = True
break
if len(realized_children) > 1:
@ -174,7 +174,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
if p is r: continue
# max one reduceop per kernel
if p.op in ReduceOps:
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
for x in p.srcs: rc_parents.append(x.base)