assert reduce recompute (#4250)

This commit is contained in:
qazal 2024-04-22 16:12:39 +03:00 committed by GitHub
parent a9bc7c1c49
commit 77a3780005
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 1 deletions

View File

@ -5,7 +5,7 @@
import unittest import unittest
from typing import List, Optional, Union from typing import List, Optional, Union
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, GRAPH, flatten from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer from tinygrad.features.graph import print_tree, realized_lazybuffer
@ -157,6 +157,26 @@ class TestSchedule(unittest.TestCase):
bc = b+c bc = b+c
check_schedule(bc, 1) check_schedule(bc, 1)
def test_cache_reduce_parent(self):
x = Tensor.empty(32)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
x = Tensor.empty(32)
y = Tensor.empty(4, 4)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 4)
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
assert len(reduceops) == 2
def test_fold_double_unary(self): def test_fold_double_unary(self):
y = Tensor.empty(2) y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__() out = y.sum(keepdim=True).sqrt().__neg__()