mirror of https://github.com/commaai/tinygrad.git
setitem in-place operator tests (#4577)
* tests and error * rename to in-place * add a note * more comments * more comments * disable folded advanced setitem tests for now
This commit is contained in:
parent
0fa57b8ce9
commit
089eeec271
|
@ -1094,6 +1094,8 @@ class TestIndexing(unittest.TestCase):
|
|||
r[zero]
|
||||
numpy_testing_assert_equal_helper(r, r[...])
|
||||
|
||||
# TODO fancy setitem
|
||||
'''
|
||||
def test_setitem_scalars(self):
|
||||
zero = Tensor(0, dtype=dtypes.int64)
|
||||
|
||||
|
@ -1121,6 +1123,7 @@ class TestIndexing(unittest.TestCase):
|
|||
# TODO: weird inaccuracy Max relative difference: 3.85322971e-08
|
||||
# numpy_testing_assert_equal_helper(9.9, r)
|
||||
np.testing.assert_allclose(9.9, r, rtol=1e-7)
|
||||
'''
|
||||
|
||||
def test_basic_advanced_combined(self):
|
||||
# From the NumPy indexing example
|
||||
|
@ -1518,7 +1521,10 @@ class TestNumpy(unittest.TestCase):
|
|||
def test_broaderrors_indexing(self):
|
||||
a = Tensor.zeros(5, 5)
|
||||
self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
# TODO: fancy setitem
|
||||
'''
|
||||
self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
'''
|
||||
|
||||
# TODO out of bound getitem does not raise error
|
||||
'''
|
||||
|
|
|
@ -42,6 +42,41 @@ class TestSetitem(unittest.TestCase):
|
|||
assert not t.lazydata.st.contiguous
|
||||
with self.assertRaises(AssertionError): t[1] = 5
|
||||
|
||||
def test_setitem_inplace_operator(self):
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] += 2
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])
|
||||
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] -= 1
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])
|
||||
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] *= 2
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])
|
||||
|
||||
# NOTE: have to manually cast setitem target to least_upper_float for div
|
||||
t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
|
||||
t[1] /= 2
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])
|
||||
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] **= 2
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])
|
||||
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] ^= 5
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_setitem_consecutive_inplace_operator(self):
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] += 2
|
||||
t = t.contiguous()
|
||||
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
|
||||
t[1] -= 1
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
|
||||
|
||||
# TODO: implement fancy setitem
|
||||
@unittest.expectedFailure
|
||||
def test_fancy_setitem(self):
|
||||
|
|
|
@ -824,6 +824,8 @@ class Tensor:
|
|||
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
|
||||
raise NotImplementedError("Advanced indexing setitem is not currently supported")
|
||||
|
||||
assign_to = self.realize().__getitem__(indices)
|
||||
# NOTE: contiguous to prevent const folding.
|
||||
|
|
Loading…
Reference in New Issue