log resnet TRAIN_BEAM / EVAL_BEAM (#4181)

also run eval in benchmark mode if either one is positive
This commit is contained in:
chenyu 2024-04-15 19:29:08 -04:00 committed by GitHub
parent 9d2273235c
commit d5b67c1ca3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 6 deletions

View File

@ -28,6 +28,9 @@ def train_resnet():
print(f"Training on {GPUS}")
for x in GPUS: Device[x]
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
# ** model definition and initializers **
num_classes = 1000
resnet.Conv2d = Conv2dHeNormal
@ -61,9 +64,11 @@ def train_resnet():
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["BEAM"] = BEAM.value
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
config["BEAM"] = BEAM.value
config["TRAIN_BEAM"] = TRAIN_BEAM
config["EVAL_BEAM"] = EVAL_BEAM
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
@ -105,7 +110,7 @@ def train_resnet():
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
@TinyJit
def train_step(X, Y):
with Context(BEAM=getenv("TRAIN_BEAM", BEAM.value)):
with Context(BEAM=TRAIN_BEAM):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
@ -119,7 +124,7 @@ def train_resnet():
@TinyJit
def eval_step(X, Y):
with Context(BEAM=getenv("EVAL_BEAM", BEAM.value)):
with Context(BEAM=EVAL_BEAM):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
@ -177,7 +182,7 @@ def train_resnet():
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
# if we are doing beam search, run the first eval too
if BEAM.value and e == start_epoch: break
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
# ** eval loop **