pad fusion tests (#4570)

* what breaks

* Revert "what breaks"

This reverts commit e79f679283c853cbadf09bf41fd18bb9601a83ee.

* simplest case

* one unsafe op

* expand+pad, shrink+pad

* safe case

* refactor
This commit is contained in:
qazal 2024-05-15 01:34:46 +08:00 committed by GitHub
parent 7afca52796
commit 355e1c135c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 32 additions and 0 deletions

View File

@ -3,7 +3,9 @@
# NOTE: this has overlap with external_test_opt.py # NOTE: this has overlap with external_test_opt.py
import unittest import unittest
import numpy as np
from typing import List, Optional, Union from typing import List, Optional, Union
from tinygrad.engine.realize import run_schedule
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, flatten from tinygrad.helpers import DEBUG, flatten
@ -775,5 +777,35 @@ class TestSchedule(unittest.TestCase):
f = (b - d).sum() - e f = (b - d).sum() - e
check_schedule([c, d, e, f], 3) check_schedule([c, d, e, f], 3)
def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_pad_reduce_usafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)