mirror of https://github.com/commaai/tinygrad.git
assert reduce recompute (#4250)
This commit is contained in:
parent
a9bc7c1c49
commit
77a3780005
|
@ -5,7 +5,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import LoadOps
|
from tinygrad.ops import LoadOps, ReduceOps
|
||||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||||
from tinygrad.codegen.linearizer import Linearizer
|
from tinygrad.codegen.linearizer import Linearizer
|
||||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||||
|
@ -157,6 +157,26 @@ class TestSchedule(unittest.TestCase):
|
||||||
bc = b+c
|
bc = b+c
|
||||||
check_schedule(bc, 1)
|
check_schedule(bc, 1)
|
||||||
|
|
||||||
|
def test_cache_reduce_parent(self):
|
||||||
|
x = Tensor.empty(32)
|
||||||
|
r0 = x.mean(axis=0, keepdim=True)
|
||||||
|
r1 = (x - r0).sum(axis=0).div(2)
|
||||||
|
out = r0 + r1
|
||||||
|
schedule = check_schedule(out, 2)
|
||||||
|
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
|
||||||
|
assert len(reduceops) == 2
|
||||||
|
|
||||||
|
def test_cache_reduce_multiple_children(self):
|
||||||
|
x = Tensor.empty(32)
|
||||||
|
y = Tensor.empty(4, 4)
|
||||||
|
r0 = x.mean(axis=0, keepdim=True)
|
||||||
|
r1 = (x - r0).sum(axis=0).div(2)
|
||||||
|
out0 = r0 + y
|
||||||
|
out1 = r1 + y
|
||||||
|
schedule = check_schedule([out0, out1], 4)
|
||||||
|
reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps]
|
||||||
|
assert len(reduceops) == 2
|
||||||
|
|
||||||
def test_fold_double_unary(self):
|
def test_fold_double_unary(self):
|
||||||
y = Tensor.empty(2)
|
y = Tensor.empty(2)
|
||||||
out = y.sum(keepdim=True).sqrt().__neg__()
|
out = y.sum(keepdim=True).sqrt().__neg__()
|
||||||
|
|
Loading…
Reference in New Issue