mirror of https://github.com/commaai/tinygrad.git
simpler reduceop children chasing (#4350)
* simplest case * midreduce case * all tests * pending things * unify tests
This commit is contained in:
parent
22376e53b7
commit
0b47818e0f
|
@ -5,7 +5,7 @@
|
|||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
|
@ -680,5 +680,37 @@ class TestSchedule(unittest.TestCase):
|
|||
# sched = check_schedule([b, c], 4)
|
||||
# doesn't store either in half because it doesn't chase
|
||||
|
||||
def test_reduce_simple_chase(self):
|
||||
a = Tensor.empty(4, 4, 4)
|
||||
r = a.sum(0) + 6
|
||||
b = r.sum(0) * 4
|
||||
c = r.sum(1) * 2
|
||||
schedule = check_schedule([b, c], 3)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_push_permute_chase(self):
|
||||
a = Tensor.empty(4, 4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
r = a.sum(2) + b
|
||||
d = r.T * 4
|
||||
e = r * d
|
||||
schedule = check_schedule([d, e], 3)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_push_shrink_chase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(4)
|
||||
c = Tensor.empty(16, )
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
schedule = check_schedule(d, 2)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_midreduce_nochase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
schedule = check_schedule(b, 2)
|
||||
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -162,8 +162,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
realized_children[tr] = st
|
||||
# can only reduce contiguous
|
||||
# max one reduceop per kernel
|
||||
if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op:
|
||||
can_chase = tr not in reduce_for_op
|
||||
forced_realize = True
|
||||
break
|
||||
if len(realized_children) > 1:
|
||||
|
@ -174,7 +174,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
if p is r: continue
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps:
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
for x in p.srcs: rc_parents.append(x.base)
|
||||
|
|
Loading…
Reference in New Issue