mirror of https://github.com/commaai/tinygrad.git
setitem support setting python const (#4111)
This commit is contained in:
parent
f8dc82a8a7
commit
dbd39ab78a
|
@ -708,7 +708,7 @@ class TestIndexing(unittest.TestCase):
|
|||
numpy_testing_assert_equal_helper(v[::11], [0])
|
||||
numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])
|
||||
|
||||
# TODO setitem
|
||||
# TODO setitem with stride
|
||||
'''
|
||||
def test_step_assignment(self):
|
||||
v = Tensor.zeros(4, 4)
|
||||
|
@ -733,7 +733,6 @@ class TestIndexing(unittest.TestCase):
|
|||
numpy_testing_assert_equal_helper(v[boolIndices], Tensor([True]))
|
||||
numpy_testing_assert_equal_helper(len(w), 2)
|
||||
|
||||
# TODO setitem
|
||||
@unittest.skip("bool indexing not supported")
|
||||
def test_bool_indices_accumulate(self):
|
||||
mask = Tensor.zeros(size=(10, ), dtype=dtypes.bool)
|
||||
|
@ -892,7 +891,6 @@ class TestIndexing(unittest.TestCase):
|
|||
r = v[c > 0]
|
||||
numpy_testing_assert_equal_helper(r.shape, (num_ones, 3))
|
||||
|
||||
# TODO setitem
|
||||
@unittest.skip("bool indexing not supported")
|
||||
def test_jit_indexing(self):
|
||||
def fn1(x):
|
||||
|
@ -1144,8 +1142,6 @@ class TestIndexing(unittest.TestCase):
|
|||
self.assertNotEqual(x, unmodified)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
'''
|
||||
def test_int_assignment(self):
|
||||
x = Tensor.arange(0, 4).reshape(2, 2)
|
||||
x[1] = 5
|
||||
|
@ -1154,7 +1150,6 @@ class TestIndexing(unittest.TestCase):
|
|||
x = Tensor.arange(0, 4).reshape(2, 2)
|
||||
x[1] = Tensor.arange(5, 7)
|
||||
numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
'''
|
||||
|
@ -1524,10 +1519,9 @@ class TestNumpy(unittest.TestCase):
|
|||
def test_broaderrors_indexing(self):
|
||||
a = Tensor.zeros(5, 5)
|
||||
self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
# TODO setitem
|
||||
# self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
|
||||
# TODO setitem
|
||||
# TODO out of bound getitem does not raise error
|
||||
'''
|
||||
def test_trivial_fancy_out_of_bounds(self):
|
||||
a = Tensor.zeros(5)
|
||||
|
|
|
@ -10,6 +10,17 @@ class TestSetitem(unittest.TestCase):
|
|||
n[2:4, 3:5] = np.ones((2, 2))
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
t[2:4, 3:5] = 1.0
|
||||
n = np.zeros((6, 6))
|
||||
n[2:4, 3:5] = 1.0
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_setitem_into_unrealized(self):
|
||||
t = Tensor.arange(4).reshape(2, 2)
|
||||
t[1] = 5
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
|
||||
|
||||
def test_simple_jit_setitem(self):
|
||||
@TinyJit
|
||||
def f(t:Tensor, a:Tensor):
|
||||
|
|
|
@ -510,11 +510,13 @@ class Tensor:
|
|||
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
|
||||
return ret
|
||||
|
||||
def __setitem__(self, indices, v:Tensor):
|
||||
def __setitem__(self, indices, v:Union[Tensor, ConstType]):
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v)
|
||||
# TODO: support python const v
|
||||
# TODO: broadcast v to the shape here, refactor for const v and one way broadcast_shape
|
||||
assign_to = self.__getitem__(indices)
|
||||
if not isinstance(v, Tensor): v = Tensor(v, self.device, self.dtype)
|
||||
assign_to = self.realize().__getitem__(indices)
|
||||
# NOTE: we check that indices is valid first
|
||||
assert self.lazydata.contiguous(), "setitem target needs to be contiguous"
|
||||
# NOTE: contiguous to prevent const folding.
|
||||
return assign_to.assign(v._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()).realize()
|
||||
|
||||
|
|
Loading…
Reference in New Issue