mirror of https://github.com/commaai/tinygrad.git
add support for retain_graph in backward (#6145)
* add support for retain_graph in backward * fix: dont accumulate grad on non-leaf tensors * fix order * fix: do not delete grad on leafs * fix linter * fix: can't exactly match torch behaviour internally * allow numerical room for test * refactor
This commit is contained in:
parent
0c5189de25
commit
724e408736
|
@ -57,6 +57,26 @@ class TestTinygrad(unittest.TestCase):
|
|||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
# A simple test is to check that we can accumulate gradients (run backward twice or more times)
|
||||
# This will only work if retain_graph works.
|
||||
def test_retain_graph(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
m = Tensor(m_init)
|
||||
out = x.dot(W).relu()
|
||||
out = out.log_softmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward(retain_graph=True)
|
||||
xgrad,wgrad = x.grad.numpy(), W.grad.numpy()
|
||||
out.backward(retain_graph=True)
|
||||
xgrad2,wgrad2 = x.grad.numpy(), W.grad.numpy()
|
||||
out.backward() # no need to retain again since we will not re-run backward
|
||||
xgrad3,wgrad3 = x.grad.numpy(), W.grad.numpy()
|
||||
np.testing.assert_allclose(xgrad3, xgrad * 3., atol=1e-6)
|
||||
np.testing.assert_allclose(wgrad3, wgrad * 3., atol=1e-6)
|
||||
np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6)
|
||||
np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_second_order_backward_pass(self):
|
||||
def test_pytorch():
|
||||
|
|
|
@ -749,23 +749,26 @@ class Tensor:
|
|||
def _deepwalk(self):
|
||||
def _walk(node, visited):
|
||||
visited.add(node)
|
||||
if getattr(node, "_ctx", None):
|
||||
# if tensor is not leaf, reset grad
|
||||
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
||||
if ctx:
|
||||
for i in node._ctx.parents:
|
||||
if i not in visited: yield from _walk(i, visited)
|
||||
yield node
|
||||
return list(_walk(self, set()))
|
||||
|
||||
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
||||
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
||||
"""
|
||||
Propagates the gradient of a tensor backwards through the computation graph.
|
||||
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
||||
|
||||
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
t.sum().backward()
|
||||
print(t.grad.numpy())
|
||||
```
|
||||
"""
|
||||
toposorted = self._deepwalk()
|
||||
if gradient is None:
|
||||
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
|
@ -774,8 +777,7 @@ class Tensor:
|
|||
|
||||
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
||||
self.grad = gradient
|
||||
|
||||
for t0 in reversed(self._deepwalk()):
|
||||
for t0 in reversed(toposorted):
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
|
@ -786,7 +788,7 @@ class Tensor:
|
|||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
del t0._ctx
|
||||
if not retain_graph: del t0._ctx
|
||||
return self
|
||||
|
||||
# ***** movement low level ops *****
|
||||
|
|
Loading…
Reference in New Issue