diff --git a/extra/training.py b/extra/training.py index 549572d6..d207e052 100644 --- a/extra/training.py +++ b/extra/training.py @@ -15,7 +15,7 @@ def sparse_categorical_crossentropy(out, Y): return out.mul(y).mean() def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy, - transform=lambda x: x, target_transform=lambda x: x): + transform=lambda x: x, target_transform=lambda x: x, noloss=False): Tensor.training = True losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): @@ -31,14 +31,15 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric loss.backward() optim.step() - cat = np.argmax(out.cpu().data, axis=-1) - accuracy = (cat == y).mean() - # printing - loss = loss.detach().cpu().data - losses.append(loss) - accuracies.append(accuracy) - t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + if not noloss: + cat = np.argmax(out.cpu().data, axis=-1) + accuracy = (cat == y).mean() + + loss = loss.detach().cpu().data + losses.append(loss) + accuracies.append(accuracy) + t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x, target_transform=lambda y: y): diff --git a/test/test_mnist.py b/test/test_mnist.py index 6ccb295c..3b760698 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -75,7 +75,8 @@ class TestMNIST(unittest.TestCase): np.random.seed(1337) model = TinyConvNet() optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=1) + train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True) + for p in model.parameters(): p.realize() def test_conv(self): np.random.seed(1337) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5e3893c5..bf16a581 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -82,10 +82,11 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device if nm(ret) not in G.nodes: G.add_node(nm(ret)) if getattr(ret, "st", None) is not None and not ret.st.contiguous: - G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides)) + #G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides)) + G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides)) dashed = True - else: - G.nodes[nm(ret)]['label'] = str(ret.shape) + elif optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(inp[0].shape)+"\n"+str(ret.shape) + else: G.nodes[nm(ret)]['label'] = str(ret.shape) G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff" G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled' @@ -250,8 +251,8 @@ class LazyBuffer: # NOTE: if ret is in the cache, it can already be realized if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous: root = get_lazybuffers(ret.op)[0] - if ret.st.shape == root.shape and root.st.contiguous: - return root + if root.st.contiguous and root != x: + return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root return ret