pad fusion tests (#4570)

* what breaks

* Revert "what breaks"

This reverts commit e79f679283c853cbadf09bf41fd18bb9601a83ee.

* simplest case

* one unsafe op

* expand+pad, shrink+pad

* safe case

* refactor
This commit is contained in:
qazal 2024-05-15 01:34:46 +08:00 committed by GitHub
parent 7afca52796
commit 355e1c135c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 32 additions and 0 deletions

View File

@ -3,7 +3,9 @@
# NOTE: this has overlap with external_test_opt.py
import unittest
import numpy as np
from typing import List, Optional, Union
from tinygrad.engine.realize import run_schedule
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, flatten
@ -775,5 +777,35 @@ class TestSchedule(unittest.TestCase):
f = (b - d).sum() - e
check_schedule([c, d, e, f], 3)
def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_pad_reduce_usafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])
if __name__ == '__main__':
unittest.main(verbosity=2)