diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index b3aac3e5..4c661ebd 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -802,7 +802,7 @@ def train_bert(): if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK): if MLLOGGER and RUNMLPERF: MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i}) - train_step_bert.reset() + if getenv("RESET_STEP", 1): train_step_bert.reset() eval_lm_losses = [] eval_clsf_losses = [] eval_lm_accs = [] @@ -840,7 +840,7 @@ def train_bert(): MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None) return - eval_step_bert.reset() + if getenv("RESET_STEP", 1): eval_step_bert.reset() del eval_data, eval_result avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses) avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)