faster deepwalk

This commit is contained in:
George Hotz 2022-01-15 20:57:57 -08:00
parent 7025c9bbeb
commit 931500a098
1 changed files with 5 additions and 4 deletions

View File

@ -110,13 +110,14 @@ class Tensor:
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):
visited, nodes = set(), []
def _deepwalk(node):
visited.add(node)
if node._ctx:
[_deepwalk(i, visited, nodes) for i in node._ctx.parents if i not in visited]
[_deepwalk(i) for i in node._ctx.parents if i not in visited]
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])
_deepwalk(self)
return nodes
def backward(self):
assert self.shape == (1,)