mirror of https://github.com/commaai/tinygrad.git
pad fusion tests (#4570)
* what breaks * Revert "what breaks" This reverts commit e79f679283c853cbadf09bf41fd18bb9601a83ee. * simplest case * one unsafe op * expand+pad, shrink+pad * safe case * refactor
This commit is contained in:
parent
7afca52796
commit
355e1c135c
|
@ -3,7 +3,9 @@
|
|||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
|
@ -775,5 +777,35 @@ class TestSchedule(unittest.TestCase):
|
|||
f = (b - d).sum() - e
|
||||
check_schedule([c, d, e, f], 3)
|
||||
|
||||
def test_pad_reduce_safe(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
b = Tensor.rand(3, 4, 5).realize()
|
||||
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
|
||||
|
||||
def test_pad_reduce_usafe(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
|
||||
|
||||
def test_shrink_pad_safe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
b = Tensor.ones((3, )).contiguous().realize()
|
||||
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
|
||||
def test_shrink_pad_unsafe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Reference in New Issue