2021-06-22 00:37:24 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
|
2023-08-22 22:36:24 +08:00
|
|
|
from tinygrad.nn.state import get_parameters
|
2023-04-12 16:18:39 +08:00
|
|
|
from tinygrad.nn import optim
|
2023-02-11 02:09:37 +08:00
|
|
|
from tinygrad.helpers import getenv
|
2021-06-22 00:37:24 +08:00
|
|
|
from extra.training import train, evaluate
|
2023-11-14 12:18:40 +08:00
|
|
|
from extra.models.resnet import ResNet
|
2023-07-08 01:43:44 +08:00
|
|
|
from extra.datasets import fetch_mnist
|
2021-06-22 00:37:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ComposeTransforms:
|
|
|
|
def __init__(self, trans):
|
|
|
|
self.trans = trans
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
for t in self.trans:
|
|
|
|
x = t(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|
|
|
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
|
|
|
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
2022-01-16 12:22:10 +08:00
|
|
|
classes = 10
|
|
|
|
|
2023-02-01 07:09:09 +08:00
|
|
|
TRANSFER = getenv('TRANSFER')
|
|
|
|
model = ResNet(getenv('NUM', 18), num_classes=classes)
|
2022-01-16 12:22:10 +08:00
|
|
|
if TRANSFER:
|
|
|
|
model.load_from_pretrained()
|
|
|
|
|
2023-04-12 16:18:39 +08:00
|
|
|
lr = 5e-3
|
2021-06-22 00:37:24 +08:00
|
|
|
transform = ComposeTransforms([
|
|
|
|
lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
|
|
|
|
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
|
|
|
lambda x: x / 255.0,
|
|
|
|
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
|
|
|
])
|
2023-05-11 00:01:22 +08:00
|
|
|
for _ in range(5):
|
2023-07-03 06:07:30 +08:00
|
|
|
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
2023-04-12 16:18:39 +08:00
|
|
|
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
|
|
|
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
2021-06-22 00:37:24 +08:00
|
|
|
lr /= 1.2
|
2022-01-16 12:22:10 +08:00
|
|
|
print(f'reducing lr to {lr:.7f}')
|