mirror of https://github.com/commaai/tinygrad.git
unify assign tests (#4247)
This commit is contained in:
parent
37f8be6450
commit
a9bc7c1c49
|
@ -115,14 +115,22 @@ class TestAssign(unittest.TestCase):
|
||||||
new = a + old_a
|
new = a + old_a
|
||||||
np.testing.assert_allclose(new.numpy(), 4)
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
def test_assign_diamond(self):
|
def test_assign_diamond_cycle(self):
|
||||||
# NOTE: should *not* raise AssertionError from numpy
|
# NOTE: should *not* raise AssertionError from numpy
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
a = Tensor.ones(4).contiguous().realize()
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
times_a = a*3
|
times_a = a*3
|
||||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||||
new = a + times_a
|
new = a + (times_a-1)
|
||||||
np.testing.assert_allclose(new.numpy(), 5)
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
|
def test_assign_diamond_contiguous_cycle(self):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
|
times_a = a*3
|
||||||
|
a.assign(Tensor.full((4,), 2.))
|
||||||
|
new = a.contiguous() + times_a-1
|
||||||
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
def test_assign_diamond_possible(self):
|
def test_assign_diamond_possible(self):
|
||||||
a = Tensor.ones(4).contiguous().realize()
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
|
@ -138,6 +146,13 @@ class TestAssign(unittest.TestCase):
|
||||||
new = a + (times_a-1).contiguous()
|
new = a + (times_a-1).contiguous()
|
||||||
np.testing.assert_allclose(new.numpy(), 4)
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
|
def test_assign_diamond_both_contiguous(self):
|
||||||
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
|
times_a = a*3
|
||||||
|
a.assign(Tensor.full((4,), 2.))
|
||||||
|
new = a.contiguous() + (times_a-1).contiguous()
|
||||||
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
def test_assign_diamond_alt(self):
|
def test_assign_diamond_alt(self):
|
||||||
a = Tensor.ones(4).contiguous().realize()
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||||
|
|
Loading…
Reference in New Issue