mirror of https://github.com/commaai/tinygrad.git
137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
|
import sys
|
|
import numpy as np
|
|
from tinygrad.nn.state import get_parameters
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.nn import BatchNorm2d, optim
|
|
from tinygrad.helpers import getenv
|
|
from extra.datasets import fetch_mnist
|
|
from extra.augment import augment_img
|
|
from extra.training import train, evaluate
|
|
GPU = getenv("GPU")
|
|
QUICK = getenv("QUICK")
|
|
DEBUG = getenv("DEBUG")
|
|
|
|
class SqueezeExciteBlock2D:
|
|
def __init__(self, filters):
|
|
self.filters = filters
|
|
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
|
|
self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
|
|
self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
|
|
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
|
|
|
def __call__(self, input):
|
|
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
|
se = se.reshape(shape=(-1, self.filters))
|
|
se = se.dot(self.weight1) + self.bias1
|
|
se = se.relu()
|
|
se = se.dot(self.weight2) + self.bias2
|
|
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
|
|
se = input.mul(se)
|
|
return se
|
|
|
|
class ConvBlock:
|
|
def __init__(self, h, w, inp, filters=128, conv=3):
|
|
self.h, self.w = h, w
|
|
self.inp = inp
|
|
#init weights
|
|
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
|
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
|
#init layers
|
|
self._bn = BatchNorm2d(128)
|
|
self._seb = SqueezeExciteBlock2D(filters)
|
|
|
|
def __call__(self, input):
|
|
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
|
|
for cweight, cbias in zip(self.cweights, self.cbiases):
|
|
x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
|
|
x = self._bn(x)
|
|
x = self._seb(x)
|
|
return x
|
|
|
|
class BigConvNet:
|
|
def __init__(self):
|
|
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
|
self.weight1 = Tensor.scaled_uniform(128,10)
|
|
self.weight2 = Tensor.scaled_uniform(128,10)
|
|
|
|
def parameters(self):
|
|
if DEBUG: #keeping this for a moment
|
|
pars = [par for par in get_parameters(self) if par.requires_grad]
|
|
no_pars = 0
|
|
for par in pars:
|
|
print(par.shape)
|
|
no_pars += np.prod(par.shape)
|
|
print('no of parameters', no_pars)
|
|
return pars
|
|
else:
|
|
return get_parameters(self)
|
|
|
|
def save(self, filename):
|
|
with open(filename+'.npy', 'wb') as f:
|
|
for par in get_parameters(self):
|
|
#if par.requires_grad:
|
|
np.save(f, par.numpy())
|
|
|
|
def load(self, filename):
|
|
with open(filename+'.npy', 'rb') as f:
|
|
for par in get_parameters(self):
|
|
#if par.requires_grad:
|
|
try:
|
|
par.numpy()[:] = np.load(f)
|
|
if GPU:
|
|
par.gpu()
|
|
except:
|
|
print('Could not load parameter')
|
|
|
|
def forward(self, x):
|
|
x = self.conv[0](x)
|
|
x = self.conv[1](x)
|
|
x = x.avg_pool2d(kernel_size=(2,2))
|
|
x = self.conv[2](x)
|
|
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
|
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
|
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
|
|
return xo
|
|
|
|
|
|
if __name__ == "__main__":
|
|
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
|
|
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
|
|
BS = 32
|
|
|
|
lmbd = 0.00025
|
|
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
|
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
|
steps = len(X_train)//BS
|
|
np.random.seed(1337)
|
|
if QUICK:
|
|
steps = 1
|
|
X_test, Y_test = X_test[:BS], Y_test[:BS]
|
|
|
|
model = BigConvNet()
|
|
|
|
if len(sys.argv) > 1:
|
|
try:
|
|
model.load(sys.argv[1])
|
|
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
|
|
evaluate(model, X_test, Y_test, BS=BS)
|
|
except:
|
|
print('could not load weights "'+sys.argv[1]+'".')
|
|
|
|
if GPU:
|
|
params = get_parameters(model)
|
|
[x.gpu_() for x in params]
|
|
|
|
for lr, epochs in zip(lrs, epochss):
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
for epoch in range(1,epochs+1):
|
|
#first epoch without augmentation
|
|
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
|
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
|
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
|
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|