mirror of https://github.com/commaai/tinygrad.git
assert reduce recompute (#4250)
This commit is contained in:
parent
a9bc7c1c49
commit
77a3780005
|
@ -5,7 +5,7 @@
|
|||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
|
@ -157,6 +157,26 @@ class TestSchedule(unittest.TestCase):
|
|||
bc = b+c
|
||||
check_schedule(bc, 1)
|
||||
|
||||
def test_cache_reduce_parent(self):
|
||||
x = Tensor.empty(32)
|
||||
r0 = x.mean(axis=0, keepdim=True)
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out = r0 + r1
|
||||
schedule = check_schedule(out, 2)
|
||||
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_cache_reduce_multiple_children(self):
|
||||
x = Tensor.empty(32)
|
||||
y = Tensor.empty(4, 4)
|
||||
r0 = x.mean(axis=0, keepdim=True)
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
y = Tensor.empty(2)
|
||||
out = y.sum(keepdim=True).sqrt().__neg__()
|
||||
|
|
Loading…
Reference in New Issue