add support for retain_graph in backward (#6145)

* add support for retain_graph in backward

* fix: dont accumulate grad on non-leaf tensors

* fix order

* fix: do not delete grad on leafs

* fix linter

* fix: can't exactly match torch behaviour internally

* allow numerical room for test

* refactor
This commit is contained in:
David González Martínez 2024-08-19 01:08:31 +02:00 committed by GitHub
parent 0c5189de25
commit 724e408736
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 6 deletions

View File

@ -57,6 +57,26 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
# A simple test is to check that we can accumulate gradients (run backward twice or more times)
# This will only work if retain_graph works.
def test_retain_graph(self):
x = Tensor(x_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
m = Tensor(m_init)
out = x.dot(W).relu()
out = out.log_softmax()
out = out.mul(m).add(m).sum()
out.backward(retain_graph=True)
xgrad,wgrad = x.grad.numpy(), W.grad.numpy()
out.backward(retain_graph=True)
xgrad2,wgrad2 = x.grad.numpy(), W.grad.numpy()
out.backward() # no need to retain again since we will not re-run backward
xgrad3,wgrad3 = x.grad.numpy(), W.grad.numpy()
np.testing.assert_allclose(xgrad3, xgrad * 3., atol=1e-6)
np.testing.assert_allclose(wgrad3, wgrad * 3., atol=1e-6)
np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6)
np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6)
@unittest.expectedFailure
def test_second_order_backward_pass(self):
def test_pytorch():

View File

@ -749,23 +749,26 @@ class Tensor:
def _deepwalk(self):
def _walk(node, visited):
visited.add(node)
if getattr(node, "_ctx", None):
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in node._ctx.parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
t.sum().backward()
print(t.grad.numpy())
```
"""
toposorted = self._deepwalk()
if gradient is None:
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
@ -774,8 +777,7 @@ class Tensor:
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
self.grad = gradient
for t0 in reversed(self._deepwalk()):
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
grads = t0._ctx.backward(t0.grad.lazydata)
@ -786,7 +788,7 @@ class Tensor:
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
if not retain_graph: del t0._ctx
return self
# ***** movement low level ops *****