mirror of https://github.com/commaai/tinygrad.git
optim.OptimizerGroup in hlb_cifar (#5401)
This commit is contained in:
parent
01fbd18209
commit
8390feb7b9
|
@ -334,12 +334,9 @@ def train_cifar():
|
|||
|
||||
if not getenv("DISABLE_BACKWARD"):
|
||||
# index 0 for bias and 1 for non-bias
|
||||
optimizer[0].zero_grad()
|
||||
optimizer[1].zero_grad()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
optimizer[0].step()
|
||||
optimizer[1].step()
|
||||
optimizer.step()
|
||||
lr_scheduler[0].step()
|
||||
lr_scheduler[1].step()
|
||||
return loss.realize()
|
||||
|
@ -408,7 +405,7 @@ def train_cifar():
|
|||
Y.shard_(GPUS, axis=0)
|
||||
|
||||
with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
loss = train_step_jitted(model, optim.OptimizerGroup(opt_bias, opt_non_bias), [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
|
|
Loading…
Reference in New Issue