diff --git a/test/test_tensor.py b/test/test_tensor.py index af27b0aa..1af4dad2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -57,6 +57,26 @@ class TestTinygrad(unittest.TestCase): for x,y in zip(test_tinygrad(), test_pytorch()): np.testing.assert_allclose(x, y, atol=1e-5) + # A simple test is to check that we can accumulate gradients (run backward twice or more times) + # This will only work if retain_graph works. + def test_retain_graph(self): + x = Tensor(x_init, requires_grad=True) + W = Tensor(W_init, requires_grad=True) + m = Tensor(m_init) + out = x.dot(W).relu() + out = out.log_softmax() + out = out.mul(m).add(m).sum() + out.backward(retain_graph=True) + xgrad,wgrad = x.grad.numpy(), W.grad.numpy() + out.backward(retain_graph=True) + xgrad2,wgrad2 = x.grad.numpy(), W.grad.numpy() + out.backward() # no need to retain again since we will not re-run backward + xgrad3,wgrad3 = x.grad.numpy(), W.grad.numpy() + np.testing.assert_allclose(xgrad3, xgrad * 3., atol=1e-6) + np.testing.assert_allclose(wgrad3, wgrad * 3., atol=1e-6) + np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6) + np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6) + @unittest.expectedFailure def test_second_order_backward_pass(self): def test_pytorch(): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 51ab1889..6b6e1f44 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -749,23 +749,26 @@ class Tensor: def _deepwalk(self): def _walk(node, visited): visited.add(node) - if getattr(node, "_ctx", None): + # if tensor is not leaf, reset grad + if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None + if ctx: for i in node._ctx.parents: if i not in visited: yield from _walk(i, visited) yield node return list(_walk(self, set())) - def backward(self, gradient:Optional[Tensor]=None) -> Tensor: + def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor: """ Propagates the gradient of a tensor backwards through the computation graph. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. - + If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) t.sum().backward() print(t.grad.numpy()) ``` """ + toposorted = self._deepwalk() if gradient is None: assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous @@ -774,8 +777,7 @@ class Tensor: assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}" self.grad = gradient - - for t0 in reversed(self._deepwalk()): + for t0 in reversed(toposorted): if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None) grads = t0._ctx.backward(t0.grad.lazydata) @@ -786,7 +788,7 @@ class Tensor: if g is not None and t.requires_grad: assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" t.grad = g if t.grad is None else (t.grad + g) - del t0._ctx + if not retain_graph: del t0._ctx return self # ***** movement low level ops *****