diff --git a/extra/training.py b/extra/training.py index d207e052..8c6e4328 100644 --- a/extra/training.py +++ b/extra/training.py @@ -31,6 +31,9 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric loss.backward() optim.step() + # TODO: corealize + for p in optim.params: p.realize() + # printing if not noloss: cat = np.argmax(out.cpu().data, axis=-1) @@ -40,6 +43,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric losses.append(loss) accuracies.append(accuracy) t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x, target_transform=lambda y: y): diff --git a/test/test_mnist.py b/test/test_mnist.py index fb315a0b..fe94b303 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -57,7 +57,7 @@ class TestMNIST(unittest.TestCase): np.random.seed(1337) model = TinyBobNet() optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=3) + train(model, X_train, Y_train, optimizer, BS=69, steps=5) def test_adam_onestep(self): np.random.seed(1337) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0b5ff006..757a8209 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -225,6 +225,7 @@ class LazyBuffer: return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape))) if x.shape != tuple(new_shape) else x # syntactic sugar around PAD and SHRINK + # TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT) def slice(x:LazyBuffer, arg): padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] return x.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg))) @@ -242,6 +243,7 @@ class LazyBuffer: # two ops in a row is one op if op == MovementOps.RESHAPE and x.realized is None and x.op.op == MovementOps.RESHAPE: return x.op.src[0].movement_op(op, arg) + if op == MovementOps.EXPAND and x.realized is None and x.op.op == MovementOps.EXPAND: return x.op.src[0].movement_op(op, arg) if op == MovementOps.PERMUTE and x.realized is None and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(x.op.arg[i] for i in arg)) if op == MovementOps.SHRINK and x.realized is None and x.op.op == MovementOps.SHRINK: return x.op.src[0].movement_op(op, arg) if op == MovementOps.PAD and x.realized is None and x.op.op == MovementOps.PAD: return x.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(x.op.arg, arg))) @@ -249,7 +251,7 @@ class LazyBuffer: # some permutes are actually just reshapes if op == MovementOps.PERMUTE and ShapeTracker(x.shape).movement_op(op, arg).contiguous: return x.movement_op(MovementOps.RESHAPE, tuple(x.shape[i] for i in arg)) - if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op != MovementOps.PAD): + if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op != MovementOps.STRIDED: # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: if isinstance(y, LazyBuffer): return y.movement_op(op, arg)