mirror of https://github.com/commaai/tinygrad.git
beautiful_mnist in torch
This commit is contained in:
parent
604fb60143
commit
aade18d20c
|
@ -0,0 +1,55 @@
|
|||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c1 = nn.Conv2d(1, 32, 5)
|
||||
self.c2 = nn.Conv2d(32, 32, 5)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.m1 = nn.MaxPool2d(2)
|
||||
self.c3 = nn.Conv2d(32, 64, 3)
|
||||
self.c4 = nn.Conv2d(64, 64, 3)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
self.m2 = nn.MaxPool2d(2)
|
||||
self.lin = nn.Linear(576, 10)
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.c1(x))
|
||||
x = nn.functional.relu(self.c2(x), 0)
|
||||
x = self.m1(self.bn1(x))
|
||||
x = nn.functional.relu(self.c3(x), 0)
|
||||
x = nn.functional.relu(self.c4(x), 0)
|
||||
x = self.m2(self.bn2(x))
|
||||
return self.lin(torch.flatten(x, 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
mps_device = torch.device("mps")
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=mps_device)
|
||||
Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=mps_device)
|
||||
X_test = torch.tensor(X_test.float().numpy(), device=mps_device)
|
||||
Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=mps_device)
|
||||
|
||||
model = Model().to(mps_device)
|
||||
optimizer = optim.Adam(model.parameters(), 1e-3)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
#@torch.compile
|
||||
def step(samples):
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
out = model(X)
|
||||
loss = loss_fn(out, Y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
|
||||
loss = step(samples)
|
||||
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
Loading…
Reference in New Issue