import functools import os import time from tqdm import tqdm from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad.helpers import getenv, BEAM, WINO from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save from tinygrad.nn.optim import LARS, SGD, OptimizerGroup from extra.lr_scheduler import LRSchedulerGroup from examples.mlperf.helpers import get_training_state, load_training_state def train_resnet(): from extra.models import resnet from examples.mlperf.dataloader import batch_load_resnet from extra.datasets.imagenet import get_train_files, get_val_files from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup from examples.mlperf.initializers import Conv2dHeNormal, Linear from examples.hlb_cifar10 import UnsyncedBatchNorm config = {} seed = config["seed"] = getenv("SEED", 42) Tensor.manual_seed(seed) # seed for weight initialization GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))] print(f"Training on {GPUS}") for x in GPUS: Device[x] # ** model definition and initializers ** num_classes = 1000 resnet.Conv2d = Conv2dHeNormal resnet.Linear = Linear if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS)) model = resnet.ResNet50(num_classes) # shard weights and initialize in order for k, x in get_state_dict(model).items(): if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k): x.realize().shard_(GPUS, axis=0) else: x.realize().to_(GPUS) parameters = get_parameters(model) # ** hyperparameters ** epochs = config["epochs"] = getenv("EPOCHS", 41) BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112 EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS) base_lr = config["base_lr"] = getenv("LR", 8.5 * (BS/2048)) lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 5) decay = config["decay"] = getenv("DECAY", 2e-4) target, achieved = getenv("TARGET", 0.759), False eval_start_epoch = getenv("EVAL_START_EPOCH", 0) eval_epochs = getenv("EVAL_EPOCHS", 1) steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS) steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS) config["BEAM"] = BEAM.value config["WINO"] = WINO.value config["SYNCBN"] = getenv("SYNCBN") # ** Optimizer ** skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k] parameters = [x for x in parameters if x not in set(skip_list)] optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay) optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True) optimizer_group = OptimizerGroup(optimizer, optimizer_skip) # ** LR scheduler ** scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4, train_steps=epochs * steps_in_train_epoch, warmup=lr_warmup_epochs * steps_in_train_epoch) scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4, train_steps=epochs * steps_in_train_epoch, warmup=lr_warmup_epochs * steps_in_train_epoch) scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip) print(f"training with batch size {BS} for {epochs} epochs") # ** resume from checkpointing ** start_epoch = 0 if ckpt:=getenv("RESUME", ""): load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt)) start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch) print(f"resuming from {ckpt} at epoch {start_epoch}") # ** init wandb ** WANDB = getenv("WANDB") if WANDB: import wandb wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {} wandb.init(config=config, **wandb_args) BENCHMARK = getenv("BENCHMARK") # ** jitted steps ** input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1) # mlperf reference resnet does not divide by input_std for some reason # input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1) def normalize(x): return x.permute([0, 3, 1, 2]) - input_mean @TinyJit def train_step(X, Y): optimizer_group.zero_grad() X = normalize(X) out = model.forward(X) loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1) top_1 = (out.argmax(-1) == Y).sum() loss.backward() optimizer_group.step() scheduler_group.step() return loss.realize(), top_1.realize() @TinyJit def eval_step(X, Y): X = normalize(X) out = model.forward(X) loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1) top_1 = (out.argmax(-1) == Y).sum() return loss.realize(), top_1.realize() def data_get(it): x, y, cookie = next(it) return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie # ** epoch loop ** step_times = [] for e in range(start_epoch, epochs): # ** train loop ** = True it = iter(tqdm(batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e), total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, data_get(it) st = time.perf_counter() while proc is not None: GlobalCounters.reset() (loss, top_1_acc), proc = train_step(proc[0], proc[1]), proc[2] pt = time.perf_counter() try: next_proc = data_get(it) except StopIteration: next_proc = None dt = time.perf_counter() device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS cl = time.perf_counter() if BENCHMARK: step_times.append(cl - st) tqdm.write( f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {[0]:.6f} LR, " f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS") if WANDB: wandb.log({"lr":, "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch}) st = cl proc, next_proc = next_proc, None # return old cookie i += 1 if i == BENCHMARK: median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds estimated_total_hours = median_step_time * steps_in_train_epoch * epochs / 60 / 60 print(f"Estimated training time: {estimated_total_hours:.0f}h{(estimated_total_hours - int(estimated_total_hours)) * 60:.0f}m") return # ** eval loop ** if (e + 1 - eval_start_epoch) % eval_epochs == 0: train_step.reset() # free the train step memory :( eval_loss = [] eval_times = [] eval_top_1_acc = [] = False it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False), total=steps_in_val_epoch)) proc = data_get(it) while proc is not None: GlobalCounters.reset() st = time.time() (loss, top_1_acc), proc = eval_step(proc[0], proc[1]), proc[2] # drop inputs, keep cookie try: next_proc = data_get(it) except StopIteration: next_proc = None loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / EVAL_BS eval_loss.append(loss) eval_top_1_acc.append(top_1_acc) proc, next_proc = next_proc, None # return old cookie et = time.time() eval_times.append(et - st) eval_step.reset() total_loss = sum(eval_loss) / len(eval_loss) total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc) total_fw_time = sum(eval_times) / len(eval_times) tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}") if WANDB: wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1}) # save model if achieved target if not achieved and total_top_1 >= target: if not os.path.exists("./ckpts"): os.mkdir("./ckpts") fn = f"./ckpts/" safe_save(get_state_dict(model), fn) print(f" *** Model saved to {fn} ***") achieved = True # checkpoint every time we eval if getenv("CKPT"): if not os.path.exists("./ckpts"): os.mkdir("./ckpts") if WANDB and is not None: fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{}_e{e}.safe" else: fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe" print(f"saving ckpt to {fn}") safe_save(get_training_state(model, optimizer_group, scheduler_group), fn) def train_retinanet(): # TODO: Retinanet pass def train_unet3d(): # TODO: Unet3d pass def train_rnnt(): # TODO: RNN-T pass def train_bert(): # TODO: BERT pass def train_maskrcnn(): # TODO: Mask RCNN pass if __name__ == "__main__": with Tensor.train(): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): nm = f"train_{m}" if nm in globals(): print(f"training {m}") globals()[nm]()