setitem in-place operator tests (#4577)

* tests and error

* rename to in-place

* add a note

* more comments

* more comments

* disable folded advanced setitem tests for now
This commit is contained in:
geohotstan 2024-05-14 13:28:02 +08:00 committed by GitHub
parent 0fa57b8ce9
commit 089eeec271
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 0 deletions

View File

@ -1094,6 +1094,8 @@ class TestIndexing(unittest.TestCase):
r[zero]
numpy_testing_assert_equal_helper(r, r[...])
# TODO fancy setitem
'''
def test_setitem_scalars(self):
zero = Tensor(0, dtype=dtypes.int64)
@ -1121,6 +1123,7 @@ class TestIndexing(unittest.TestCase):
# TODO: weird inaccuracy Max relative difference: 3.85322971e-08
# numpy_testing_assert_equal_helper(9.9, r)
np.testing.assert_allclose(9.9, r, rtol=1e-7)
'''
def test_basic_advanced_combined(self):
# From the NumPy indexing example
@ -1518,7 +1521,10 @@ class TestNumpy(unittest.TestCase):
def test_broaderrors_indexing(self):
a = Tensor.zeros(5, 5)
self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
# TODO: fancy setitem
'''
self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)
'''
# TODO out of bound getitem does not raise error
'''

View File

@ -42,6 +42,41 @@ class TestSetitem(unittest.TestCase):
assert not t.lazydata.st.contiguous
with self.assertRaises(AssertionError): t[1] = 5
def test_setitem_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] *= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])
# NOTE: have to manually cast setitem target to least_upper_float for div
t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
t[1] /= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] **= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
@unittest.expectedFailure
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
t = t.contiguous()
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
# TODO: implement fancy setitem
@unittest.expectedFailure
def test_fancy_setitem(self):

View File

@ -824,6 +824,8 @@ class Tensor:
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
raise NotImplementedError("Advanced indexing setitem is not currently supported")
assign_to = self.realize().__getitem__(indices)
# NOTE: contiguous to prevent const folding.