From 355e1c135cb3badef9507db2a15020a0c8c94da7 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 15 May 2024 01:34:46 +0800 Subject: [PATCH] pad fusion tests (#4570) * what breaks * Revert "what breaks" This reverts commit e79f679283c853cbadf09bf41fd18bb9601a83ee. * simplest case * one unsafe op * expand+pad, shrink+pad * safe case * refactor --- test/test_schedule.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index c7773126..bdfa32c8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -3,7 +3,9 @@ # NOTE: this has overlap with external_test_opt.py import unittest +import numpy as np from typing import List, Optional, Union +from tinygrad.engine.realize import run_schedule from tinygrad.tensor import Tensor from tinygrad.ops import BinaryOps, LoadOps, ReduceOps from tinygrad.helpers import DEBUG, flatten @@ -775,5 +777,35 @@ class TestSchedule(unittest.TestCase): f = (b - d).sum() - e check_schedule([c, d, e, f], 3) + def test_pad_reduce_safe(self): + Tensor.manual_seed(0) + a = Tensor.rand(3, 4, 5).realize() + b = Tensor.rand(3, 4, 5).realize() + out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous() + run_schedule(check_schedule(out, 1)) + np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum()) + + def test_pad_reduce_usafe(self): + Tensor.manual_seed(0) + a = Tensor.rand(3, 4, 5).realize() + out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous() + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum()) + + def test_shrink_pad_safe(self): + a = Tensor.ones((3, )).contiguous().realize() + b = Tensor.ones((3, )).contiguous().realize() + out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous() + run_schedule(check_schedule(out, 1)) + np.testing.assert_equal(out.numpy(), [2, 0]) + + # TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437) + def test_shrink_pad_unsafe(self): + a = Tensor.ones((3, )).contiguous().realize() + out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous() + run_schedule(check_schedule(out, 1)) + with self.assertRaises(AssertionError): + np.testing.assert_equal(out.numpy(), [2, 0]) + if __name__ == '__main__': unittest.main(verbosity=2)