Commit Graph

7 Commits

Author SHA1 Message Date
wozeparrot 9a9cac58f9
add lars to nn (#3750)
* feat: add lars

* feat: don't remove this comment

* clean: smaller diff

* clean: shorter line

* feat: remove mlperf lars, switch resnet

* fix: fully remove mlperf lars

* clean: comment

* feat: contiguous

* feat: no weight decay on skip params

* feat: optimizergroup

* feat: classic momentum

* fix: pylint

* clean: move comment

* fix: correct algo

* feat: lrschedulergroup

* feat: skip list tests

* feat: :| forgot that params are a thing

* feat: remove skip_list params from main params

* feat: set moment

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-24 11:43:12 -04:00
chenyu 24d004a89b
hotfix check ckpts before writing achieved model (#3901)
this killed tinybox green run
2024-03-23 17:16:38 -04:00
chenyu e1c5aa9cce
estimated resnet training time for BENCHMARK (#3769) 2024-03-15 22:36:58 -04:00
chenyu 4bd5535d72
update mlperf resnet default hparams (#3758)
we might be able to have higher lr given smaller BS, but this is good.

Trained to 75.9%
https://wandb.ai/chenyuxyz/tinygrad-examples_mlperf/runs/xi2f48se/overview
2024-03-15 12:09:26 -04:00
David Hou 199f7c4342
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff

TEST_TRAIN env for less data

don't diskcache get_train_files

debug message

no lr_scaler for fp32

comment, typo

type stuff

don't destructure proc

make batchnorm parameters float

make batchnorm parameters float

resnet18, checkpointing

hack up checkpointing to keep the names in there

oops

wandb_resume

lower lr

eval/ckpt use e+1

lars

report top_1_acc

some wandb stuff

split fw and bw steps to save memory

oops

save model when reach target

formatting

make sgd hparams consistent

just always write the cats tag...

pass X and Y into backward_step to trigger input replace

shuffle eval set to fix batchnorm eval

dataset is sorted by class, so the means and variances are all wrong

small cleanup

hack restore only one copy of each tensor

do bufs from lin after cache check (lru should handle it fine)

record epoch in wandb

more digits for topk in eval

more env vars

small cleanup

cleanup hack tricks

cleanup hack tricks

don't save ckpt for testeval

cleanup

diskcache train file glob

clean up a little

device_str

SCE into tensor

small

small

log_softmax out of resnet.py

oops

hack :(

comments

HeNormal, track gradient norm

oops

log SYNCBN to wandb

real truncnorm

less samples for truncated normal

custom init for Linear

log layer stats

small

Revert "small"

This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.

Revert "log layer stats"

This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.

rename BNSYNC to SYNCBN to be consistent with cifar

optional TRACK_NORMS

fix label smoothing :/

lars skip list

only weight decay if not in skip list

comment

default 0 TRACK_NORMS

don't allocate beam scratch buffers if in cache

clean up data pipeline, unsplit train/test, put back a hack

remove print

run test_indexing on remu (#3404)

* emulated ops_hip infra

* add int4

* include test_indexing in remu

* Revert "Merge branch 'remu-dev-mac'"

This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16.

fix bad seeding

UnsyncBatchNorm2d but with synced trainable weights

label downsample batchnorm in Bottleneck

:/

:/

i mean... it runs... its hits the acc... its fast...

new unsyncbatchnorm for resnet

small fix

don't do assign buffer reuse for axis change

* remove changes

* remove changes

* move LARS out of tinygrad/

* rand_truncn rename

* whitespace

* stray whitespace

* no more gnorms

* delete some dataloading stuff

* remove comment

* clean up train script

* small comments

* move checkpointing stuff to mlperf helpers

* if WANDB

* small comments

* remove whitespace change

* new unsynced bn

* clean up prints / loop vars

* whitespace

* undo nn changes

* clean up loops

* rearrange getenvs

* cpu_count()

* PolynomialLR whitespace

* move he_normal out

* cap warmup in polylr

* rearrange wandb log

* realize both x and y in data_get

* use double quotes

* combine prints in ckpts resume

* take UBN from cifar

* running_var

* whitespace

* whitespace

* typo

* if instead of ternary for resnet downsample

* clean up dataloader cleanup a little?

* separate rng for shuffle

* clean up imports in model_train

* clean up imports

* don't realize copyin in data_get

* remove TESTEVAL (train dataloader didn't get freed every loop)

* adjust wandb_config entries a little

* clean up wandb config dict

* reduce lines

* whitespace

* shorter lines

* put shm unlink back, but it doesn't seem to do anything

* don't pass seed per task

* monkeypatch batchnorm

* the reseed was wrong

* add epoch number to desc

* don't unsyncedbatchnorm is syncbn=1

* put back downsample name

* eval every epoch

* Revert "the reseed was wrong"

This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.

* cast lr in onecycle

* support fp16

* cut off kernel if expand after reduce

* test polynomial lr

* move polynomiallr to examples/mlperf

* working PolynomialDecayWithWarmup + tests.......

add lars_util.py, oops

* keep lars_util.py as intact as possible, simplify our interface

* no more half

* polylr and lars were merged

* undo search change

* override Linear init

* remove half stuff from model_train

* update scheduler init with new args

* don't divide by input mean

* mistake in resnet.py

* restore whitespace in resnet.py

* add test_data_parallel_resnet_train_step

* move initializers out of resnet.py

* unused imports

* log_softmax to model output in test to fix precision flakiness

* log_softmax to model output in test to fix precision flakiness

* oops, don't realize here

* is None

* realize initializations in order for determinism

* BENCHMARK flag for number of steps

* add resnet to bechmark.yml

* return instead of break

* missing return

* cpu_count, rearrange benchmark.yml

* unused variable

* disable tqdm if BENCHMARK

* getenv WARMUP_EPOCHS

* unlink disktensor shm file if exists

* terminate instead of join

* properly shut down queues

* use hip in benchmark for now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 00:53:41 -04:00
Yixiang Gao 094d3d71be
with Tensor.train() (#1935)
* add with.train

* remove the rest TODOs

* fix pyflake

* fix pyflake error

* fix mypy
2023-09-28 18:02:31 -07:00
wozeparrot 0fc4cf72a2
feat: add train scaffolding (#859) 2023-05-30 07:10:40 -07:00