Add MLLogger (#5125)

* add MLPerf logger

* eval steps

* start with step 1

* compliance for 3.1.0 and 4.0.0

* more compliance

* assert, comment and contiguous
This commit is contained in:
Elias Wahl 2024-06-26 18:23:56 +02:00 committed by GitHub
parent 16405b973a
commit e267f3161d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 109 additions and 16 deletions

View File

@ -2,7 +2,7 @@ from collections import OrderedDict
import unicodedata
import numpy as np
from tinygrad.nn import state
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, dtypes
from tinygrad.helpers import getenv
#
@ -225,3 +225,14 @@ def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it)
for key in data.keys(): data[key].shard_(GPUS, axis=0)
return data
def get_fake_data_bert(GPUS:list[str], BS:int):
return {
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_positions": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"masked_lm_weights": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
}

View File

@ -386,7 +386,7 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
def train_bert():
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert, get_fake_data_bert
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
@ -397,6 +397,35 @@ def train_bert():
for x in GPUS: Device[x]
seed = config["seed"] = getenv("SEED", 12345)
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename="bert.log")
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
MLLOGGER = mllog.get_mllogger()
MLLOGGER.logger.propagate = False
if INITMLPERF:
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
else:
MLLOGGER = None
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 16 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
@ -413,6 +442,7 @@ def train_bert():
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
@ -438,8 +468,8 @@ def train_bert():
# ** Optimizer **
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, weight_decay=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, weight_decay=0.0, adam=False)
optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
# ** LR scheduler **
@ -448,8 +478,32 @@ def train_bert():
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
# log mlperf hparams
if MLLOGGER:
if RUNMLPERF:
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
MLLOGGER.event(key="max_predictions_per_seq", value=76)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
MLLOGGER.event(key='start_warmup_step', value=0)
MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
# ** resume from checkpointing **
start_step = 0
start_step = 1
previous_step = None
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_step = int(scheduler_wd.epoch_counter.numpy().item())
@ -464,13 +518,17 @@ def train_bert():
BENCHMARK = getenv("BENCHMARK")
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
if not INITMLPERF:
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **
wc_start = time.perf_counter()
i, train_data = start_step, get_data_bert(GPUS, train_it)
if INITMLPERF:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
else:
i, train_data = start_step, get_data_bert(GPUS, train_it)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
@ -483,7 +541,10 @@ def train_bert():
pt = time.perf_counter()
try:
next_data = get_data_bert(GPUS, train_it)
if INITMLPERF:
next_data = get_fake_data_bert(GPUS, BS)
else:
next_data = get_data_bert(GPUS, train_it)
except StopIteration:
next_data = None
@ -513,11 +574,12 @@ def train_bert():
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
return
# ** eval loop **
if i % eval_step_freq == 0 or i == 1:
train_step_bert.reset() # free the train step memory :(
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
train_step_bert.reset()
eval_lm_losses = []
eval_clsf_losses = []
eval_lm_accs = []
@ -527,7 +589,10 @@ def train_bert():
BEAM.value = EVAL_BEAM
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
eval_data = get_data_bert(GPUS, eval_it)
if INITMLPERF:
eval_data = get_fake_data_bert(GPUS, BS)
else:
eval_data = get_data_bert(GPUS, eval_it)
GlobalCounters.reset()
st = time.time()
@ -546,7 +611,11 @@ def train_bert():
et = time.time()
eval_times.append(et - st)
if BENCHMARK and j == BENCHMARK: break
if BENCHMARK and j == BENCHMARK:
# assume INITMLPERF has BENCHMARK set
if MLLOGGER and INITMLPERF:
MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
return
eval_step_bert.reset()
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
@ -557,12 +626,15 @@ def train_bert():
results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
tqdm.write(results)
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
if WANDB:
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
# save model if achieved target
if not achieved and avg_lm_acc >= target:
wc_end = time.perf_counter()
@ -577,10 +649,16 @@ def train_bert():
seconds = total_seconds % 60
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
achieved = True
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
# stop once hitting the target
break
if getenv("CKPT") and i % save_ckpt_freq == 0:
if getenv("CKPT", 1) and i % save_ckpt_freq == 0:
if MLLOGGER and RUNMLPERF:
if previous_step:
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
@ -594,6 +672,10 @@ def train_bert():
last = ckpt_files.pop(0)
print(f"Removing old ckpt {last}")
os.remove(os.path.join(ckpt_dir, last))
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"], "step_num": i, "first_step_num": i+1})
previous_step = i
def train_maskrcnn():
# TODO: Mask RCNN