raise error if setitem tensors have requires_grad (#4575)

* raise error if setitem tensors have requires_grad

working on supporting this, first properly raises error

* NotImplementedError
This commit is contained in:
chenyu 2024-05-13 18:56:47 -04:00 committed by GitHub
parent f7d08bd454
commit 0fa57b8ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View File

@ -81,5 +81,23 @@ class TestSetitem(unittest.TestCase):
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8)
z[:3] = x
def test_set_into_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8)
with self.assertRaises(NotImplementedError):
z[:3] = x
def test_set_with_requires_grad(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8, requires_grad=True)
with self.assertRaises(NotImplementedError):
z[:3] = x
if __name__ == '__main__':
unittest.main()

View File

@ -823,6 +823,8 @@ class Tensor:
assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
assign_to = self.realize().__getitem__(indices)
# NOTE: contiguous to prevent const folding.
v = v.cast(assign_to.dtype)._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()