fix CIFAR jit (#1657)

* update mask function

* kept 94 with the new fetcher

clean up batch fetcher

* 94.04% without cutmix

* 94.04% with cutmix

* move batch fetcher to avoid fetching additional batch last STEP
This commit is contained in:
Yixiang Gao 2023-08-24 19:14:40 -04:00 committed by GitHub
parent f00325e77d
commit 173850f599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 45 additions and 43 deletions

View File

@ -42,7 +42,7 @@ hyp = {
'kernel_size': 2, # kernel size for the whitening layer
'batch_norm_momentum': .5,
'cutmix_size': 3,
'cutmix_steps': 588, # original repo used epoch 6 which is roughly 6*98=588 STEPS
'cutmix_steps': 490, # different from original repo which used epoch > 12.1 - 6 which is roughly 7*98=686 STEPS
'pad_amount': 2
}
}
@ -143,14 +143,14 @@ def pad_reflect(X, size=2) -> Tensor:
return X
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
def make_square_mask(X, mask_size):
def make_square_mask(shape, mask_size):
is_even = int(mask_size % 2 == 0)
center_max = X.shape[-2]-mask_size//2-is_even
center_max = shape[-2]-mask_size//2-is_even
center_min = mask_size//2-is_even
center = Tensor.rand(X.shape[0])*(center_max-center_min)+center_min
center = Tensor.rand(shape[0])*(center_max-center_min)+center_min
d_y = Tensor.arange(0, X.shape[-2]).reshape((1,1,X.shape[-2],1))
d_x = Tensor.arange(0, X.shape[-1]).reshape((1,1,1,X.shape[-1]))
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1))
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1]))
d_y = d_y - center.reshape((-1,1,1,1))
d_x = d_x - center.reshape((-1,1,1,1))
d_y =(d_y >= -(mask_size / 2)) * (d_y <= mask_size / 2)
@ -160,34 +160,15 @@ def make_square_mask(X, mask_size):
return mask
def random_crop(X, crop_size=32):
mask = make_square_mask(X, crop_size)
mask = make_square_mask(X.shape, crop_size)
mask = mask.repeat((1,3,1,1))
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
return X_cropped.reshape((-1, 3, crop_size, crop_size))
transform = [
lambda x: x.to(device=Device.DEFAULT).float(),
lambda x: x / 255.0, # scale
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1), # normalize
lambda x: x.reshape((-1,3,32,32)),
lambda x: pad_reflect(x, size=hyp['net']['pad_amount']),
lambda x: random_crop(x, crop_size=32),
lambda x: Tensor.where(Tensor.rand(x.shape[0],1,1,1) < 0.5, x[..., ::-1], x), # flip LR
]
transform_test = [
lambda x: x.to(device=Device.DEFAULT).float(),
lambda x: x / 255.0,
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1),
lambda x: x.reshape((-1,3,32,32)),
]
def cutmix(X, Y, mask_size=3, p=0.5):
if Tensor.rand(1) > p: return X, Y
def cutmix(X, Y, mask_size=3):
# fill the square with randomly selected images from the same batch
mask = make_square_mask(X, mask_size)
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
random.shuffle(order)
X_patch = Tensor(X.numpy()[order,...])
@ -197,23 +178,36 @@ def cutmix(X, Y, mask_size=3, p=0.5):
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
return X_cutmix, Y_cutmix
def fetch_batches(X, Y, BS, seed, is_train=False):
# the operations that remain inside batch fetcher is the ones that involves random operations
def fetch_batches(X_in, Y_in, BS, seed, is_train):
step = 0
while True:
set_seed(seed)
X, Y = X_in, Y_in
order = list(range(0, X.shape[0]))
random.shuffle(order)
if is_train:
X = random_crop(X, crop_size=32)
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
X, Y = X.numpy(), Y.numpy()
for i in range(0, X.shape[0], BS):
# padding the last batch in order to match buffer size during JIT
# pad the last batch
batch_end = min(i+BS, Y.shape[0])
# TODO need indexing support for tinygrad Tensor
x = Tensor(X.numpy()[order[batch_end-BS:batch_end],:])
y = Tensor(np.eye(10, dtype=np.float32)[Y.numpy()[order[batch_end-BS:batch_end]]])
x = x.sequential(transform) if is_train else x.sequential(transform_test)
x = Tensor(X[order[batch_end-BS:batch_end],:])
y = Tensor(Y[order[batch_end-BS:batch_end]])
step += 1
yield x, y
if not is_train: break
seed += 1
transform = [
lambda x: x / 255.0,
lambda x: (x - Tensor(cifar_mean).repeat((1024,1)).T.reshape(1,-1))/ Tensor(cifar_std).repeat((1024,1)).T.reshape(1,-1),
lambda x: x.reshape((-1,3,32,32))
]
def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
# this import needs to be done here because this is running in a subprocess
from extra.dist import OOB
@ -228,9 +222,19 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
X_test, Y_test = X_train, Y_train
else:
X_train, Y_train, X_test, Y_test = fetch_cifar()
# load data and label into GPU and convert to dtype accordingly
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
# one-hot encode labels
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
# preprocess data
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
# precompute whitening patches
W = whitening(X_train.sequential(transform_test))
W = whitening(X_train)
# padding is not timed in the original repo since it can be done all at once
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
model = SpeedyResNet(W)
@ -250,7 +254,7 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
initial_div_factor = 1e16
final_lr_ratio = 0.022
final_lr_ratio = 0.02199
pct_start = hyp['opt']['percent_start']
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
@ -305,17 +309,11 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, seed=seed, is_train=True)
while i <= STEPS:
X, Y = next(batcher)
if i >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
# further split batch if distributed
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
if i%100 == 0 and i > 1:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
losses = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, seed=seed):
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, seed=seed, is_train=False):
# further split batch if distributed
if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
@ -342,6 +340,10 @@ def train_cifar(bs=BS, eval_bs=EVAL_BS, steps=STEPS, seed=32):
best_eval = acc
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i==STEPS: break
X, Y = next(batcher)
# further split batch if distributed
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
GlobalCounters.reset()
st = time.monotonic()
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)