adam in benchmark_train_efficientnet

This commit is contained in:
George Hotz 2022-07-19 09:33:07 -07:00
parent ef1100fdff
commit acbeaf0ba9
1 changed files with 3 additions and 1 deletions

View File

@ -17,13 +17,15 @@ BS = int(os.getenv("BS", 8))
CNT = int(os.getenv("CNT", 10))
BACKWARD = int(os.getenv("BACKWARD", 0))
TRAINING = int(os.getenv("TRAINING", 1))
ADAM = int(os.getenv("ADAM", 0))
if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
parameters = get_parameters(model)
for p in parameters: p.realize()
optimizer = optim.SGD(parameters, lr=0.001)
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
else: optimizer = optim.SGD(parameters, lr=0.001)
Tensor.training = TRAINING
Tensor.no_grad = not BACKWARD