CIFAR trainer + various bugfixes / improvements (#6146)

* move cifar into datasets

* support for pathlib Tensors, tar_extract, and fetch gunzip

* too early for Device.DEFAULT

* simpler hlb_cifar + .to(None) is default

* new compiler failure, start beautiful_cifar

* beautiful cifar runs but is broken

* jit train step

* cleaner

* std_mean, not mean_std

* more correct

* fast indexing

* don't print that

* torch load broken

* add eval

* nicer bar

* decoraters are the way to do this

* bounds check the load

* a few ops

* batchnorm bugfix, if track_running_stats is False, use online estimate

* full timing

* fix fusion

* unneeded realize

* master tensor
This commit is contained in:
George Hotz 2024-08-20 16:58:46 -07:00 committed by GitHub
parent 296368f0dd
commit 9faf205601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 173 additions and 1 deletions

168
examples/beautiful_cifar.py Normal file
View File

@ -0,0 +1,168 @@
import time
start_tm = time.perf_counter()
import math
from typing import Tuple
import numpy as np
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import partition, trange, getenv, Context
from extra.lr_scheduler import OneCycleLR
dtypes.default_float = dtypes.half
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
batchsize = getenv("BS", 1024)
bias_scaler = 64
hyp = {
'opt': {
'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
'non_bias_lr': 1.525 / 512,
'bias_decay': 6.687e-4 * batchsize/bias_scaler,
'non_bias_decay': 6.687e-4 * batchsize,
'scaling_factor': 1./9,
'percent_start': .23,
'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
},
'net': {
'whitening': {
'kernel_size': 2,
'num_examples': 50000,
},
'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
'cutmix_size': 3,
'cutmix_epochs': 6,
'pad_amount': 2,
'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
},
'misc': {
'ema': {
'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
'decay_base': .95,
'decay_pow': 3.,
'every_n_steps': 5,
},
'train_epochs': 12,
#'train_epochs': 12.1,
'device': 'cuda',
'data_location': 'data.pt',
}
}
scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
depths = {
'init': round(scaler**-1*hyp['net']['base_depth']), # 32 w/ scaler at base value
'block1': round(scaler** 0*hyp['net']['base_depth']), # 64 w/ scaler at base value
'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value
'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value
'num_classes': 10
}
whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2
class ConvGroup:
def __init__(self, channels_in, channels_out):
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
self.norm1 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
self.norm2 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
self.norm1.weight.requires_grad = False
self.norm2.weight.requires_grad = False
def __call__(self, x:Tensor) -> Tensor:
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
class SpeedyConvNet:
def __init__(self):
self.whiten = nn.Conv2d(3, 2*whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0, bias=False)
self.conv_group_1 = ConvGroup(2*whiten_conv_depth, depths['block1'])
self.conv_group_2 = ConvGroup(depths['block1'], depths['block2'])
self.conv_group_3 = ConvGroup(depths['block2'], depths['block3'])
self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
def __call__(self, x:Tensor) -> Tensor:
x = self.whiten(x).quick_gelu()
x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']
if __name__ == "__main__":
# *** dataset ***
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
# TODO: without this line indexing doesn't fuse!
X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])
# *** model ***
model = SpeedyConvNet()
state_dict = nn.state.get_state_dict(model)
#for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
num_steps_per_epoch = X_train.size(0) // batchsize
total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])
loss_batchsize_scaler = 512/batchsize
pct_start = hyp['opt']['percent_start']
initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
final_lr_ratio = .07 # Actually pretty important, apparently!
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
def loss_fn(out, Y):
return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit
@Tensor.train()
def train_step(idxs:Tensor) -> Tensor:
with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
X = X_train[idxs]
Y = Y_train[idxs].realize(X)
X, Y = preprocess(X, Y)
out = model(X)
loss = loss_fn(out, Y)
opt.zero_grad()
loss.backward()
opt.step()
lr_sched_bias.step()
lr_sched_non_bias.step()
return loss / (batchsize*loss_batchsize_scaler)
eval_batchsize = 2500
@TinyJit
@Tensor.test()
def val_step() -> Tensor:
# TODO with Tensor.no_grad()
Tensor.no_grad = True
loss, acc = [], []
for i in range(0, X_test.size(0), eval_batchsize):
X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
out = model(X)
loss.append(loss_fn(out, Y))
acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
Tensor.no_grad = False
return ret
np.random.seed(1337)
for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
# TODO: move to tinygrad
gst = time.perf_counter()
idxs = np.arange(X_train.shape[0])
np.random.shuffle(idxs)
tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) # NOTE: long doesn't fold
train_loss = 0
for epoch_step in (t:=trange(num_steps_per_epoch)):
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step(tidxs[epoch_step].contiguous()).float().item()
t.set_description(f"*** loss: {loss:5.3f} lr: {opt_non_bias.lr.item():.6f}"
f" tm: {(et:=(time.perf_counter()-st))*1000:6.2f} ms {GlobalCounters.global_ops/(1e9*et):7.0f} GFLOPS")
train_loss += loss
gmt = time.perf_counter()
GlobalCounters.reset()
val_loss, acc = [x.float().item() for x in val_step()]
get = time.perf_counter()
print(f"\033[F*** epoch {epoch:3d} tm: {(gmt-gst):5.2f} s val_tm: {(get-gmt):5.2f} s train_loss: {train_loss/num_steps_per_epoch:5.3f} val_loss: {val_loss:5.3f} eval acc: {acc*100:5.2f}% @ {get-start_tm:6.2f} s ")

View File

@ -9,5 +9,9 @@ class TestCompileFailures(unittest.TestCase):
def test_interpolate_atari(self):
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))
@unittest.skip("FIXME: broken on METAL")
def test_add_max_uchar(self):
self.compile((Tensor.empty(1024, dtype='uint8') + Tensor.empty(1024, dtype='uint8')).max())
if __name__ == '__main__':
unittest.main()

View File

@ -34,7 +34,7 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
if beam_compare == 2: