mirror of https://github.com/commaai/tinygrad.git
prune graph
This commit is contained in:
parent
eda6f071b2
commit
f76d41812b
|
@ -29,6 +29,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
|||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if noloss: del loss
|
||||
optim.step()
|
||||
|
||||
# printing
|
||||
|
|
|
@ -11,11 +11,8 @@ def model_step(lm):
|
|||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
#out = loss.detach().numpy()
|
||||
for p in optimizer.params:
|
||||
p.realize()
|
||||
#x.grad.realize()
|
||||
Tensor.training = False
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,13 @@ class TestMNIST(unittest.TestCase):
|
|||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=5)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=3)
|
||||
|
||||
def test_sgd_sixstep(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=6, noloss=True)
|
||||
|
||||
def test_adam_onestep(self):
|
||||
np.random.seed(1337)
|
||||
|
|
|
@ -16,6 +16,9 @@ class BatchNorm2D:
|
|||
|
||||
def __call__(self, x):
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this
|
||||
x_detached = x.detach()
|
||||
batch_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
|
|
|
@ -51,6 +51,17 @@ if GRAPH:
|
|||
G = nx.DiGraph()
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
if int(os.getenv("PRUNEGRAPH", 0)):
|
||||
dead_nodes = []
|
||||
for n in G.nodes:
|
||||
if G.nodes[n]['fillcolor'] in ["#80ff8080", "#80ff80"]:
|
||||
for x,_ in G.in_edges(n):
|
||||
for _,y in G.out_edges(n):
|
||||
G.add_edge(x, y)
|
||||
dead_nodes.append(n)
|
||||
if G.nodes[n]['fillcolor'] in ["#FFFF8080", "#FFFF80"]:
|
||||
dead_nodes.append(n)
|
||||
for n in dead_nodes: G.remove_node(n)
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
|
||||
# -Gnslimit=100 can make it finish, but you won't like results
|
||||
|
|
Loading…
Reference in New Issue