unify assign tests (#4247)

This commit is contained in:
qazal 2024-04-22 11:01:15 +03:00 committed by GitHub
parent 37f8be6450
commit a9bc7c1c49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 3 deletions

View File

@ -115,14 +115,22 @@ class TestAssign(unittest.TestCase):
new = a + old_a new = a + old_a
np.testing.assert_allclose(new.numpy(), 4) np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond(self): def test_assign_diamond_cycle(self):
# NOTE: should *not* raise AssertionError from numpy # NOTE: should *not* raise AssertionError from numpy
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
a = Tensor.ones(4).contiguous().realize() a = Tensor.ones(4).contiguous().realize()
times_a = a*3 times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous()) a.assign(Tensor.full((4,), 2.).contiguous())
new = a + times_a new = a + (times_a-1)
np.testing.assert_allclose(new.numpy(), 5) np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_contiguous_cycle(self):
with self.assertRaises(RuntimeError):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
new = a.contiguous() + times_a-1
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_possible(self): def test_assign_diamond_possible(self):
a = Tensor.ones(4).contiguous().realize() a = Tensor.ones(4).contiguous().realize()
@ -138,6 +146,13 @@ class TestAssign(unittest.TestCase):
new = a + (times_a-1).contiguous() new = a + (times_a-1).contiguous()
np.testing.assert_allclose(new.numpy(), 4) np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_both_contiguous(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
new = a.contiguous() + (times_a-1).contiguous()
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_alt(self): def test_assign_diamond_alt(self):
a = Tensor.ones(4).contiguous().realize() a = Tensor.ones(4).contiguous().realize()
a.assign(Tensor.full((4,), 2.).contiguous()) a.assign(Tensor.full((4,), 2.).contiguous())