setitem support setting python const (#4111)

This commit is contained in:
chenyu 2024-04-08 11:37:50 -04:00 committed by GitHub
parent f8dc82a8a7
commit dbd39ab78a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 12 deletions

View File

@ -708,7 +708,7 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(v[::11], [0])
numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])
# TODO setitem
# TODO setitem with stride
'''
def test_step_assignment(self):
v = Tensor.zeros(4, 4)
@ -733,7 +733,6 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(v[boolIndices], Tensor([True]))
numpy_testing_assert_equal_helper(len(w), 2)
# TODO setitem
@unittest.skip("bool indexing not supported")
def test_bool_indices_accumulate(self):
mask = Tensor.zeros(size=(10, ), dtype=dtypes.bool)
@ -892,7 +891,6 @@ class TestIndexing(unittest.TestCase):
r = v[c > 0]
numpy_testing_assert_equal_helper(r.shape, (num_ones, 3))
# TODO setitem
@unittest.skip("bool indexing not supported")
def test_jit_indexing(self):
def fn1(x):
@ -1144,8 +1142,6 @@ class TestIndexing(unittest.TestCase):
self.assertNotEqual(x, unmodified)
'''
# TODO setitem
'''
def test_int_assignment(self):
x = Tensor.arange(0, 4).reshape(2, 2)
x[1] = 5
@ -1154,7 +1150,6 @@ class TestIndexing(unittest.TestCase):
x = Tensor.arange(0, 4).reshape(2, 2)
x[1] = Tensor.arange(5, 7)
numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])
'''
# TODO setitem
'''
@ -1524,10 +1519,9 @@ class TestNumpy(unittest.TestCase):
def test_broaderrors_indexing(self):
a = Tensor.zeros(5, 5)
self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
# TODO setitem
# self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0)
self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0)
# TODO setitem
# TODO out of bound getitem does not raise error
'''
def test_trivial_fancy_out_of_bounds(self):
a = Tensor.zeros(5)

View File

@ -10,6 +10,17 @@ class TestSetitem(unittest.TestCase):
n[2:4, 3:5] = np.ones((2, 2))
np.testing.assert_allclose(t.numpy(), n)
t = Tensor.zeros(6, 6).contiguous().realize()
t[2:4, 3:5] = 1.0
n = np.zeros((6, 6))
n[2:4, 3:5] = 1.0
np.testing.assert_allclose(t.numpy(), n)
def test_setitem_into_unrealized(self):
t = Tensor.arange(4).reshape(2, 2)
t[1] = 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
def test_simple_jit_setitem(self):
@TinyJit
def f(t:Tensor, a:Tensor):

View File

@ -510,11 +510,13 @@ class Tensor:
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
return ret
def __setitem__(self, indices, v:Tensor):
def __setitem__(self, indices, v:Union[Tensor, ConstType]):
if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v)
# TODO: support python const v
# TODO: broadcast v to the shape here, refactor for const v and one way broadcast_shape
assign_to = self.__getitem__(indices)
if not isinstance(v, Tensor): v = Tensor(v, self.device, self.dtype)
assign_to = self.realize().__getitem__(indices)
# NOTE: we check that indices is valid first
assert self.lazydata.contiguous(), "setitem target needs to be contiguous"
# NOTE: contiguous to prevent const folding.
return assign_to.assign(v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()).realize()