join expands

This commit is contained in:
George Hotz 2022-07-17 13:42:05 -07:00
parent cfabbbd6bb
commit 73b0471b25
3 changed files with 8 additions and 2 deletions

View File

@ -31,6 +31,9 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
loss.backward()
optim.step()
# TODO: corealize
for p in optim.params: p.realize()
# printing
if not noloss:
cat = np.argmax(out.cpu().data, axis=-1)
@ -40,6 +43,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x,
target_transform=lambda y: y):

View File

@ -57,7 +57,7 @@ class TestMNIST(unittest.TestCase):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, BS=69, steps=3)
train(model, X_train, Y_train, optimizer, BS=69, steps=5)
def test_adam_onestep(self):
np.random.seed(1337)

View File

@ -225,6 +225,7 @@ class LazyBuffer:
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape))) if x.shape != tuple(new_shape) else x
# syntactic sugar around PAD and SHRINK
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
def slice(x:LazyBuffer, arg):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
return x.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
@ -242,6 +243,7 @@ class LazyBuffer:
# two ops in a row is one op
if op == MovementOps.RESHAPE and x.realized is None and x.op.op == MovementOps.RESHAPE: return x.op.src[0].movement_op(op, arg)
if op == MovementOps.EXPAND and x.realized is None and x.op.op == MovementOps.EXPAND: return x.op.src[0].movement_op(op, arg)
if op == MovementOps.PERMUTE and x.realized is None and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(x.op.arg[i] for i in arg))
if op == MovementOps.SHRINK and x.realized is None and x.op.op == MovementOps.SHRINK: return x.op.src[0].movement_op(op, arg)
if op == MovementOps.PAD and x.realized is None and x.op.op == MovementOps.PAD: return x.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(x.op.arg, arg)))
@ -249,7 +251,7 @@ class LazyBuffer:
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and ShapeTracker(x.shape).movement_op(op, arg).contiguous: return x.movement_op(MovementOps.RESHAPE, tuple(x.shape[i] for i in arg))
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op != MovementOps.PAD):
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op != MovementOps.STRIDED:
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)