Topological sort, zero_grads (#119)

* Topological sort, zero_grads

* Bug fix, add test

* Add zero_grads

* Put deepwalk function in backward

* Move zero_grad to optim

* Fix gradcheck hack

Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
adamritter 2020-11-16 04:25:29 +00:00 committed by GitHub
parent a35425189d
commit 5ea3d76dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 15 deletions

View File

@ -56,6 +56,7 @@ class TinyConvNet:
def train(model, optim, steps, BS=128, gpu=False):
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
optim.zero_grad()
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)

View File

@ -63,6 +63,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
def test_topo_sort(self):
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),

View File

@ -11,6 +11,8 @@ def jacobian(func, input):
J = np.zeros((jo,ji), dtype=np.float32)
for o in range(jo):
input.grad = None
output=func(input)
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()

View File

@ -7,6 +7,10 @@ class Optimizer:
def __init__(self, params):
self.params = params
def zero_grad(self):
for param in self.params:
param.grad = None # PyTorch defaults to set to 0
class SGD(Optimizer):
def __init__(self, params, lr=0.001):
super(SGD, self).__init__(params)

View File

@ -109,7 +109,6 @@ class Tensor:
return Tensor(np.eye(dim).astype(np.float32))
def backward(self, allow_fill=True):
#print("running backward on", self)
if self._ctx is None:
return
@ -119,21 +118,30 @@ class Tensor:
assert self.data.shape == (1,)
self.grad = Tensor(np.ones(self.data.shape, dtype=self.data.dtype), gpu=self.gpu)
assert(self.grad is not None)
visited, nodes = set(), []
def deepwalk(node):
visited.add(self)
if node._ctx:
for i in node._ctx.parents:
if i not in visited:
deepwalk(i)
nodes.append(node)
deepwalk(self)
with ProfileOp(self._ctx.__class__.__name__, [self.grad], backward=True):
grads = self._ctx.backward(self._ctx, self.grad.data)
if len(self._ctx.parents) == 1:
grads = [grads]
for t,g in zip(self._ctx.parents, grads):
if g is None:
continue
if g.shape != t.data.shape:
print("grad shape must match tensor shape in %r, %r != %r" %
(self._ctx, g.shape, t.data.shape))
assert(False)
t.grad = Tensor(g)
t.backward(False)
for t0 in reversed(nodes):
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True):
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
if len(t0._ctx.parents) == 1:
grads = [grads]
for t,g in zip(t0._ctx.parents, grads):
if g is None:
continue
if g.shape != t.data.shape:
print("grad shape must match tensor shape in %r, %r != %r" %
(self._ctx, g.shape, t.data.shape))
assert(False)
t.grad = Tensor(g) if t.grad is None else (t.grad + Tensor(g))
# ***** tinygrad supports CPU and GPU *****