prune graph

This commit is contained in:
George Hotz 2022-07-17 15:38:43 -07:00
parent eda6f071b2
commit f76d41812b
5 changed files with 23 additions and 5 deletions

View File

@ -29,6 +29,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss: del loss
optim.step()
# printing

View File

@ -11,11 +11,8 @@ def model_step(lm):
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
optimizer.zero_grad()
loss.backward()
del x,loss
optimizer.step()
#out = loss.detach().numpy()
for p in optimizer.params:
p.realize()
#x.grad.realize()
Tensor.training = False

View File

@ -57,7 +57,13 @@ class TestMNIST(unittest.TestCase):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, BS=69, steps=5)
train(model, X_train, Y_train, optimizer, BS=69, steps=3)
def test_sgd_sixstep(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, BS=69, steps=6, noloss=True)
def test_adam_onestep(self):
np.random.seed(1337)

View File

@ -16,6 +16,9 @@ class BatchNorm2D:
def __call__(self, x):
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this
x_detached = x.detach()
batch_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))

View File

@ -51,6 +51,17 @@ if GRAPH:
G = nx.DiGraph()
def save_graph_exit():
for k,v in cnts.items(): print(k, v)
if int(os.getenv("PRUNEGRAPH", 0)):
dead_nodes = []
for n in G.nodes:
if G.nodes[n]['fillcolor'] in ["#80ff8080", "#80ff80"]:
for x,_ in G.in_edges(n):
for _,y in G.out_edges(n):
G.add_edge(x, y)
dead_nodes.append(n)
if G.nodes[n]['fillcolor'] in ["#FFFF8080", "#FFFF80"]:
dead_nodes.append(n)
for n in dead_nodes: G.remove_node(n)
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
# -Gnslimit=100 can make it finish, but you won't like results