use at least float32 for optim.lr (#4297)

* use at least float32 for optim.lr

when doing mixed precision training (float32 weight, default_float=half), still use float32 to store lr.
it would have been upcasted later in actual weight update, but would have lost precision.
this improved resnet convergence significantly

* undo type annotation
This commit is contained in:
chenyu 2024-04-25 14:42:28 -04:00 committed by GitHub
parent 6f792b727b
commit 5ae252ae83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 6 deletions

View File

@ -66,7 +66,8 @@ class CosineAnnealingLR(LR_Scheduler):
self.eta_max = optimizer.lr.numpy()[0]
def get_lr(self) -> Tensor:
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
lr = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))
return Tensor([lr], device=self.optimizer.device, dtype=self.optimizer.lr.dtype)
class OneCycleLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
@ -88,4 +89,4 @@ class OneCycleLR(LR_Scheduler):
return (self.epoch_counter < self.total_steps*self.pct_start).where(
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
)
).cast(self.optimizer.lr.dtype)

View File

@ -1,9 +1,10 @@
import numpy as np
import torch
import unittest
from tinygrad import Tensor, Device
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW
from tinygrad.helpers import CI
from test.helpers import is_dtype_supported
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
@ -105,5 +106,14 @@ class TestOptim(unittest.TestCase):
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mixed_precision(self):
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
# weight update would overflow without upcasting
self._test_sgd(10, {'lr': 1e10}, 1e-6, 3e-4)
self._test_adam(1, {'lr': 1e10}, 1e-4, 1e-4)
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
dtypes.default_float = old_default_float
if __name__ == '__main__':
unittest.main()

View File

@ -2,6 +2,7 @@
from typing import List
from tinygrad.helpers import dedup, flatten, getenv
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype
class Optimizer:
def __init__(self, params: List[Tensor], lr: float):
@ -13,7 +14,9 @@ class Optimizer:
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
# store lr in at least float32 precision
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
def zero_grad(self):
for param in self.params: param.grad = None
@ -59,7 +62,7 @@ class LARS(Optimizer):
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
# popular momentum does pre learning rate update
if not self.classic: g = g * r * self.lr
t.assign(t.detach() - g)
t.assign((t.detach() - g).cast(t.dtype))
return self.b
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
@ -89,5 +92,5 @@ class LAMB(Optimizer):
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign(t.detach() - self.lr * r * up)
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
return [self.t] + self.m + self.v