mirror of https://github.com/commaai/tinygrad.git
add lars to nn (#3750)
* feat: add lars * feat: don't remove this comment * clean: smaller diff * clean: shorter line * feat: remove mlperf lars, switch resnet * fix: fully remove mlperf lars * clean: comment * feat: contiguous * feat: no weight decay on skip params * feat: optimizergroup * feat: classic momentum * fix: pylint * clean: move comment * fix: correct algo * feat: lrschedulergroup * feat: skip list tests * feat: :| forgot that params are a thing * feat: remove skip_list params from main params * feat: set moment --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
8c8b57fd5f
commit
9a9cac58f9
|
@ -6,7 +6,9 @@ from tqdm import tqdm
|
|||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
|
||||
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
from examples.mlperf.helpers import get_training_state, load_training_state
|
||||
|
||||
def train_resnet():
|
||||
|
@ -60,20 +62,26 @@ def train_resnet():
|
|||
config["SYNCBN"] = getenv("SYNCBN")
|
||||
|
||||
# ** Optimizer **
|
||||
from examples.mlperf.optimizers import LARS
|
||||
skip_list = {v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k}
|
||||
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay, skip_list=skip_list)
|
||||
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
|
||||
parameters = [x for x in parameters if x not in set(skip_list)]
|
||||
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
|
||||
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
|
||||
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
|
||||
|
||||
# ** LR scheduler **
|
||||
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
|
||||
train_steps=epochs * steps_in_train_epoch,
|
||||
warmup=lr_warmup_epochs * steps_in_train_epoch)
|
||||
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
|
||||
train_steps=epochs * steps_in_train_epoch,
|
||||
warmup=lr_warmup_epochs * steps_in_train_epoch)
|
||||
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
|
||||
print(f"training with batch size {BS} for {epochs} epochs")
|
||||
|
||||
# ** resume from checkpointing **
|
||||
start_epoch = 0
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
load_training_state(model, optimizer, scheduler, safe_load(ckpt))
|
||||
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
||||
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
|
||||
print(f"resuming from {ckpt} at epoch {start_epoch}")
|
||||
|
||||
|
@ -93,14 +101,14 @@ def train_resnet():
|
|||
def normalize(x): return x.permute([0, 3, 1, 2]) - input_mean
|
||||
@TinyJit
|
||||
def train_step(X, Y):
|
||||
optimizer.zero_grad()
|
||||
optimizer_group.zero_grad()
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
top_1 = (out.argmax(-1) == Y).sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer_group.step()
|
||||
scheduler_group.step()
|
||||
return loss.realize(), top_1.realize()
|
||||
@TinyJit
|
||||
def eval_step(X, Y):
|
||||
|
@ -214,7 +222,7 @@ def train_resnet():
|
|||
else:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
|
||||
print(f"saving ckpt to {fn}")
|
||||
safe_save(get_training_state(model, optimizer, scheduler), fn)
|
||||
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
||||
|
||||
def train_retinanet():
|
||||
# TODO: Retinanet
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from typing import List, Set
|
||||
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
|
||||
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
|
||||
class LARS(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, skip_list=None, nesterov=False):
|
||||
super().__init__(params, lr)
|
||||
assert momentum >= 0.0 and weight_decay >= 0.0
|
||||
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov = momentum, weight_decay, eta, eps, nesterov
|
||||
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.skip_list = set(skip_list or [])
|
||||
|
||||
def step(self):
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.contiguous()
|
||||
w = t.detach()
|
||||
|
||||
if t not in self.skip_list:
|
||||
g_norm = (g * g).sum().sqrt()
|
||||
w_norm = (w * w).sum().sqrt()
|
||||
trust_ratio = ((w_norm > 0) * (g_norm > 0)).where(
|
||||
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps),
|
||||
1.0)
|
||||
|
||||
scaled_lr = self.lr * trust_ratio
|
||||
g = g + self.weight_decay * t.detach()
|
||||
else:
|
||||
scaled_lr = self.lr
|
||||
|
||||
g = g * scaled_lr
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g)
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g)
|
||||
self.realize(self.b)
|
|
@ -14,6 +14,11 @@ class LR_Scheduler:
|
|||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
|
||||
class LRSchedulerGroup:
|
||||
def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers
|
||||
def step(self) -> None:
|
||||
for s in self.schedulers: s.step()
|
||||
|
||||
class MultiStepLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||
super().__init__(optimizer)
|
||||
|
|
|
@ -4,11 +4,11 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import tensorflow_addons as tfa
|
||||
from tensorflow.python.ops import math_ops
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.optim import LAMB
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
||||
|
||||
from examples.mlperf.optimizers import LARS
|
||||
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
||||
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
||||
|
@ -55,7 +55,7 @@ def step(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=Tru
|
|||
out = net.forward()
|
||||
optim.zero_grad()
|
||||
out.backward()
|
||||
lrs.append(optim.lr.numpy().item())
|
||||
lrs.append(optim.lr.item() if not isinstance(optim, OptimizerGroup) else optim.optimizers[0].lr.item())
|
||||
if do_optim: optim.step()
|
||||
if scheduler is not None: scheduler.step()
|
||||
return lrs, net.x.detach().numpy(), net.W.detach().numpy()
|
||||
|
@ -83,11 +83,19 @@ def step_tf(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=
|
|||
optim._iterations.assign_add(1)
|
||||
return lrs, net.x.numpy(), net.W.numpy()
|
||||
|
||||
# skip_list=True -> skip W
|
||||
def create_tiny_lars(params, lr, skip_list=False): return LARS(params, lr, skip_list=[params[1]] if skip_list else None)
|
||||
# skip list is skipping W
|
||||
def create_tiny_lars(params, lr, skip_list=False):
|
||||
if skip_list: return OptimizerGroup(LARS([params[0]], lr), SGD([params[1]], lr, classic=True, weight_decay=0., momentum=.9))
|
||||
return LARS(params, lr)
|
||||
def create_tf_lars(lr, skip_list=False): return LARSOptimizer(lr, skip_list=["W"] if skip_list else None)
|
||||
|
||||
def create_tf_polylr(initial_lr, end_lr, train_steps, warmup, power=2):
|
||||
def create_tiny_polylr(optim, initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
|
||||
assert power == 2
|
||||
if skip_list: return LRSchedulerGroup(
|
||||
PolynomialDecayWithWarmup(optim[0], initial_lr, end_lr, train_steps, warmup, power),
|
||||
PolynomialDecayWithWarmup(optim[1], initial_lr, end_lr, train_steps, warmup, power))
|
||||
return PolynomialDecayWithWarmup(optim, initial_lr, end_lr, train_steps, warmup, power)
|
||||
def create_tf_polylr(initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
|
||||
assert power == 2
|
||||
return PolynomialDecayWithWarmup_tf(1, 1, train_steps,
|
||||
initial_learning_rate=initial_lr, end_learning_rate=end_lr, warmup_epochs=warmup)
|
||||
|
@ -102,7 +110,7 @@ class ExternalTestOptim(unittest.TestCase):
|
|||
def _test_lars(self, steps, opts, atol, rtol): self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol)
|
||||
def _test_lars_polylr(self, steps, opts, schedopts, atol, rtol, do_optim=True):
|
||||
self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol,
|
||||
tiny_sched=PolynomialDecayWithWarmup, tf_sched=create_tf_polylr, schedopts=schedopts, do_optim=do_optim)
|
||||
tiny_sched=create_tiny_polylr, tf_sched=create_tf_polylr, schedopts=schedopts, do_optim=do_optim)
|
||||
|
||||
def test_lamb(self): self._test_lamb(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_lamb_high_lr(self): self._test_lamb(1, {'lr': 10}, 1e-5, 1e-5)
|
||||
|
@ -112,9 +120,12 @@ class ExternalTestOptim(unittest.TestCase):
|
|||
|
||||
def test_lars(self): self._test_lars(1, {'lr': 0.01}, 1e-5, 0)
|
||||
def test_lars_high_lr(self): self._test_lars(1, {'lr': 10}, 1e-5, 1e-5)
|
||||
def test_multistep_lars(self): self._test_lamb(10, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_multistep_lars_high_lr(self): self._test_lamb(10, {'lr': 10}, 1e-5, 3e-4)
|
||||
def test_lars_skip_list(self): self._test_lars(1, {'lr': 0.01, 'skip_list': True}, 1e-5, 0)
|
||||
def test_multistep_lars(self): self._test_lars(10, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_multistep_lars_high_lr(self): self._test_lars(10, {'lr': 10}, 1e-5, 3e-4)
|
||||
def test_lars_skip(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
|
||||
def test_lars_skip_high_lr(self): self._test_lars(1, {'lr': 10, 'skip_list': True}, 1e-5, 1e-5)
|
||||
def test_lars_skip_multistep(self): self._test_lars(10, {'lr': 0.001, 'skip_list': True}, 1e-5, 0)
|
||||
def test_lars_skip_multistep_high_lr(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
|
||||
|
||||
def test_lars_polylr(self):
|
||||
self._test_lars_polylr(10, {'lr': 1.0}, {
|
||||
|
@ -130,6 +141,15 @@ class ExternalTestOptim(unittest.TestCase):
|
|||
'train_steps': 100,
|
||||
'warmup': 43
|
||||
}, 1e-5, 1e-5, do_optim=False)
|
||||
def test_lars_polylr_skip(self):
|
||||
self._test_lars_polylr(10, {'lr': 1.0, 'skip_list': True}, {
|
||||
'initial_lr': 1.0,
|
||||
'end_lr': 1e-4,
|
||||
'train_steps': 10,
|
||||
'warmup': 3,
|
||||
'skip_list': True
|
||||
}, 1e-5, 1e-5)
|
||||
|
||||
@unittest.skip("slow, but you can run this locally to check")
|
||||
def test_lars_polylr_resnet(self):
|
||||
train_files = 1_281_167
|
||||
|
|
|
@ -237,7 +237,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
from resnet import ResNet18
|
||||
from examples.mlperf.optimizers import LARS
|
||||
from tinygrad.nn.optim import LARS
|
||||
|
||||
fake_image = Tensor.rand((2, 3, 224, 224))
|
||||
fake_image_sharded = fake_image.shard((d0, d1), axis=0)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from typing import List, Optional
|
||||
from tinygrad.helpers import dedup, flatten, getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
|
@ -21,26 +21,48 @@ class Optimizer:
|
|||
def realize(self, extra=None):
|
||||
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
|
||||
|
||||
def step(self) -> None: raise NotImplementedError
|
||||
def step(self, extra:Optional[List[Tensor]]=None): self.realize(self._step() + (extra if extra is not None else []))
|
||||
def _step(self) -> List[Tensor]: raise NotImplementedError
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
|
||||
class OptimizerGroup(Optimizer):
|
||||
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
|
||||
self.optimizers = optimizers
|
||||
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
|
||||
def __getitem__(self, i): return self.optimizers[i]
|
||||
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
||||
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
|
||||
|
||||
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
|
||||
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
|
||||
return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
|
||||
|
||||
class LARS(Optimizer):
|
||||
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
|
||||
super().__init__(params, lr)
|
||||
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
|
||||
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
||||
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
||||
def step(self) -> None:
|
||||
def _step(self) -> List[Tensor]:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
# contiguous is needed since the grads can allegedly form a "diamond"
|
||||
# TODO: fix this in lazy.py
|
||||
g = t.grad.contiguous() + self.wd * t.detach()
|
||||
g = t.grad.contiguous()
|
||||
if self.tcoef != 0:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = g.square().sum().sqrt()
|
||||
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
||||
else: r = 1.0
|
||||
g = g + self.wd * t.detach()
|
||||
# classic momentum does post learning rate update
|
||||
if self.classic: g = g * r * self.lr
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
t.assign(t.detach() - g)
|
||||
return self.b
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
||||
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
|
||||
|
@ -53,7 +75,7 @@ class LAMB(Optimizer):
|
|||
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self) -> None:
|
||||
def _step(self) -> List[Tensor]:
|
||||
self.t.assign(self.t + 1)
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
|
@ -69,4 +91,4 @@ class LAMB(Optimizer):
|
|||
else:
|
||||
r = 1.0
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
return [self.t] + self.m + self.v
|
||||
|
|
Loading…
Reference in New Issue