mirror of https://github.com/commaai/tinygrad.git
raise error if setitem tensors have requires_grad (#4575)
* raise error if setitem tensors have requires_grad working on supporting this, first properly raises error * NotImplementedError
This commit is contained in:
parent
f7d08bd454
commit
0fa57b8ce9
|
@ -81,5 +81,23 @@ class TestSetitem(unittest.TestCase):
|
|||
np.testing.assert_allclose(t.numpy(), n)
|
||||
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
|
||||
|
||||
class TestWithGrad(unittest.TestCase):
|
||||
def test_no_requires_grad_works(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
x = Tensor.rand(8)
|
||||
z[:3] = x
|
||||
|
||||
def test_set_into_requires_grad(self):
|
||||
z = Tensor.rand(8, 8, requires_grad=True)
|
||||
x = Tensor.rand(8)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
z[:3] = x
|
||||
|
||||
def test_set_with_requires_grad(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
x = Tensor.rand(8, requires_grad=True)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
z[:3] = x
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -823,6 +823,8 @@ class Tensor:
|
|||
assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
|
||||
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
|
||||
assign_to = self.realize().__getitem__(indices)
|
||||
# NOTE: contiguous to prevent const folding.
|
||||
v = v.cast(assign_to.dtype)._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
||||
|
|
Loading…
Reference in New Issue