mirror of https://github.com/commaai/tinygrad.git
fix CIFAR jit (#1657)
* update mask function * kept 94 with the new fetcher clean up batch fetcher * 94.04% without cutmix * 94.04% with cutmix * move batch fetcher to avoid fetching additional batch last STEP
This commit is contained in:
parent
f00325e77d
commit
173850f599
|
@ -42,7 +42,7 @@ hyp = {
|
|||
'kernel_size': 2, # kernel size for the whitening layer
|
||||
'batch_norm_momentum': .5,
|
||||
'cutmix_size': 3,
|
||||
'cutmix_steps': 588, # original repo used epoch 6 which is roughly 6*98=588 STEPS
|
||||
'cutmix_steps': 490, # different from original repo which used epoch > 12.1 - 6 which is roughly 7*98=686 STEPS
|
||||
'pad_amount': 2
|
||||
}
|
||||
}
|
||||
|
@ -143,14 +143,14 @@ def pad_reflect(X, size=2) -> Tensor:
|
|||
return X
|
||||
|
||||
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
|
||||
def make_square_mask(X, mask_size):
|
||||
def make_square_mask(shape, mask_size):
|
||||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = X.shape[-2]-mask_size//2-is_even
|
||||
center_max = shape[-2]-mask_size//2-is_even
|
||||
center_min = mask_size//2-is_even
|
||||
center = Tensor.rand(X.shape[0])*(center_max-center_min)+center_min
|
||||
center = Tensor.rand(shape[0])*(center_max-center_min)+center_min
|
||||
|
||||
d_y = Tensor.arange(0, X.shape[-2]).reshape((1,1,X.shape[-2],1))
|
||||
d_x = Tensor.arange(0, X.shape[-1]).reshape((1,1,1,X.shape[-1]))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1))
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1]))
|
||||
d_y = d_y - center.reshape((-1,1,1,1))
|
||||
d_x = d_x - center.reshape((-1,1,1,1))
|
||||
d_y =(d_y >= -(mask_size / 2)) * (d_y <= mask_size / 2)
|
||||
|
@ -160,34 +160,15 @@ def make_square_mask(X, mask_size):
|
|||
return mask
|
||||
|
||||
def random_crop(X, crop_size=32):
|
||||
mask = make_square_mask(X, crop_size)
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
mask = mask.repeat((1,3,1,1))
|
||||
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
|
||||
|
||||
return X_cropped.reshape((-1, 3, crop_size, crop_size))
|
||||
|
||||
transform = [
|
||||
lambda x: x.to(device=Device.DEFAULT).float(),
|
||||
lambda x: x / 255.0, # scale
|
||||
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1), # normalize
|
||||
lambda x: x.reshape((-1,3,32,32)),
|
||||
lambda x: pad_reflect(x, size=hyp['net']['pad_amount']),
|
||||
lambda x: random_crop(x, crop_size=32),
|
||||
lambda x: Tensor.where(Tensor.rand(x.shape[0],1,1,1) < 0.5, x[..., ::-1], x), # flip LR
|
||||
]
|
||||
|
||||
transform_test = [
|
||||
lambda x: x.to(device=Device.DEFAULT).float(),
|
||||
lambda x: x / 255.0,
|
||||
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1),
|
||||
lambda x: x.reshape((-1,3,32,32)),
|
||||
]
|
||||
|
||||
def cutmix(X, Y, mask_size=3, p=0.5):
|
||||
if Tensor.rand(1) > p: return X, Y
|
||||
|
||||
def cutmix(X, Y, mask_size=3):
|
||||
# fill the square with randomly selected images from the same batch
|
||||
mask = make_square_mask(X, mask_size)
|
||||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order,...])
|
||||
|
@ -197,23 +178,36 @@ def cutmix(X, Y, mask_size=3, p=0.5):
|
|||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
|
||||
def fetch_batches(X, Y, BS, seed, is_train=False):
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
def fetch_batches(X_in, Y_in, BS, seed, is_train):
|
||||
step = 0
|
||||
while True:
|
||||
set_seed(seed)
|
||||
X, Y = X_in, Y_in
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
if is_train:
|
||||
X = random_crop(X, crop_size=32)
|
||||
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
|
||||
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# padding the last batch in order to match buffer size during JIT
|
||||
# pad the last batch
|
||||
batch_end = min(i+BS, Y.shape[0])
|
||||
# TODO need indexing support for tinygrad Tensor
|
||||
x = Tensor(X.numpy()[order[batch_end-BS:batch_end],:])
|
||||
y = Tensor(np.eye(10, dtype=np.float32)[Y.numpy()[order[batch_end-BS:batch_end]]])
|
||||
x = x.sequential(transform) if is_train else x.sequential(transform_test)
|
||||
x = Tensor(X[order[batch_end-BS:batch_end],:])
|
||||
y = Tensor(Y[order[batch_end-BS:batch_end]])
|
||||
step += 1
|
||||
yield x, y
|
||||
|
||||
if not is_train: break
|
||||
seed += 1
|
||||
|
||||
transform = [
|
||||
lambda x: x / 255.0,
|
||||
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1),
|
||||
lambda x: x.reshape((-1,3,32,32))
|
||||
]
|
||||
|
||||
def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
|
@ -228,9 +222,19 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
|
|||
X_test, Y_test = X_train, Y_train
|
||||
else:
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||
# preprocess data
|
||||
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
|
||||
|
||||
# precompute whitening patches
|
||||
W = whitening(X_train.sequential(transform_test))
|
||||
W = whitening(X_train)
|
||||
|
||||
# padding is not timed in the original repo since it can be done all at once
|
||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
||||
|
||||
model = SpeedyResNet(W)
|
||||
|
||||
|
@ -250,7 +254,7 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
|
|||
|
||||
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
||||
initial_div_factor = 1e16
|
||||
final_lr_ratio = 0.022
|
||||
final_lr_ratio = 0.02199
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
|
@ -305,17 +309,11 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
|
|||
i = 0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, seed=seed, is_train=True)
|
||||
while i <= STEPS:
|
||||
X, Y = next(batcher)
|
||||
if i >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
|
||||
if i%100 == 0 and i > 1:
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
losses = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, seed=seed):
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, seed=seed, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
|
@ -342,6 +340,10 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
|
|||
best_eval = acc
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
X, Y = next(batcher)
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
st = time.monotonic()
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
|
|
Loading…
Reference in New Issue