From c225e62dd22e7bd42c8258458223770694292cbf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 9 Dec 2020 02:52:28 -0800 Subject: [PATCH] touchups --- examples/serious_mnist.py | 1 - test/test_gc.py | 2 +- tinygrad/tensor.py | 10 +++------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py index 4fbee4f7..1137f07a 100644 --- a/examples/serious_mnist.py +++ b/examples/serious_mnist.py @@ -44,7 +44,6 @@ class SeriousModel: if __name__ == "__main__": model = SeriousModel() params = get_parameters(model) - print(len(params)) if GPU: [x.cuda_() for x in params] optimizer = optim.Adam(params, lr=0.001) diff --git a/test/test_gc.py b/test/test_gc.py index 56076bac..7aa90921 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -4,7 +4,7 @@ import unittest from tinygrad.tensor import Tensor, GPU def tensors_allocated(): - return sum([isinstance(x, Tensor) for x in gc.get_objects()]) + return sum([isinstance(x, Tensor) for x in gc.get_objects()]) class TestGC(unittest.TestCase): gpu = False diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index df73b22e..4c33341b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -122,9 +122,7 @@ class Tensor: # ***** toposort and backward pass ***** - def deepwalk(self, visited=None, nodes=None): - if visited == None and nodes == None: - visited, nodes = set(), [] + def deepwalk(self, visited: set, nodes: list): visited.add(self) if self._ctx: [i.deepwalk(visited, nodes) for i in self._ctx.parents if i not in visited] @@ -132,15 +130,13 @@ class Tensor: return nodes def backward(self): - if self._ctx is None: - return + assert self.shape == (1,) # fill in the first grad with one # this is "implicit gradient creation" - assert self.shape == (1,) self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False) - for t0 in reversed(self.deepwalk()): + for t0 in reversed(self.deepwalk(set(), [])): assert (t0.grad is not None) with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True): grads = t0._ctx.backward(t0._ctx, t0.grad.data)