Commit Graph

872 Commits

Author SHA1 Message Date
George Hotz d4b662c318
new openpilot compile (#6573)
* new openpilot compile

* note, copyout doesn't work for images
2024-09-18 14:22:50 +08:00
kormann f5dd25d376
enable whisper batch for long sequences (#6458)
* long batch +test

* long batch +test

* cleanup

* rollback syntactic changes

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-09-17 00:42:10 -04:00
chenyu 798be6bb74
add gated read_image count in openpilot compile2 (#6546)
530 to go
2024-09-16 21:17:00 -04:00
Francis Lata b7ce9a1530
UNet3D MLPerf (#3470)
* add training set transforms

* add DICE cross entropy loss

* convert pred and label to Tensor when calculating DICE score

* cleanups and allow train dataset batching

* fix DICE CE loss calculation

* jitted training step

* clean up DICE CE loss calculation

* initial support for sharding

* Revert "initial support for sharding"

This reverts commit e3670813b8a67469e7f694e09f2d15a8c40065da.

* minor updates

* cleanup imports

* add support for sharding

* apply temp patch to try to avoid OOM

* revert cstyle changes

* add gradient acc

* hotfix

* add FP16 support

* add ability to train on smaller image sizes

* add support for saving and loading checkpoints + cleanup some various modes

* fix issue with using smaller patch size + update W&B logging

* disable LR_WARMUP_EPOCHS

* updates

* minor cleanups

* cleanup

* update order of transformations

* more cleanups

* realize loss

* cleanup

* more cleanup

* some cleanups

* add RAM usage

* minor cleanups

* add support for gradient accumulation

* cleanup imports

* minor updates to not use GA_STEPS

* remove FP16 option since it's available now globally

* update multi-GPU setup

* add timing logs for training loop

* go back to using existing dataloader and add ability to preprocess data to save time

* clean up optimization and re-enable JIT and multi-GPU support for training and evaluation

* free train and eval steps memory

* cleanups and scale batch size based on the number of GPUs

* fix GlobalCounters import

* fix seed

* fix W&B setup

* update batch size default size

* add back metric divergence check

* put back JIT on UNet3d eval

* move dataset preprocessing inside training code

* add test for dice_loss

* add config logging support to W&B and other cleanups

* change how default float is getting retrieved

* remove TinyJit import duplicate

* update config logging to W&B and remove JIT on eval_step

* no need for caching preprocessed data anymore

* fix how evaluation is ran and how often

* add support for LR scaling

* fix issue with gaussian being moved to scipy.signal.windows

* remove DICE loss unit test

* fix issue where loss isn't compatible with multiGPU

* add individual BEAM control for train and eval steps

* fix ndimage scipy import

* add BENCHMARK

* cleanups on BENCHMARK + fix on rand_flip augmentation during training

* cleanup train and eval BEAM envs

* add checkpointing support after every eval

* cleanup model_eval

* disable grad during eval

* use new preprocessing dataset mechanism

* remove unused import

* use training and inference_mode contexts

* start eval after benchmarking

* add data fetching time

* cleanup decorators

* more cleanups on training script

* add message during benchmarking mode

* realize when reassigning LR on scheduler and update default number of epochs

* add JIT on eval step

* remove JIT on eval_step

* add train dataloader for unet3d

* move checkpointing to be done after every epoch

* revert removal of JIT on unet3d inference

* save checkpoint if metric is not successful

* Revert "add train dataloader for unet3d"

This reverts commit c166d129dfbe2e1c46d1937135a60b4ed25caa3d.

* Revert "Revert "add train dataloader for unet3d""

This reverts commit 36366c65d26f59ed1227acb670d5ce7b997606ae.

* hotfix: seed was defaulting to a value of 0

* fix SEED value

* remove the usage of context managers for setting BEAM and going from training to inference

* support new stack API for calculating eval loss and metric

* Revert "remove the usage of context managers for setting BEAM and going from training to inference"

This reverts commit 2c0ba8d322ec912bd8617cbe167c542e9ba229d9.

* check training and test preprocessed folders separately

* clean up imports and log FUSE_CONV_BW

* use train and val preprocessing constants

* add kits19 dataset setup script

* update to use the new test decorator for disabling grad

* update kits19 dataset setup script

* add docs on how to train the model

* set default value for BASEDIR

* add detailed instruction about BASEDIR usage

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-09-10 04:37:28 -04:00
kormann f6f4f3222f
whisper long batch (#6335)
* reset

* test

* only part refactor
2024-09-09 21:03:59 -04:00
qazal 935b6b658f
delete seen from the scheduler api [run_process_replay] (#6427)
docs
2024-09-09 16:26:34 +08:00
wozeparrot cb61cfce24
feat: example and extra tweaks (#6310) 2024-08-28 19:26:11 -07:00
Tobias Fischer 3517aa89d9
sdxl batched inference fixes (#6293) 2024-08-28 07:44:58 -04:00
qazal 552fbd5527
update llm.c with UOp ast [run_process_replay] (#6296) 2024-08-27 15:04:54 +03:00
chenyu c9a9631818
no UnaryOps.NEG in generated UOp patterns (#6209)
* no UnaryOps.NEG in generated UOp patterns

removed pattern `x * (-1) -> -x`  and `x != True`

* those are fine because NEG became CMPNE and True

* fix sd validation L2 norm
2024-08-21 11:08:22 -04:00
George Hotz 9faf205601
CIFAR trainer + various bugfixes / improvements (#6146)
* move cifar into datasets

* support for pathlib Tensors, tar_extract, and fetch gunzip

* too early for Device.DEFAULT

* simpler hlb_cifar + .to(None) is default

* new compiler failure, start beautiful_cifar

* beautiful cifar runs but is broken

* jit train step

* cleaner

* std_mean, not mean_std

* more correct

* fast indexing

* don't print that

* torch load broken

* add eval

* nicer bar

* decoraters are the way to do this

* bounds check the load

* a few ops

* batchnorm bugfix, if track_running_stats is False, use online estimate

* full timing

* fix fusion

* unneeded realize

* master tensor
2024-08-20 16:58:46 -07:00
George Hotz d9c62a33c3
add cifar to datasets.py (#6210) 2024-08-20 11:42:49 -07:00
George Hotz 17a043edad
tensor inference (#6156)
* tensor inference

* test is even better name
2024-08-18 00:19:28 -07:00
qazal 28c75bf2a6
merge uops with ops (#6111)
Co-authored-by: chenyu <chenyu@fastmail.com>
2024-08-16 18:17:57 -04:00
qazal c23d44c779
AST is UOp (#6030)
* most of the work from the uops2 branch

* schedule

* realize

* kernel

* lowerer

* search

* green

* merge uops with ops

* Revert "merge uops with ops"

This reverts commit 1408a59f12c97e3466679884266b247cf9df46bc.

* fix benchmark

* remove extra dedup
2024-08-16 22:09:00 +03:00
George Hotz 14b613e281 add STEPS to beautiful_mnist 2024-08-10 15:23:44 -07:00
wozeparrot d269bc95fa
faster tinychat (#5993) 2024-08-08 19:16:26 -07:00
Elias Wahl c9b4602854
no load in INITMLPERF (#5957) 2024-08-08 11:28:24 -04:00
Elias Wahl c9862e17d4
MLPERF BERT submission scripts (#5931)
* green

* red

* fix benchmark

* log

* count train samples

* oops. 4.0 -> 4.1

* note to todo

* no pillow
2024-08-06 18:09:18 -04:00
chenyu 1dab75ae37
clean up mlperf dataloader import (#5940)
use tinygrad tqdm for dataset, and PIL Image is only needed for resnet
2024-08-06 17:10:08 -04:00
George Hotz e077bc7baf
move memory planner to realize (#5937) 2024-08-06 10:41:29 -07:00
Elias Wahl 937bf5fe12
better hparam (#5891) 2024-08-03 12:38:53 -04:00
Elias Wahl 4a114756f6
New BERT dataloader (#5881)
* One file == One topic

* update test

* new dataloader

* update train script

* get index is faster
2024-08-02 15:12:23 -04:00
David Hou 9a485f36e4
shard kvcache (#5830) 2024-07-30 20:29:54 -07:00
George Hotz 21c5e8e1b7
extreme llama speed, 57.34 tok/s (#5827)
* extreme llama speed

* mergable
2024-07-30 18:32:09 -07:00
Francis Lata a0baff7a3d
update dataloader script example (#5818) 2024-07-30 15:18:29 -04:00
wozeparrot eebb1b9922
feat: temperature 0 llama3 benchmark (#5806) 2024-07-30 12:05:36 -07:00
wozeparrot 639af3f823
llama3 temperature flag (#5803) 2024-07-29 16:33:51 -07:00
George Hotz 8b34ee2f52
remove global_size and local_size from Kernel class [run_process_replay] (#5720)
* remove global_size and local_size from Kernel class [run_process_replay]

* sizes from the prg
2024-07-25 13:55:08 -07:00
George Hotz 7f5282b2f5
tests if the linearizer is generating dumb code (#5611)
* tests if the linearizer is generating dumb code

* push consts to the end

* sort adds

* sorted add and mul

* this better

* simple expand/contract

* no math contract/expand
2024-07-20 20:36:32 -07:00
George Hotz b399ccd6ef
BEAM bugfix, kernels dedup now (#5617)
* BEAM bugfix, kernels dedup now

* getenv is default
2024-07-20 19:43:50 -07:00
chenyu d71308ed68
copy mlperf 4.0 to mlperf 4.1 (#5614) 2024-07-20 16:12:00 -04:00
George Hotz 1113e47f96 print best in MCTS + light up the winner in hcopt 2024-07-20 09:39:36 -07:00
George Hotz 06e336bccb
mcts search (#5598)
* mcts search

* mcts cleanups

* mcts cleanup

* random shuffle children order

* mcts in handcode_opt

* src and remove_node

* debug 3 to print ast

* print the type

* mcts in extra
2024-07-19 21:38:39 -07:00
George Hotz 0ad87021e2
move acc to end (#5568)
* move acc to end

* confirmed pictures are the same

* relax that

* Update test_ops.py
2024-07-19 03:06:52 -07:00
George Hotz 2de82b8a5d
remove get_lazyop_info (#5570)
* don't use get_lazyop_info more

* keep that min

* no ptx for that test
2024-07-19 03:05:33 -07:00
kormann 2c4add6844
pretty print lazy op per default (#5505)
* pretty lop

* min diff

* walrus

* fix

* min diff

* simplify

* pretty helper function

* ws

* pretty uop upat

* tests

* stricter tests

* test passes

* ws

* stronger upat test

* delete print_tree

* min diff

* stricter exp test

* fix merge

* stronger uops eval test

* +readable and deep upat test

* +readable and deep upat test

* sort inv fix

* fix

* revert allowed_len
2024-07-18 09:34:08 -07:00
George Hotz fa7e734b49
MetaOps.KERNEL (#5543) 2024-07-17 19:41:23 -07:00
chenyu 4193095f67
fix handcode_opt.py with DEBUG=2 (#5530)
only one ast per kernel now
2024-07-17 14:50:47 -04:00
George Hotz a9f5a764dc
make BatchNorm work for 2D and 3D (#5477)
* make BatchNorm work for 2D and 3D

* beautiful mnist shouldn't use BatchNorm2d
2024-07-14 11:39:58 -07:00
George Hotz aade18d20c beautiful_mnist in torch 2024-07-14 11:09:58 -07:00
George Hotz cdf63e41bf mnist mlx example uses compile to be fair to tinyjit 2024-07-13 18:14:45 -07:00
George Hotz 8940530290 add mlx beautiful_mnist example 2024-07-13 17:55:47 -07:00
chenyu 28972418c4
s/get_linearizer/get_kernel [run_process_replay] (#5467) 2024-07-13 20:32:22 -04:00
Francis Lata 0345577032
UNet3D dataloader shared memory fix (#5465)
* create separate SharedMemory between inputs and labels

* update path check for shared mem

* clean up unit test for dataset
2024-07-13 20:26:00 -04:00
chenyu 4df63da190
clean up rest of the loadop [run_process_replay] (#5440)
to metaop and filter_sink
2024-07-12 23:38:51 -04:00
George Hotz 03c2dc8bd7
lowerer is kernel [run_process_replay] (#5437) 2024-07-12 18:50:55 -07:00
chenyu 9a187e6102
fix handcode_opt script (#5435)
* fix handcode_opt script

* run in ci

* real run in ci

* HALF=0
2024-07-12 20:52:28 -04:00
George Hotz 870dc8c350
s/Linearizer/Lowerer [run_process_replay] (#5428) 2024-07-12 15:54:07 -07:00
George Hotz 6707c778d0
scheduleitem is not Tuple [run_process_replay] (#5425)
* scheduleitem is not Tuple [run_process_replay]

* fix tests

* fix op + fuzzers

* fix mop test
2024-07-12 15:13:19 -07:00
George Hotz f6ef283e6a
s/loadops/metaops [run_process_replay] (#5421) 2024-07-12 13:26:50 -07:00
wozeparrot d1cbd6bb95
unity handcode_resnet_opt and handcode_bert_opt (#5418) 2024-07-12 12:05:01 -07:00
wozeparrot b7cc75a9df
usage summary in handcode opt (#5414) 2024-07-12 11:21:18 -07:00
George Hotz 8390feb7b9
optim.OptimizerGroup in hlb_cifar (#5401) 2024-07-11 20:14:36 -07:00
wozeparrot c24d495ef9
metadata in handcode_opt (#5400) 2024-07-11 17:45:34 -07:00
George Hotz 5232e405ce hotfix: add BS to beautiful_mnist 2024-07-11 10:55:05 -07:00
wozeparrot c9b3ae6bbf
fix llama.py chat mode assert (#5366) 2024-07-10 18:06:14 -07:00
wozeparrot fa873df9c1
bring tinychat more inline with tinyos' version (#5358) 2024-07-10 13:13:52 -07:00
chenyu 322c37e621
use helpers.JIT in llama and gpt2 examples (#5350)
* use helpers.JIT in llama and gpt2 examples

replaced getenv("JIT"), effectively made gpt2 default jit

* fix test_gpt2
2024-07-09 15:04:43 -04:00
Elias Wahl 73bddc44f6
Fix fake dataloader (#5326) 2024-07-08 09:07:44 -04:00
chenyu 43c3f73fbc
handcode_bert_opt.py (#5295)
similar to handcode_resnet50_opt.py, one file to check bert kernels without dataset.
2024-07-05 11:01:20 -04:00
Tobias Fischer 0c3a35e5c2
Stable Diffusion v2 Inference (#5283)
* model implementation

* clip fix, more qol options
2024-07-03 22:47:10 -04:00
reddyn12 d3e244d8b7
prev speed improvements (#5252)
Co-authored-by: reddyn <nikidsniper@gmail.com>
2024-07-03 09:06:01 -07:00
chenyu 191463a919
add timing to SDXL (#5273) 2024-07-02 23:29:54 -04:00
chenyu b2c3a28a5e
nn.RMSNorm (#5272)
the norm itself has no significant value to add to Tensor method, but we would want Tensor.normalize
2024-07-02 21:39:01 -04:00
Tobias Fischer 8c9c1cf62f
Pulled CLIP and UNet into Seperate Files (#5253)
* pulled clip and unet into seperate files

* reference cleanup, lru cache fix

* better pool indexing
2024-07-01 22:33:01 -04:00
chenyu b9122ecdaf
revert stable diffusion validation with threefry (#5248)
* Revert "use threefry in stable diffusion benchmark (#4988)"

This reverts commit 44dfa37c70.

* sdxl and validation fix

* relax threshold
2024-07-01 14:43:47 -04:00
George Hotz 3df47bc21e
OpenELM + repeat_interleave (#5234)
* start writing openelm

* progress...hit bug

* repeat_interleave support

* gqa

* add rotary embedding

* spp

* i think it runs correctly

* broken

* output is good now

* cleanups

* no io_uring on android
2024-06-30 15:18:39 -07:00
chenyu 88763eb9ff
fix stable_diffusion with fp16 (#5239) 2024-06-30 12:59:31 -04:00
chenyu 7090eac8cb
validate sdxl output and put it in benchmark (#5211)
* validate sdxl output and put it in benchmark

* don't print fetch progress_bar in CI
2024-06-28 11:40:52 -04:00
chenyu 63fa4e2a0e
fix seed = 0 in sdxl (#5209)
removed a few unneeded realize and contiguous too
2024-06-28 08:48:59 -04:00
Tobias Fischer 4688f97d48
Add SDXL Inference to Examples (#5206)
* added sdxl inference code

* fixed trailing whitespace

* use original impl code, removed uneeded numpy calls
2024-06-28 07:42:28 -04:00
chenyu 0ba093dea0
hotfix: only validate stable diffusion when using threefry (#5166) 2024-06-26 16:50:38 -04:00
chenyu e4a5870b36
validate stable_diffusion output (#5163)
changed default steps, forgot to update validation
2024-06-26 16:42:21 -04:00
nimlgen 21b225ac45
llama3 download works (#5160) 2024-06-26 22:45:13 +03:00
wozeparrot c91b3c4079
shard llama3 on 0 sometimes (#5157) 2024-06-26 11:50:57 -07:00
Elias Wahl e267f3161d
Add MLLogger (#5125)
* add MLPerf logger

* eval steps

* start with step 1

* compliance for 3.1.0 and 4.0.0

* more compliance

* assert, comment and contiguous
2024-06-26 12:23:56 -04:00
David Hou 3604642847
Llama shard axis 0 sometimes (#5123)
* make buffer view optional with a flag [run_process_replay]

* do not view when sharding to save memory [run_process_replay]

* llama shard axis=0 sometimes

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
2024-06-26 10:35:25 -04:00
chenyu dade7677cf
validate llama3 output only with model "LLaMA-3/8B-SF-DPO" (#5138) 2024-06-24 20:58:25 -04:00
chenyu 055e616302
cleanup mnist data load in beautiful_mnist (#5106) 2024-06-22 18:31:51 -04:00
chenyu e356807696
tinytqdm.set_description and tinytrange (#5101) 2024-06-22 14:45:06 -04:00
chenyu 8080298739
s/tinytqdm/tqdm (#5103)
except in unit test where tqdm is imported
2024-06-22 14:18:26 -04:00
chenyu e468601226
update llama attention casting (#5096)
* update llama attention casting

updated scaled_dot_product_attention middle cast and removed hard-coded half in llama attention.

* fix that
2024-06-22 10:57:17 -04:00
wozeparrot acb715c64c
fix: llama3 special tokens (#5045) 2024-06-18 17:08:44 -07:00
chenyu a3ed4176c8
use tinytqdm in active tests and examples (#5038)
* use tinytqdm in active tests and examples

stress test this before 0.9.1

* no set_description
2024-06-18 16:01:19 -04:00
Elias Wahl f31ef11537
Better default hparams for large BS (#5030)
* better default hparams for large BS

* bf16 too

* use tuple
2024-06-18 11:13:06 -04:00
Elias Wahl 7bfa9101c0
Float in scaled dot product attention (#4985)
* Monkeypatch scaled-dot-product-attention

* Use dot instead of matmul

* new api

* imports

* least_upper_dtype
2024-06-18 08:16:41 -04:00
chenyu c52352bd9a
fix yolov8 example (#5003)
it was creating Tensor from a list of numpy arrays, which is not supported after moving creating from a list not using numpy.
2024-06-16 20:47:29 -04:00
chenyu 44dfa37c70
use threefry in stable diffusion benchmark (#4988)
also updated default steps to 10. easier to tell the image is following the prompt.
2024-06-15 20:25:29 -04:00
wozeparrot ce1ed374c9
more tinychat fixes (#4971) 2024-06-15 16:29:39 -07:00
wozeparrot 8209cd3c55
easier llama3 + fetch subdir (#4938) 2024-06-14 13:47:27 -07:00
chenyu 67e8df4969
remove numpy from dtype (#4969)
replaced all dtype.np with _to_np_dtype defined in tensor.py.

after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
2024-06-14 15:38:45 -04:00
wozeparrot 2a974ff257
fix: no readablestream await of, too new (#4965) 2024-06-14 11:22:19 -07:00
Elias Wahl d2e3c391e8
Residual in MLM loss + Change default steps (#4935)
* Residual in mlm loss

* Reduce default steps to 160K * 24

* oops

* comment
2024-06-12 16:09:18 -04:00
wozeparrot 3d13c23bfa
llama3 `--download_model` (#4922) 2024-06-11 22:59:59 -07:00
wozeparrot 2849d0a2a1
fix copying to clipboard on a non secure context (#4890) 2024-06-08 16:51:47 -07:00
wozeparrot 6c24eda522
feat: tinychat (#4869) 2024-06-08 12:05:45 -07:00
Brennan Kinney 9445946cae
docs: Update referenced yaml in `yolov8.py` (#4871)
YAML files have since been relocated.
2024-06-08 15:05:00 -04:00
Nik 085c0bbf6b
add mlperf train subset of openimages (#4841) 2024-06-05 10:10:11 -04:00
Elias Wahl e576aca044
Disable dropout (#4837) 2024-06-04 18:57:26 -04:00