diff --git a/examples/mlx/beautiful_mnist_mlx.py b/examples/other_mnist/beautiful_mnist_mlx.py similarity index 100% rename from examples/mlx/beautiful_mnist_mlx.py rename to examples/other_mnist/beautiful_mnist_mlx.py diff --git a/examples/other_mnist/beautiful_mnist_torch.py b/examples/other_mnist/beautiful_mnist_torch.py new file mode 100644 index 00000000..862e48d6 --- /dev/null +++ b/examples/other_mnist/beautiful_mnist_torch.py @@ -0,0 +1,55 @@ +from tinygrad import dtypes +from tinygrad.helpers import trange +from tinygrad.nn.datasets import mnist +import torch +from torch import nn, optim + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(1, 32, 5) + self.c2 = nn.Conv2d(32, 32, 5) + self.bn1 = nn.BatchNorm2d(32) + self.m1 = nn.MaxPool2d(2) + self.c3 = nn.Conv2d(32, 64, 3) + self.c4 = nn.Conv2d(64, 64, 3) + self.bn2 = nn.BatchNorm2d(64) + self.m2 = nn.MaxPool2d(2) + self.lin = nn.Linear(576, 10) + def forward(self, x): + x = nn.functional.relu(self.c1(x)) + x = nn.functional.relu(self.c2(x), 0) + x = self.m1(self.bn1(x)) + x = nn.functional.relu(self.c3(x), 0) + x = nn.functional.relu(self.c4(x), 0) + x = self.m2(self.bn2(x)) + return self.lin(torch.flatten(x, 1)) + +if __name__ == "__main__": + mps_device = torch.device("mps") + X_train, Y_train, X_test, Y_test = mnist() + X_train = torch.tensor(X_train.float().numpy(), device=mps_device) + Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=mps_device) + X_test = torch.tensor(X_test.float().numpy(), device=mps_device) + Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=mps_device) + + model = Model().to(mps_device) + optimizer = optim.Adam(model.parameters(), 1e-3) + + loss_fn = nn.CrossEntropyLoss() + #@torch.compile + def step(samples): + X,Y = X_train[samples], Y_train[samples] + out = model(X) + loss = loss_fn(out, Y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + return loss + + test_acc = float('nan') + for i in (t:=trange(70)): + samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well + loss = step(samples) + if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item() + t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")