tinygrad/examples/mlperf/model_train.py

254 lines
9.9 KiB
Python
Raw Normal View History

MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
import functools
import os
import time
from tqdm import tqdm
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
from extra.lr_scheduler import LRSchedulerGroup
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
from examples.mlperf.helpers import get_training_state, load_training_state
2023-05-30 22:10:40 +08:00
def train_resnet():
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
from extra.models import resnet
from examples.mlperf.dataloader import batch_load_resnet
from extra.datasets.imagenet import get_train_files, get_val_files
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from examples.mlperf.initializers import Conv2dHeNormal, Linear
from examples.hlb_cifar10 import UnsyncedBatchNorm
config = {}
seed = config["seed"] = getenv("SEED", 42)
Tensor.manual_seed(seed) # seed for weight initialization
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"Training on {GPUS}")
for x in GPUS: Device[x]
# ** model definition and initializers **
num_classes = 1000
resnet.Conv2d = Conv2dHeNormal
resnet.Linear = Linear
if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
model = resnet.ResNet50(num_classes)
# shard weights and initialize in order
for k, x in get_state_dict(model).items():
if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
x.realize().shard_(GPUS, axis=0)
else:
x.realize().to_(GPUS)
parameters = get_parameters(model)
# ** hyperparameters **
epochs = config["epochs"] = getenv("EPOCHS", 41)
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
base_lr = config["base_lr"] = getenv("LR", 8.5 * (BS/2048))
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 5)
decay = config["decay"] = getenv("DECAY", 2e-4)
target, achieved = getenv("TARGET", 0.759), False
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_epochs = getenv("EVAL_EPOCHS", 1)
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
config["BEAM"] = BEAM.value
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
print(f"training with batch size {BS} for {epochs} epochs")
# ** resume from checkpointing **
start_epoch = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args)
BENCHMARK = getenv("BENCHMARK")
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
# ** jitted steps **
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
# mlperf reference resnet does not divide by input_std for some reason
# input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
def normalize(x): return x.permute([0, 3, 1, 2]) - input_mean
@TinyJit
def train_step(X, Y):
optimizer_group.zero_grad()
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
X = normalize(X)
out = model.forward(X)
loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
loss.backward()
optimizer_group.step()
scheduler_group.step()
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
X = normalize(X)
out = model.forward(X)
loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie
# ** epoch loop **
step_times = []
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
for e in range(start_epoch, epochs):
# ** train loop **
Tensor.training = True
it = iter(tqdm(batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e),
total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
i, proc = 0, data_get(it)
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
(loss, top_1_acc), proc = train_step(proc[0], proc[1]), proc[2]
pt = time.perf_counter()
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS
cl = time.perf_counter()
if BENCHMARK:
step_times.append(cl - st)
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
st = cl
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_hours = median_step_time * steps_in_train_epoch * epochs / 60 / 60
print(f"Estimated training time: {estimated_total_hours:.0f}h{(estimated_total_hours - int(estimated_total_hours)) * 60:.0f}m")
return
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
# ** eval loop **
if (e + 1 - eval_start_epoch) % eval_epochs == 0:
train_step.reset() # free the train step memory :(
eval_loss = []
eval_times = []
eval_top_1_acc = []
Tensor.training = False
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False), total=steps_in_val_epoch))
proc = data_get(it)
while proc is not None:
GlobalCounters.reset()
st = time.time()
(loss, top_1_acc), proc = eval_step(proc[0], proc[1]), proc[2] # drop inputs, keep cookie
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / EVAL_BS
eval_loss.append(loss)
eval_top_1_acc.append(top_1_acc)
proc, next_proc = next_proc, None # return old cookie
et = time.time()
eval_times.append(et - st)
eval_step.reset()
total_loss = sum(eval_loss) / len(eval_loss)
total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc)
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
if WANDB:
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
# save model if achieved target
if not achieved and total_top_1 >= target:
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
MLPerf Resnet (cleaned up) (#3573) * this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f. Revert "log layer stats" This reverts commit 9d9822458524c514939adeee34b88356cd191cb0. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
fn = f"./ckpts/resnet50.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
achieved = True
# checkpoint every time we eval
if getenv("CKPT"):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
2023-05-30 22:10:40 +08:00
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# TODO: BERT
pass
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()