optim.OptimizerGroup in hlb_cifar (#5401)

This commit is contained in:
George Hotz 2024-07-11 20:14:36 -07:00 committed by GitHub
parent 01fbd18209
commit 8390feb7b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 6 deletions

View File

@ -334,12 +334,9 @@ def train_cifar():
if not getenv("DISABLE_BACKWARD"):
# index 0 for bias and 1 for non-bias
optimizer[0].zero_grad()
optimizer[1].zero_grad()
optimizer.zero_grad()
loss.backward()
optimizer[0].step()
optimizer[1].step()
optimizer.step()
lr_scheduler[0].step()
lr_scheduler[1].step()
return loss.realize()
@ -408,7 +405,7 @@ def train_cifar():
Y.shard_(GPUS, axis=0)
with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
loss = train_step_jitted(model, optim.OptimizerGroup(opt_bias, opt_non_bias), [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()
loss_cpu = loss.numpy()
# EMA for network weights