mirror of https://github.com/commaai/tinygrad.git
169 lines
8.0 KiB
Python
169 lines
8.0 KiB
Python
import time
|
|
start_tm = time.perf_counter()
|
|
import math
|
|
from typing import Tuple, cast
|
|
import numpy as np
|
|
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
|
|
from tinygrad.helpers import partition, trange, getenv, Context
|
|
from extra.lr_scheduler import OneCycleLR
|
|
|
|
dtypes.default_float = dtypes.half
|
|
|
|
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
|
batchsize = getenv("BS", 1024)
|
|
bias_scaler = 64
|
|
hyp = {
|
|
'opt': {
|
|
'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
|
|
'non_bias_lr': 1.525 / 512,
|
|
'bias_decay': 6.687e-4 * batchsize/bias_scaler,
|
|
'non_bias_decay': 6.687e-4 * batchsize,
|
|
'scaling_factor': 1./9,
|
|
'percent_start': .23,
|
|
'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
|
|
},
|
|
'net': {
|
|
'whitening': {
|
|
'kernel_size': 2,
|
|
'num_examples': 50000,
|
|
},
|
|
'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
|
|
'cutmix_size': 3,
|
|
'cutmix_epochs': 6,
|
|
'pad_amount': 2,
|
|
'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
|
|
},
|
|
'misc': {
|
|
'ema': {
|
|
'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
|
|
'decay_base': .95,
|
|
'decay_pow': 3.,
|
|
'every_n_steps': 5,
|
|
},
|
|
'train_epochs': 12,
|
|
#'train_epochs': 12.1,
|
|
'device': 'cuda',
|
|
'data_location': 'data.pt',
|
|
}
|
|
}
|
|
|
|
scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
|
|
depths = {
|
|
'init': round(scaler**-1*hyp['net']['base_depth']), # 32 w/ scaler at base value
|
|
'block1': round(scaler** 0*hyp['net']['base_depth']), # 64 w/ scaler at base value
|
|
'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value
|
|
'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value
|
|
'num_classes': 10
|
|
}
|
|
whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2
|
|
|
|
class ConvGroup:
|
|
def __init__(self, channels_in, channels_out):
|
|
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
|
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
|
self.norm1 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
|
|
self.norm2 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
|
|
cast(Tensor, self.norm1.weight).requires_grad = False
|
|
cast(Tensor, self.norm2.weight).requires_grad = False
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
|
|
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
|
|
|
|
class SpeedyConvNet:
|
|
def __init__(self):
|
|
self.whiten = nn.Conv2d(3, 2*whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0, bias=False)
|
|
self.conv_group_1 = ConvGroup(2*whiten_conv_depth, depths['block1'])
|
|
self.conv_group_2 = ConvGroup(depths['block1'], depths['block2'])
|
|
self.conv_group_3 = ConvGroup(depths['block2'], depths['block3'])
|
|
self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
x = self.whiten(x).quick_gelu()
|
|
x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
|
|
return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']
|
|
|
|
if __name__ == "__main__":
|
|
# *** dataset ***
|
|
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
|
|
# TODO: without this line indexing doesn't fuse!
|
|
X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
|
|
cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
|
|
def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
|
|
return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])
|
|
|
|
# *** model ***
|
|
model = SpeedyConvNet()
|
|
state_dict = nn.state.get_state_dict(model)
|
|
|
|
#for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)
|
|
|
|
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
|
|
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
|
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
|
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
|
|
|
|
num_steps_per_epoch = X_train.size(0) // batchsize
|
|
total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])
|
|
loss_batchsize_scaler = 512/batchsize
|
|
|
|
pct_start = hyp['opt']['percent_start']
|
|
initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
|
|
final_lr_ratio = .07 # Actually pretty important, apparently!
|
|
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
|
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
|
|
|
def loss_fn(out, Y):
|
|
return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
|
|
|
@TinyJit
|
|
@Tensor.train()
|
|
def train_step(idxs:Tensor) -> Tensor:
|
|
with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
|
|
X = X_train[idxs]
|
|
Y = Y_train[idxs].realize(X)
|
|
X, Y = preprocess(X, Y)
|
|
out = model(X)
|
|
loss = loss_fn(out, Y)
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
opt.step()
|
|
lr_sched_bias.step()
|
|
lr_sched_non_bias.step()
|
|
return loss / (batchsize*loss_batchsize_scaler)
|
|
|
|
eval_batchsize = 2500
|
|
@TinyJit
|
|
@Tensor.test()
|
|
def val_step() -> Tuple[Tensor, Tensor]:
|
|
# TODO with Tensor.no_grad()
|
|
Tensor.no_grad = True
|
|
loss, acc = [], []
|
|
for i in range(0, X_test.size(0), eval_batchsize):
|
|
X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
|
|
out = model(X)
|
|
loss.append(loss_fn(out, Y))
|
|
acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
|
|
ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
|
|
Tensor.no_grad = False
|
|
return ret
|
|
|
|
np.random.seed(1337)
|
|
for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
|
|
# TODO: move to tinygrad
|
|
gst = time.perf_counter()
|
|
idxs = np.arange(X_train.shape[0])
|
|
np.random.shuffle(idxs)
|
|
tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) # NOTE: long doesn't fold
|
|
train_loss:float = 0
|
|
for epoch_step in (t:=trange(num_steps_per_epoch)):
|
|
st = time.perf_counter()
|
|
GlobalCounters.reset()
|
|
loss = train_step(tidxs[epoch_step].contiguous()).float().item()
|
|
t.set_description(f"*** loss: {loss:5.3f} lr: {opt_non_bias.lr.item():.6f}"
|
|
f" tm: {(et:=(time.perf_counter()-st))*1000:6.2f} ms {GlobalCounters.global_ops/(1e9*et):7.0f} GFLOPS")
|
|
train_loss += loss
|
|
gmt = time.perf_counter()
|
|
GlobalCounters.reset()
|
|
val_loss, acc = [x.float().item() for x in val_step()]
|
|
get = time.perf_counter()
|
|
print(f"\033[F*** epoch {epoch:3d} tm: {(gmt-gst):5.2f} s val_tm: {(get-gmt):5.2f} s train_loss: {train_loss/num_steps_per_epoch:5.3f} val_loss: {val_loss:5.3f} eval acc: {acc*100:5.2f}% @ {get-start_tm:6.2f} s ")
|