2021-06-22 00:37:24 +08:00
|
|
|
import tinygrad.nn as nn
|
2024-03-28 13:25:37 +08:00
|
|
|
from tinygrad import Tensor, dtypes
|
2023-11-23 09:41:12 +08:00
|
|
|
from tinygrad.nn.state import torch_load
|
2023-11-29 09:36:55 +08:00
|
|
|
from tinygrad.helpers import fetch, get_child
|
2021-06-22 00:37:24 +08:00
|
|
|
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
# allow monkeypatching in layer implementations
|
|
|
|
BatchNorm = nn.BatchNorm2d
|
|
|
|
Conv2d = nn.Conv2d
|
|
|
|
Linear = nn.Linear
|
|
|
|
|
|
|
|
|
2021-06-22 00:37:24 +08:00
|
|
|
class BasicBlock:
|
|
|
|
expansion = 1
|
|
|
|
|
2023-05-29 11:20:16 +08:00
|
|
|
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
|
|
|
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
self.bn1 = BatchNorm(planes)
|
|
|
|
self.conv2 = Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
|
|
|
self.bn2 = BatchNorm(planes)
|
2021-11-30 07:05:31 +08:00
|
|
|
self.downsample = []
|
2021-06-22 00:37:24 +08:00
|
|
|
if stride != 1 or in_planes != self.expansion*planes:
|
2021-11-30 07:05:31 +08:00
|
|
|
self.downsample = [
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
|
|
|
BatchNorm(self.expansion*planes)
|
2021-11-30 07:05:31 +08:00
|
|
|
]
|
2021-06-22 00:37:24 +08:00
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
out = self.bn1(self.conv1(x)).relu()
|
|
|
|
out = self.bn2(self.conv2(out))
|
2021-11-30 07:05:31 +08:00
|
|
|
out = out + x.sequential(self.downsample)
|
2021-06-22 00:37:24 +08:00
|
|
|
out = out.relu()
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class Bottleneck:
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
2021-06-22 00:37:24 +08:00
|
|
|
expansion = 4
|
|
|
|
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
|
2023-05-29 11:20:16 +08:00
|
|
|
width = int(planes * (base_width / 64.0)) * groups
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
self.conv1 = Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
|
|
|
|
self.bn1 = BatchNorm(width)
|
|
|
|
self.conv2 = Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
|
|
|
|
self.bn2 = BatchNorm(width)
|
|
|
|
self.conv3 = Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
|
|
|
self.bn3 = BatchNorm(self.expansion*planes)
|
2021-11-30 07:05:31 +08:00
|
|
|
self.downsample = []
|
2021-06-22 00:37:24 +08:00
|
|
|
if stride != 1 or in_planes != self.expansion*planes:
|
2021-11-30 07:05:31 +08:00
|
|
|
self.downsample = [
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
|
|
|
BatchNorm(self.expansion*planes)
|
2021-11-30 07:05:31 +08:00
|
|
|
]
|
2021-06-22 00:37:24 +08:00
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
out = self.bn1(self.conv1(x)).relu()
|
|
|
|
out = self.bn2(self.conv2(out)).relu()
|
|
|
|
out = self.bn3(self.conv3(out))
|
2021-12-01 04:54:03 +08:00
|
|
|
out = out + x.sequential(self.downsample)
|
2021-06-22 00:37:24 +08:00
|
|
|
out = out.relu()
|
|
|
|
return out
|
|
|
|
|
|
|
|
class ResNet:
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
|
2022-01-16 12:22:10 +08:00
|
|
|
self.num = num
|
|
|
|
self.block = {
|
|
|
|
18: BasicBlock,
|
|
|
|
34: BasicBlock,
|
|
|
|
50: Bottleneck,
|
|
|
|
101: Bottleneck,
|
|
|
|
152: Bottleneck
|
|
|
|
}[num]
|
|
|
|
|
|
|
|
self.num_blocks = {
|
|
|
|
18: [2,2,2,2],
|
|
|
|
34: [3,4,6,3],
|
|
|
|
50: [3,4,6,3],
|
|
|
|
101: [3,4,23,3],
|
|
|
|
152: [3,8,36,3]
|
|
|
|
}[num]
|
|
|
|
|
2021-06-22 00:37:24 +08:00
|
|
|
self.in_planes = 64
|
|
|
|
|
2023-05-29 11:20:16 +08:00
|
|
|
self.groups = groups
|
|
|
|
self.base_width = width_per_group
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
|
|
|
self.bn1 = BatchNorm(64)
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
|
|
|
|
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
|
|
|
|
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
|
|
|
|
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
|
MLPerf Resnet (cleaned up) (#3573)
* this is a lot of stuff
TEST_TRAIN env for less data
don't diskcache get_train_files
debug message
no lr_scaler for fp32
comment, typo
type stuff
don't destructure proc
make batchnorm parameters float
make batchnorm parameters float
resnet18, checkpointing
hack up checkpointing to keep the names in there
oops
wandb_resume
lower lr
eval/ckpt use e+1
lars
report top_1_acc
some wandb stuff
split fw and bw steps to save memory
oops
save model when reach target
formatting
make sgd hparams consistent
just always write the cats tag...
pass X and Y into backward_step to trigger input replace
shuffle eval set to fix batchnorm eval
dataset is sorted by class, so the means and variances are all wrong
small cleanup
hack restore only one copy of each tensor
do bufs from lin after cache check (lru should handle it fine)
record epoch in wandb
more digits for topk in eval
more env vars
small cleanup
cleanup hack tricks
cleanup hack tricks
don't save ckpt for testeval
cleanup
diskcache train file glob
clean up a little
device_str
SCE into tensor
small
small
log_softmax out of resnet.py
oops
hack :(
comments
HeNormal, track gradient norm
oops
log SYNCBN to wandb
real truncnorm
less samples for truncated normal
custom init for Linear
log layer stats
small
Revert "small"
This reverts commit 988f4c1cf35ca4be6c31facafccdd1e177469f2f.
Revert "log layer stats"
This reverts commit 9d9822458524c514939adeee34b88356cd191cb0.
rename BNSYNC to SYNCBN to be consistent with cifar
optional TRACK_NORMS
fix label smoothing :/
lars skip list
only weight decay if not in skip list
comment
default 0 TRACK_NORMS
don't allocate beam scratch buffers if in cache
clean up data pipeline, unsplit train/test, put back a hack
remove print
run test_indexing on remu (#3404)
* emulated ops_hip infra
* add int4
* include test_indexing in remu
* Revert "Merge branch 'remu-dev-mac'"
This reverts commit 6870457e57dc5fa70169189fd33b24dbbee99c40, reversing
changes made to 3c4c8c9e16d87b291d05e1cab558124cc339ac46.
fix bad seeding
UnsyncBatchNorm2d but with synced trainable weights
label downsample batchnorm in Bottleneck
:/
:/
i mean... it runs... its hits the acc... its fast...
new unsyncbatchnorm for resnet
small fix
don't do assign buffer reuse for axis change
* remove changes
* remove changes
* move LARS out of tinygrad/
* rand_truncn rename
* whitespace
* stray whitespace
* no more gnorms
* delete some dataloading stuff
* remove comment
* clean up train script
* small comments
* move checkpointing stuff to mlperf helpers
* if WANDB
* small comments
* remove whitespace change
* new unsynced bn
* clean up prints / loop vars
* whitespace
* undo nn changes
* clean up loops
* rearrange getenvs
* cpu_count()
* PolynomialLR whitespace
* move he_normal out
* cap warmup in polylr
* rearrange wandb log
* realize both x and y in data_get
* use double quotes
* combine prints in ckpts resume
* take UBN from cifar
* running_var
* whitespace
* whitespace
* typo
* if instead of ternary for resnet downsample
* clean up dataloader cleanup a little?
* separate rng for shuffle
* clean up imports in model_train
* clean up imports
* don't realize copyin in data_get
* remove TESTEVAL (train dataloader didn't get freed every loop)
* adjust wandb_config entries a little
* clean up wandb config dict
* reduce lines
* whitespace
* shorter lines
* put shm unlink back, but it doesn't seem to do anything
* don't pass seed per task
* monkeypatch batchnorm
* the reseed was wrong
* add epoch number to desc
* don't unsyncedbatchnorm is syncbn=1
* put back downsample name
* eval every epoch
* Revert "the reseed was wrong"
This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f.
* cast lr in onecycle
* support fp16
* cut off kernel if expand after reduce
* test polynomial lr
* move polynomiallr to examples/mlperf
* working PolynomialDecayWithWarmup + tests.......
add lars_util.py, oops
* keep lars_util.py as intact as possible, simplify our interface
* no more half
* polylr and lars were merged
* undo search change
* override Linear init
* remove half stuff from model_train
* update scheduler init with new args
* don't divide by input mean
* mistake in resnet.py
* restore whitespace in resnet.py
* add test_data_parallel_resnet_train_step
* move initializers out of resnet.py
* unused imports
* log_softmax to model output in test to fix precision flakiness
* log_softmax to model output in test to fix precision flakiness
* oops, don't realize here
* is None
* realize initializations in order for determinism
* BENCHMARK flag for number of steps
* add resnet to bechmark.yml
* return instead of break
* missing return
* cpu_count, rearrange benchmark.yml
* unused variable
* disable tqdm if BENCHMARK
* getenv WARMUP_EPOCHS
* unlink disktensor shm file if exists
* terminate instead of join
* properly shut down queues
* use hip in benchmark for now
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-03-14 12:53:41 +08:00
|
|
|
self.fc = Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
|
2021-06-22 00:37:24 +08:00
|
|
|
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
2021-06-22 00:37:24 +08:00
|
|
|
strides = [stride] + [1] * (num_blocks-1)
|
|
|
|
layers = []
|
|
|
|
for stride in strides:
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
if block == Bottleneck:
|
|
|
|
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
|
|
|
|
else:
|
|
|
|
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
2021-06-22 00:37:24 +08:00
|
|
|
self.in_planes = planes * block.expansion
|
2021-11-30 07:05:31 +08:00
|
|
|
return layers
|
2021-06-22 00:37:24 +08:00
|
|
|
|
|
|
|
def forward(self, x):
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
is_feature_only = self.fc is None
|
|
|
|
if is_feature_only: features = []
|
2021-06-22 00:37:24 +08:00
|
|
|
out = self.bn1(self.conv1(x)).relu()
|
2023-05-13 22:46:27 +08:00
|
|
|
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
|
2021-11-30 07:05:31 +08:00
|
|
|
out = out.sequential(self.layer1)
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
if is_feature_only: features.append(out)
|
2021-11-30 07:05:31 +08:00
|
|
|
out = out.sequential(self.layer2)
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
if is_feature_only: features.append(out)
|
2021-11-30 07:05:31 +08:00
|
|
|
out = out.sequential(self.layer3)
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
if is_feature_only: features.append(out)
|
2021-11-30 07:05:31 +08:00
|
|
|
out = out.sequential(self.layer4)
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
if is_feature_only: features.append(out)
|
|
|
|
if not is_feature_only:
|
|
|
|
out = out.mean([2,3])
|
2024-03-28 13:25:37 +08:00
|
|
|
out = self.fc(out.cast(dtypes.float32))
|
MaskRCNN Inference (#884)
* MaskRCNN weights loading
* backbone maybe works
* backbone works, but resnet body atol 1e-3
* RPN Call, but veryy wrong output
* fixed topk
* RPN maybe works, not sure about nms
* Fix cursed modules
* add back editorconfig
* Full call, wrong output
* Full call works
* fix mask
* use NMS from retinanet
* Removing extra funcs
* refactor
* readable
* Add example to run model
* remove filter
* Fix split, batched inference is worse
* Fix image sizes
* Matching reference
* merge master
* add filter on top detections
* cuda backend fixed
* add model eval and spec
* convert images to rgb
* fix eval
* simplify examples code
* remove extra code
* meshgrid using tinygrad
* removing numpy
* roi align, floor, ceil
* remove numpy from level_mapper
* remove numpy from pooler
* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"
This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.
* roi align gather
* fix master merge
* revert to old floor, ceil as ints present in domain
* use log2 op
* fix indexes
* weird bug with ints and gpu
* weird bug with ints and gpu
* refactors, add env var for gather
* floor with contiguous, where
* refactor topk, sort
* remove staticmethod
* refactor stride
* remove log2 mlop
* realize -> contiguous
* refactor forward
* remove num_classes, stride_in_1x1 from state
* refactor forward
* refactoring
* flake8
* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk
* keep using tinygrad for smaller gathers
* fix empty tensors
* comms
* move from tensor.py
* resnet test passing
* add coco dataset back
* fix spaces
* add test for log2
* no need to create Tensors
* no need to create Tensors
---------
Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
2023-06-26 06:37:51 +08:00
|
|
|
return out
|
|
|
|
return features
|
2021-06-22 00:37:24 +08:00
|
|
|
|
2023-09-30 00:34:51 +08:00
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
2021-06-22 00:37:24 +08:00
|
|
|
return self.forward(x)
|
|
|
|
|
2021-12-01 05:14:54 +08:00
|
|
|
def load_from_pretrained(self):
|
2022-01-16 12:22:10 +08:00
|
|
|
model_urls = {
|
2023-05-29 11:20:16 +08:00
|
|
|
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
|
|
|
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
|
|
|
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
|
|
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
|
|
|
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
|
|
|
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
2022-01-16 12:22:10 +08:00
|
|
|
}
|
|
|
|
|
2023-05-29 11:20:16 +08:00
|
|
|
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
2024-09-06 12:50:21 +08:00
|
|
|
for k, dat in torch_load(fetch(self.url)).items():
|
2023-11-23 09:41:12 +08:00
|
|
|
obj: Tensor = get_child(self, k)
|
2022-01-16 12:22:01 +08:00
|
|
|
|
2022-01-16 12:22:10 +08:00
|
|
|
if 'fc.' in k and obj.shape != dat.shape:
|
2022-06-06 08:12:43 +08:00
|
|
|
print("skipping fully connected layer")
|
2022-01-16 12:22:10 +08:00
|
|
|
continue # Skip FC if transfer learning
|
|
|
|
|
2024-06-27 06:44:10 +08:00
|
|
|
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
2024-09-06 13:06:02 +08:00
|
|
|
obj.assign(dat.to(obj.device).reshape(obj.shape))
|
2021-06-22 00:37:24 +08:00
|
|
|
|
2022-01-16 12:22:10 +08:00
|
|
|
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
|
|
|
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
|
|
|
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
|
|
|
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
|
|
|
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
2024-06-27 06:44:10 +08:00
|
|
|
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
2024-09-06 12:50:21 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model = ResNet18()
|
|
|
|
model.load_from_pretrained()
|
|
|
|
from tinygrad import Context, GlobalCounters, TinyJit
|
|
|
|
jmodel = TinyJit(model)
|
|
|
|
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
|
|
|
|
GlobalCounters.reset()
|
2024-10-16 11:40:07 +08:00
|
|
|
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
|
2024-09-06 12:50:21 +08:00
|
|
|
for i in range(10): jmodel(Tensor.rand(1, 3, 224, 224))
|