mirror of https://github.com/commaai/tinygrad.git
use at least float32 for optim.lr (#4297)
* use at least float32 for optim.lr when doing mixed precision training (float32 weight, default_float=half), still use float32 to store lr. it would have been upcasted later in actual weight update, but would have lost precision. this improved resnet convergence significantly * undo type annotation
This commit is contained in:
parent
6f792b727b
commit
5ae252ae83
|
@ -66,7 +66,8 @@ class CosineAnnealingLR(LR_Scheduler):
|
|||
self.eta_max = optimizer.lr.numpy()[0]
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
|
||||
lr = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))
|
||||
return Tensor([lr], device=self.optimizer.device, dtype=self.optimizer.lr.dtype)
|
||||
|
||||
class OneCycleLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
|
||||
|
@ -88,4 +89,4 @@ class OneCycleLR(LR_Scheduler):
|
|||
return (self.epoch_counter < self.total_steps*self.pct_start).where(
|
||||
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
|
||||
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
|
||||
)
|
||||
).cast(self.optimizer.lr.dtype)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW
|
||||
from tinygrad.helpers import CI
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
np.random.seed(1337)
|
||||
x_init = np.random.randn(1,4).astype(np.float32)
|
||||
|
@ -105,5 +106,14 @@ class TestOptim(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mixed_precision(self):
|
||||
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
|
||||
# weight update would overflow without upcasting
|
||||
self._test_sgd(10, {'lr': 1e10}, 1e-6, 3e-4)
|
||||
self._test_adam(1, {'lr': 1e10}, 1e-4, 1e-4)
|
||||
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -2,6 +2,7 @@
|
|||
from typing import List
|
||||
from tinygrad.helpers import dedup, flatten, getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, params: List[Tensor], lr: float):
|
||||
|
@ -13,7 +14,9 @@ class Optimizer:
|
|||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.device = self.params[0].device
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
# store lr in at least float32 precision
|
||||
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
||||
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
|
@ -59,7 +62,7 @@ class LARS(Optimizer):
|
|||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
t.assign(t.detach() - g)
|
||||
t.assign((t.detach() - g).cast(t.dtype))
|
||||
return self.b
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
||||
|
@ -89,5 +92,5 @@ class LAMB(Optimizer):
|
|||
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
return [self.t] + self.m + self.v
|
||||
|
|
Loading…
Reference in New Issue