noop removal can replace with reshape

This commit is contained in:
George Hotz 2022-07-16 08:32:42 -07:00
parent d985217fa4
commit d04b274cd2
3 changed files with 17 additions and 14 deletions

View File

@ -15,7 +15,7 @@ def sparse_categorical_crossentropy(out, Y):
return out.mul(y).mean()
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy,
transform=lambda x: x, target_transform=lambda x: x):
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
@ -31,14 +31,15 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
loss.backward()
optim.step()
cat = np.argmax(out.cpu().data, axis=-1)
accuracy = (cat == y).mean()
# printing
loss = loss.detach().cpu().data
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
if not noloss:
cat = np.argmax(out.cpu().data, axis=-1)
accuracy = (cat == y).mean()
loss = loss.detach().cpu().data
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x,
target_transform=lambda y: y):

View File

@ -75,7 +75,8 @@ class TestMNIST(unittest.TestCase):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, BS=69, steps=1)
train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True)
for p in model.parameters(): p.realize()
def test_conv(self):
np.random.seed(1337)

View File

@ -82,10 +82,11 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
if nm(ret) not in G.nodes: G.add_node(nm(ret))
if getattr(ret, "st", None) is not None and not ret.st.contiguous:
G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
#G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
dashed = True
else:
G.nodes[nm(ret)]['label'] = str(ret.shape)
elif optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(inp[0].shape)+"\n"+str(ret.shape)
else: G.nodes[nm(ret)]['label'] = str(ret.shape)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
@ -250,8 +251,8 @@ class LazyBuffer:
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
root = get_lazybuffers(ret.op)[0]
if ret.st.shape == root.shape and root.st.contiguous:
return root
if root.st.contiguous and root != x:
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
return ret