From 77a378000527e921bfd6a6d3cf55cc84789359c9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 22 Apr 2024 16:12:39 +0300 Subject: [PATCH] assert reduce recompute (#4250) --- test/test_schedule.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index a6045f0a..aa84ab73 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -5,7 +5,7 @@ import unittest from typing import List, Optional, Union from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps +from tinygrad.ops import LoadOps, ReduceOps from tinygrad.helpers import DEBUG, GRAPH, flatten from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.graph import print_tree, realized_lazybuffer @@ -157,6 +157,26 @@ class TestSchedule(unittest.TestCase): bc = b+c check_schedule(bc, 1) + def test_cache_reduce_parent(self): + x = Tensor.empty(32) + r0 = x.mean(axis=0, keepdim=True) + r1 = (x - r0).sum(axis=0).div(2) + out = r0 + r1 + schedule = check_schedule(out, 2) + reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps] + assert len(reduceops) == 2 + + def test_cache_reduce_multiple_children(self): + x = Tensor.empty(32) + y = Tensor.empty(4, 4) + r0 = x.mean(axis=0, keepdim=True) + r1 = (x - r0).sum(axis=0).div(2) + out0 = r0 + y + out1 = r1 + y + schedule = check_schedule([out0, out1], 4) + reduceops = [x for si in schedule for out in si.ast for x in out.lazyops if x.op in ReduceOps] + assert len(reduceops) == 2 + def test_fold_double_unary(self): y = Tensor.empty(2) out = y.sum(keepdim=True).sqrt().__neg__()