diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index f76e79a3..f7a82efb 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1094,6 +1094,8 @@ class TestIndexing(unittest.TestCase): r[zero] numpy_testing_assert_equal_helper(r, r[...]) + # TODO fancy setitem + ''' def test_setitem_scalars(self): zero = Tensor(0, dtype=dtypes.int64) @@ -1121,6 +1123,7 @@ class TestIndexing(unittest.TestCase): # TODO: weird inaccuracy Max relative difference: 3.85322971e-08 # numpy_testing_assert_equal_helper(9.9, r) np.testing.assert_allclose(9.9, r, rtol=1e-7) + ''' def test_basic_advanced_combined(self): # From the NumPy indexing example @@ -1518,7 +1521,10 @@ class TestNumpy(unittest.TestCase): def test_broaderrors_indexing(self): a = Tensor.zeros(5, 5) self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2])) + # TODO: fancy setitem + ''' self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0) + ''' # TODO out of bound getitem does not raise error ''' diff --git a/test/test_setitem.py b/test/test_setitem.py index d28f9eb9..00a02b29 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -42,6 +42,41 @@ class TestSetitem(unittest.TestCase): assert not t.lazydata.st.contiguous with self.assertRaises(AssertionError): t[1] = 5 + def test_setitem_inplace_operator(self): + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] += 2 + np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]]) + + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] -= 1 + np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]]) + + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] *= 2 + np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]]) + + # NOTE: have to manually cast setitem target to least_upper_float for div + t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous() + t[1] /= 2 + np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]]) + + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] **= 2 + np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]]) + + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] ^= 5 + np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]]) + + @unittest.expectedFailure + def test_setitem_consecutive_inplace_operator(self): + t = Tensor.arange(4).reshape(2, 2).contiguous() + t[1] += 2 + t = t.contiguous() + # TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),)) + t[1] -= 1 + np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]]) + # TODO: implement fancy setitem @unittest.expectedFailure def test_fancy_setitem(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8ab65c9a..cd3e7714 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -824,6 +824,8 @@ class Tensor: if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") + if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)): + raise NotImplementedError("Advanced indexing setitem is not currently supported") assign_to = self.realize().__getitem__(indices) # NOTE: contiguous to prevent const folding.