add lars to nn (#3750)

* feat: add lars

* feat: don't remove this comment

* clean: smaller diff

* clean: shorter line

* feat: remove mlperf lars, switch resnet

* fix: fully remove mlperf lars

* clean: comment

* feat: contiguous

* feat: no weight decay on skip params

* feat: optimizergroup

* feat: classic momentum

* fix: pylint

* clean: move comment

* fix: correct algo

* feat: lrschedulergroup

* feat: skip list tests

* feat: :| forgot that params are a thing

* feat: remove skip_list params from main params

* feat: set moment

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
wozeparrot 2024-03-24 11:43:12 -04:00 committed by GitHub
parent 8c8b57fd5f
commit 9a9cac58f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 87 additions and 70 deletions

View File

@ -6,7 +6,9 @@ from tqdm import tqdm
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
def train_resnet():
@ -60,20 +62,26 @@ def train_resnet():
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
from examples.mlperf.optimizers import LARS
skip_list = {v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k}
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay, skip_list=skip_list)
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
print(f"training with batch size {BS} for {epochs} epochs")
# ** resume from checkpointing **
start_epoch = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer, scheduler, safe_load(ckpt))
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
@ -93,14 +101,14 @@ def train_resnet():
def normalize(x): return x.permute([0, 3, 1, 2]) - input_mean
@TinyJit
def train_step(X, Y):
optimizer.zero_grad()
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
loss.backward()
optimizer.step()
scheduler.step()
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
@ -214,7 +222,7 @@ def train_resnet():
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer, scheduler), fn)
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
# TODO: Retinanet

View File

@ -1,38 +0,0 @@
from typing import List, Set
from tinygrad import Tensor
from tinygrad.nn.optim import Optimizer
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
class LARS(Optimizer):
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, skip_list=None, nesterov=False):
super().__init__(params, lr)
assert momentum >= 0.0 and weight_decay >= 0.0
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov = momentum, weight_decay, eta, eps, nesterov
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.skip_list = set(skip_list or [])
def step(self):
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.contiguous()
w = t.detach()
if t not in self.skip_list:
g_norm = (g * g).sum().sqrt()
w_norm = (w * w).sum().sqrt()
trust_ratio = ((w_norm > 0) * (g_norm > 0)).where(
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps),
1.0)
scaled_lr = self.lr * trust_ratio
g = g + self.weight_decay * t.detach()
else:
scaled_lr = self.lr
g = g * scaled_lr
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g)
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g)
self.realize(self.b)

View File

@ -14,6 +14,11 @@ class LR_Scheduler:
self.epoch_counter.assign(self.epoch_counter + 1).realize()
self.optimizer.lr.assign(self.get_lr()).realize()
class LRSchedulerGroup:
def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers
def step(self) -> None:
for s in self.schedulers: s.step()
class MultiStepLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
super().__init__(optimizer)

View File

@ -4,11 +4,11 @@ import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from examples.mlperf.optimizers import LARS
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
@ -55,7 +55,7 @@ def step(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=Tru
out = net.forward()
optim.zero_grad()
out.backward()
lrs.append(optim.lr.numpy().item())
lrs.append(optim.lr.item() if not isinstance(optim, OptimizerGroup) else optim.optimizers[0].lr.item())
if do_optim: optim.step()
if scheduler is not None: scheduler.step()
return lrs, net.x.detach().numpy(), net.W.detach().numpy()
@ -83,11 +83,19 @@ def step_tf(optim, steps=1, kwargs={}, scheduler=None, schedopts=None, do_optim=
optim._iterations.assign_add(1)
return lrs, net.x.numpy(), net.W.numpy()
# skip_list=True -> skip W
def create_tiny_lars(params, lr, skip_list=False): return LARS(params, lr, skip_list=[params[1]] if skip_list else None)
# skip list is skipping W
def create_tiny_lars(params, lr, skip_list=False):
if skip_list: return OptimizerGroup(LARS([params[0]], lr), SGD([params[1]], lr, classic=True, weight_decay=0., momentum=.9))
return LARS(params, lr)
def create_tf_lars(lr, skip_list=False): return LARSOptimizer(lr, skip_list=["W"] if skip_list else None)
def create_tf_polylr(initial_lr, end_lr, train_steps, warmup, power=2):
def create_tiny_polylr(optim, initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
assert power == 2
if skip_list: return LRSchedulerGroup(
PolynomialDecayWithWarmup(optim[0], initial_lr, end_lr, train_steps, warmup, power),
PolynomialDecayWithWarmup(optim[1], initial_lr, end_lr, train_steps, warmup, power))
return PolynomialDecayWithWarmup(optim, initial_lr, end_lr, train_steps, warmup, power)
def create_tf_polylr(initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
assert power == 2
return PolynomialDecayWithWarmup_tf(1, 1, train_steps,
initial_learning_rate=initial_lr, end_learning_rate=end_lr, warmup_epochs=warmup)
@ -102,7 +110,7 @@ class ExternalTestOptim(unittest.TestCase):
def _test_lars(self, steps, opts, atol, rtol): self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol)
def _test_lars_polylr(self, steps, opts, schedopts, atol, rtol, do_optim=True):
self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol,
tiny_sched=PolynomialDecayWithWarmup, tf_sched=create_tf_polylr, schedopts=schedopts, do_optim=do_optim)
tiny_sched=create_tiny_polylr, tf_sched=create_tf_polylr, schedopts=schedopts, do_optim=do_optim)
def test_lamb(self): self._test_lamb(1, {'lr': 0.001}, 1e-5, 0)
def test_lamb_high_lr(self): self._test_lamb(1, {'lr': 10}, 1e-5, 1e-5)
@ -112,9 +120,12 @@ class ExternalTestOptim(unittest.TestCase):
def test_lars(self): self._test_lars(1, {'lr': 0.01}, 1e-5, 0)
def test_lars_high_lr(self): self._test_lars(1, {'lr': 10}, 1e-5, 1e-5)
def test_multistep_lars(self): self._test_lamb(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_lars_high_lr(self): self._test_lamb(10, {'lr': 10}, 1e-5, 3e-4)
def test_lars_skip_list(self): self._test_lars(1, {'lr': 0.01, 'skip_list': True}, 1e-5, 0)
def test_multistep_lars(self): self._test_lars(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_lars_high_lr(self): self._test_lars(10, {'lr': 10}, 1e-5, 3e-4)
def test_lars_skip(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
def test_lars_skip_high_lr(self): self._test_lars(1, {'lr': 10, 'skip_list': True}, 1e-5, 1e-5)
def test_lars_skip_multistep(self): self._test_lars(10, {'lr': 0.001, 'skip_list': True}, 1e-5, 0)
def test_lars_skip_multistep_high_lr(self): self._test_lars(10, {'lr': 10, 'skip_list': True}, 1e-5, 3e-4)
def test_lars_polylr(self):
self._test_lars_polylr(10, {'lr': 1.0}, {
@ -130,6 +141,15 @@ class ExternalTestOptim(unittest.TestCase):
'train_steps': 100,
'warmup': 43
}, 1e-5, 1e-5, do_optim=False)
def test_lars_polylr_skip(self):
self._test_lars_polylr(10, {'lr': 1.0, 'skip_list': True}, {
'initial_lr': 1.0,
'end_lr': 1e-4,
'train_steps': 10,
'warmup': 3,
'skip_list': True
}, 1e-5, 1e-5)
@unittest.skip("slow, but you can run this locally to check")
def test_lars_polylr_resnet(self):
train_files = 1_281_167

View File

@ -237,7 +237,7 @@ class TestMultiTensor(unittest.TestCase):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
from resnet import ResNet18
from examples.mlperf.optimizers import LARS
from tinygrad.nn.optim import LARS
fake_image = Tensor.rand((2, 3, 224, 224))
fake_image_sharded = fake_image.shard((d0, d1), axis=0)

View File

@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup, getenv
from typing import List, Optional
from tinygrad.helpers import dedup, flatten, getenv
from tinygrad.tensor import Tensor
class Optimizer:
@ -21,26 +21,48 @@ class Optimizer:
def realize(self, extra=None):
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
def step(self) -> None: raise NotImplementedError
def step(self, extra:Optional[List[Tensor]]=None): self.realize(self._step() + (extra if extra is not None else []))
def _step(self) -> List[Tensor]: raise NotImplementedError
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
class OptimizerGroup(Optimizer):
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
self.optimizers = optimizers
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
def __getitem__(self, i): return self.optimizers[i]
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)
class LARS(Optimizer):
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
super().__init__(params, lr)
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
def step(self) -> None:
def _step(self) -> List[Tensor]:
for i, t in enumerate(self.params):
assert t.grad is not None
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = t.grad.contiguous() + self.wd * t.detach()
g = t.grad.contiguous()
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
else: r = 1.0
g = g + self.wd * t.detach()
# classic momentum does post learning rate update
if self.classic: g = g * r * self.lr
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
# popular momentum does pre learning rate update
if not self.classic: g = g * r * self.lr
t.assign(t.detach() - g)
return self.b
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
@ -53,7 +75,7 @@ class LAMB(Optimizer):
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self) -> None:
def _step(self) -> List[Tensor]:
self.t.assign(self.t + 1)
for i, t in enumerate(self.params):
assert t.grad is not None
@ -69,4 +91,4 @@ class LAMB(Optimizer):
else:
r = 1.0
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)
return [self.t] + self.m + self.v