mirror of https://github.com/commaai/tinygrad.git
update bert epoch logging (#6940)
* update bert epoch logging epoch for bert is simply number of examples seen (which is used for RCP check) * update total steps too * more changes
This commit is contained in:
parent
0498e846a5
commit
a78c96273a
|
@ -649,7 +649,7 @@ def train_bert():
|
|||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0001 * math.sqrt(BS/66))
|
||||
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3000000 // BS)
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS)
|
||||
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
|
||||
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
|
||||
|
@ -749,15 +749,12 @@ def train_bert():
|
|||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
if MLLOGGER:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
|
||||
else:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
|
||||
epoch_started = False
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
if not epoch_started and MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i+1, metadata=dict(epoch_num=i+1))
|
||||
epoch_started = True
|
||||
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
|
@ -791,7 +788,7 @@ def train_bert():
|
|||
if WANDB:
|
||||
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
||||
|
||||
train_data, next_data = next_data, None
|
||||
i += 1
|
||||
|
@ -806,9 +803,7 @@ def train_bert():
|
|||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
epoch_started = False
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i+1, metadata=dict(epoch_num=i+1))
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
||||
if getenv("RESET_STEP", 1): train_step_bert.reset()
|
||||
eval_lm_losses = []
|
||||
eval_clsf_losses = []
|
||||
|
@ -863,12 +858,13 @@ def train_bert():
|
|||
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i*BS, metadata={"epoch_count": i*BS, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": i*BS, "masked_lm_accuracy": avg_lm_acc})
|
||||
|
||||
# save model if achieved target
|
||||
if not achieved and avg_lm_acc >= target:
|
||||
wc_end = time.perf_counter()
|
||||
if getenv("CKPT"):
|
||||
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/bert-large.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
|
@ -881,6 +877,7 @@ def train_bert():
|
|||
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
||||
achieved = True
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i*BS, metadata={"epoch_num": i*BS})
|
||||
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
|
||||
# stop once hitting the target
|
||||
break
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
"number_of_nodes": "1",
|
||||
"host_processors_per_node": "1",
|
||||
"host_processor_model_name": "AMD EPYC 7532 32-Core Processor",
|
||||
"host_processor_core_count": "64",
|
||||
"host_processor_core_count": "32",
|
||||
"host_processor_vcpu_count": "64",
|
||||
"host_processor_frequency": "",
|
||||
"host_processor_caches": "",
|
||||
"host_processor_interconnect": "",
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
"number_of_nodes": "1",
|
||||
"host_processors_per_node": "1",
|
||||
"host_processor_model_name": "AMD EPYC 7532 32-Core Processor",
|
||||
"host_processor_core_count": "64",
|
||||
"host_processor_core_count": "32",
|
||||
"host_processor_vcpu_count": "64",
|
||||
"host_processor_frequency": "",
|
||||
"host_processor_caches": "",
|
||||
"host_processor_interconnect": "",
|
||||
|
|
Loading…
Reference in New Issue