From 5ea3d76dfb54585dc0ecee7ec1ad567c165c922c Mon Sep 17 00:00:00 2001 From: adamritter <58403584+adamritter@users.noreply.github.com> Date: Mon, 16 Nov 2020 04:25:29 +0000 Subject: [PATCH] Topological sort, zero_grads (#119) * Topological sort, zero_grads * Bug fix, add test * Add zero_grads * Put deepwalk function in backward * Move zero_grad to optim * Fix gradcheck hack Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com> --- test/test_mnist.py | 1 + test/test_ops.py | 2 ++ tinygrad/gradcheck.py | 2 ++ tinygrad/optim.py | 4 ++++ tinygrad/tensor.py | 38 +++++++++++++++++++++++--------------- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/test/test_mnist.py b/test/test_mnist.py index 2eef2df1..5c0bd6b4 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -56,6 +56,7 @@ class TinyConvNet: def train(model, optim, steps, BS=128, gpu=False): losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): + optim.zero_grad() samp = np.random.randint(0, X_train.shape[0], size=(BS)) x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu) diff --git a/test/test_ops.py b/test/test_ops.py index cf7b0f53..0ce6ac84 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -63,6 +63,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu) def test_tanh(self): helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) + def test_topo_sort(self): + helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), diff --git a/tinygrad/gradcheck.py b/tinygrad/gradcheck.py index 71aa9739..e56873c5 100644 --- a/tinygrad/gradcheck.py +++ b/tinygrad/gradcheck.py @@ -11,6 +11,8 @@ def jacobian(func, input): J = np.zeros((jo,ji), dtype=np.float32) for o in range(jo): + input.grad = None + output=func(input) # tinygrad doesn't support slicing, tiny-hack to select # the needed scalar an backpropagate only through it o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum() diff --git a/tinygrad/optim.py b/tinygrad/optim.py index 99bd6534..a78f860d 100644 --- a/tinygrad/optim.py +++ b/tinygrad/optim.py @@ -7,6 +7,10 @@ class Optimizer: def __init__(self, params): self.params = params + def zero_grad(self): + for param in self.params: + param.grad = None # PyTorch defaults to set to 0 + class SGD(Optimizer): def __init__(self, params, lr=0.001): super(SGD, self).__init__(params) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e77a5586..3cc741d4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -109,7 +109,6 @@ class Tensor: return Tensor(np.eye(dim).astype(np.float32)) def backward(self, allow_fill=True): - #print("running backward on", self) if self._ctx is None: return @@ -119,21 +118,30 @@ class Tensor: assert self.data.shape == (1,) self.grad = Tensor(np.ones(self.data.shape, dtype=self.data.dtype), gpu=self.gpu) - assert(self.grad is not None) + visited, nodes = set(), [] + def deepwalk(node): + visited.add(self) + if node._ctx: + for i in node._ctx.parents: + if i not in visited: + deepwalk(i) + nodes.append(node) + deepwalk(self) - with ProfileOp(self._ctx.__class__.__name__, [self.grad], backward=True): - grads = self._ctx.backward(self._ctx, self.grad.data) - if len(self._ctx.parents) == 1: - grads = [grads] - for t,g in zip(self._ctx.parents, grads): - if g is None: - continue - if g.shape != t.data.shape: - print("grad shape must match tensor shape in %r, %r != %r" % - (self._ctx, g.shape, t.data.shape)) - assert(False) - t.grad = Tensor(g) - t.backward(False) + for t0 in reversed(nodes): + assert (t0.grad is not None) + with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True): + grads = t0._ctx.backward(t0._ctx, t0.grad.data) + if len(t0._ctx.parents) == 1: + grads = [grads] + for t,g in zip(t0._ctx.parents, grads): + if g is None: + continue + if g.shape != t.data.shape: + print("grad shape must match tensor shape in %r, %r != %r" % + (self._ctx, g.shape, t.data.shape)) + assert(False) + t.grad = Tensor(g) if t.grad is None else (t.grad + Tensor(g)) # ***** tinygrad supports CPU and GPU *****