mirror of https://github.com/commaai/tinygrad.git
CIFAR trainer + various bugfixes / improvements (#6146)
* move cifar into datasets * support for pathlib Tensors, tar_extract, and fetch gunzip * too early for Device.DEFAULT * simpler hlb_cifar + .to(None) is default * new compiler failure, start beautiful_cifar * beautiful cifar runs but is broken * jit train step * cleaner * std_mean, not mean_std * more correct * fast indexing * don't print that * torch load broken * add eval * nicer bar * decoraters are the way to do this * bounds check the load * a few ops * batchnorm bugfix, if track_running_stats is False, use online estimate * full timing * fix fusion * unneeded realize * master tensor
This commit is contained in:
parent
296368f0dd
commit
9faf205601
|
@ -0,0 +1,168 @@
|
|||
import time
|
||||
start_tm = time.perf_counter()
|
||||
import math
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.helpers import partition, trange, getenv, Context
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
|
||||
dtypes.default_float = dtypes.half
|
||||
|
||||
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
batchsize = getenv("BS", 1024)
|
||||
bias_scaler = 64
|
||||
hyp = {
|
||||
'opt': {
|
||||
'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
|
||||
'non_bias_lr': 1.525 / 512,
|
||||
'bias_decay': 6.687e-4 * batchsize/bias_scaler,
|
||||
'non_bias_decay': 6.687e-4 * batchsize,
|
||||
'scaling_factor': 1./9,
|
||||
'percent_start': .23,
|
||||
'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
|
||||
},
|
||||
'net': {
|
||||
'whitening': {
|
||||
'kernel_size': 2,
|
||||
'num_examples': 50000,
|
||||
},
|
||||
'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
|
||||
'cutmix_size': 3,
|
||||
'cutmix_epochs': 6,
|
||||
'pad_amount': 2,
|
||||
'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
|
||||
},
|
||||
'misc': {
|
||||
'ema': {
|
||||
'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
|
||||
'decay_base': .95,
|
||||
'decay_pow': 3.,
|
||||
'every_n_steps': 5,
|
||||
},
|
||||
'train_epochs': 12,
|
||||
#'train_epochs': 12.1,
|
||||
'device': 'cuda',
|
||||
'data_location': 'data.pt',
|
||||
}
|
||||
}
|
||||
|
||||
scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
|
||||
depths = {
|
||||
'init': round(scaler**-1*hyp['net']['base_depth']), # 32 w/ scaler at base value
|
||||
'block1': round(scaler** 0*hyp['net']['base_depth']), # 64 w/ scaler at base value
|
||||
'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value
|
||||
'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value
|
||||
'num_classes': 10
|
||||
}
|
||||
whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.norm1 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
|
||||
self.norm2 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
|
||||
self.norm1.weight.requires_grad = False
|
||||
self.norm2.weight.requires_grad = False
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
|
||||
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
|
||||
|
||||
class SpeedyConvNet:
|
||||
def __init__(self):
|
||||
self.whiten = nn.Conv2d(3, 2*whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0, bias=False)
|
||||
self.conv_group_1 = ConvGroup(2*whiten_conv_depth, depths['block1'])
|
||||
self.conv_group_2 = ConvGroup(depths['block1'], depths['block2'])
|
||||
self.conv_group_3 = ConvGroup(depths['block2'], depths['block3'])
|
||||
self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.whiten(x).quick_gelu()
|
||||
x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
|
||||
return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']
|
||||
|
||||
if __name__ == "__main__":
|
||||
# *** dataset ***
|
||||
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
|
||||
# TODO: without this line indexing doesn't fuse!
|
||||
X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
|
||||
cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
|
||||
def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])
|
||||
|
||||
# *** model ***
|
||||
model = SpeedyConvNet()
|
||||
state_dict = nn.state.get_state_dict(model)
|
||||
|
||||
#for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)
|
||||
|
||||
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
|
||||
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
||||
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
|
||||
|
||||
num_steps_per_epoch = X_train.size(0) // batchsize
|
||||
total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])
|
||||
loss_batchsize_scaler = 512/batchsize
|
||||
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
|
||||
final_lr_ratio = .07 # Actually pretty important, apparently!
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
||||
|
||||
def loss_fn(out, Y):
|
||||
return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(idxs:Tensor) -> Tensor:
|
||||
with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
|
||||
X = X_train[idxs]
|
||||
Y = Y_train[idxs].realize(X)
|
||||
X, Y = preprocess(X, Y)
|
||||
out = model(X)
|
||||
loss = loss_fn(out, Y)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
lr_sched_bias.step()
|
||||
lr_sched_non_bias.step()
|
||||
return loss / (batchsize*loss_batchsize_scaler)
|
||||
|
||||
eval_batchsize = 2500
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def val_step() -> Tensor:
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
loss, acc = [], []
|
||||
for i in range(0, X_test.size(0), eval_batchsize):
|
||||
X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
|
||||
out = model(X)
|
||||
loss.append(loss_fn(out, Y))
|
||||
acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
|
||||
ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
|
||||
np.random.seed(1337)
|
||||
for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
|
||||
# TODO: move to tinygrad
|
||||
gst = time.perf_counter()
|
||||
idxs = np.arange(X_train.shape[0])
|
||||
np.random.shuffle(idxs)
|
||||
tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) # NOTE: long doesn't fold
|
||||
train_loss = 0
|
||||
for epoch_step in (t:=trange(num_steps_per_epoch)):
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss = train_step(tidxs[epoch_step].contiguous()).float().item()
|
||||
t.set_description(f"*** loss: {loss:5.3f} lr: {opt_non_bias.lr.item():.6f}"
|
||||
f" tm: {(et:=(time.perf_counter()-st))*1000:6.2f} ms {GlobalCounters.global_ops/(1e9*et):7.0f} GFLOPS")
|
||||
train_loss += loss
|
||||
gmt = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
val_loss, acc = [x.float().item() for x in val_step()]
|
||||
get = time.perf_counter()
|
||||
print(f"\033[F*** epoch {epoch:3d} tm: {(gmt-gst):5.2f} s val_tm: {(get-gmt):5.2f} s train_loss: {train_loss/num_steps_per_epoch:5.3f} val_loss: {val_loss:5.3f} eval acc: {acc*100:5.2f}% @ {get-start_tm:6.2f} s ")
|
|
@ -9,5 +9,9 @@ class TestCompileFailures(unittest.TestCase):
|
|||
def test_interpolate_atari(self):
|
||||
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))
|
||||
|
||||
@unittest.skip("FIXME: broken on METAL")
|
||||
def test_add_max_uchar(self):
|
||||
self.compile((Tensor.empty(1024, dtype='uint8') + Tensor.empty(1024, dtype='uint8')).max())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -34,7 +34,7 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
|||
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
||||
if beam_compare == 2:
|
||||
|
|
Loading…
Reference in New Issue