mirror of https://github.com/commaai/tinygrad.git
Add MLLogger (#5125)
* add MLPerf logger * eval steps * start with step 1 * compliance for 3.1.0 and 4.0.0 * more compliance * assert, comment and contiguous
This commit is contained in:
parent
16405b973a
commit
e267f3161d
|
@ -2,7 +2,7 @@ from collections import OrderedDict
|
|||
import unicodedata
|
||||
import numpy as np
|
||||
from tinygrad.nn import state
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
#
|
||||
|
@ -225,3 +225,14 @@ def get_data_bert(GPUS:list[str], it):
|
|||
data: dict[str, Tensor] = next(it)
|
||||
for key in data.keys(): data[key].shard_(GPUS, axis=0)
|
||||
return data
|
||||
|
||||
def get_fake_data_bert(GPUS:list[str], BS:int):
|
||||
return {
|
||||
"input_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"input_mask": Tensor.zeros((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
|
||||
"segment_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_positions": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_ids": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"masked_lm_weights": Tensor.zeros((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
"next_sentence_labels": Tensor.zeros((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
|
||||
}
|
||||
|
|
|
@ -386,7 +386,7 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
|
|||
def train_bert():
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert, get_fake_data_bert
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
||||
|
||||
config = {}
|
||||
|
@ -397,6 +397,35 @@ def train_bert():
|
|||
for x in GPUS: Device[x]
|
||||
seed = config["seed"] = getenv("SEED", 12345)
|
||||
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
RUNMLPERF = getenv("RUNMLPERF")
|
||||
if getenv("LOGMLPERF"):
|
||||
from mlperf_logging import mllog
|
||||
import mlperf_logging.mllog.constants as mllog_constants
|
||||
|
||||
mllog.config(filename="bert.log")
|
||||
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
|
||||
MLLOGGER = mllog.get_mllogger()
|
||||
MLLOGGER.logger.propagate = False
|
||||
|
||||
if INITMLPERF:
|
||||
assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
|
||||
|
||||
diskcache_clear()
|
||||
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
||||
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
|
||||
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 16 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
|
@ -413,6 +442,7 @@ def train_bert():
|
|||
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
decay = config["DECAY"] = getenv("DECAY", 0.01)
|
||||
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
|
||||
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
|
||||
|
||||
target, achieved = getenv("TARGET", 0.72), False
|
||||
|
@ -438,8 +468,8 @@ def train_bert():
|
|||
# ** Optimizer **
|
||||
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
||||
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
|
||||
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, weight_decay=decay, adam=False)
|
||||
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, weight_decay=0.0, adam=False)
|
||||
optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
|
||||
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
|
||||
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
|
||||
|
||||
# ** LR scheduler **
|
||||
|
@ -448,8 +478,32 @@ def train_bert():
|
|||
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
|
||||
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
|
||||
|
||||
# log mlperf hparams
|
||||
if MLLOGGER:
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
|
||||
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
|
||||
MLLOGGER.event(key="max_predictions_per_seq", value=76)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
|
||||
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
|
||||
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
|
||||
MLLOGGER.event(key='start_warmup_step', value=0)
|
||||
MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
|
||||
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
|
||||
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
|
||||
|
||||
# ** resume from checkpointing **
|
||||
start_step = 0
|
||||
start_step = 1
|
||||
previous_step = None
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
||||
start_step = int(scheduler_wd.epoch_counter.numpy().item())
|
||||
|
@ -464,13 +518,17 @@ def train_bert():
|
|||
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
|
||||
if not INITMLPERF:
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
|
||||
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
if INITMLPERF:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
|
@ -483,7 +541,10 @@ def train_bert():
|
|||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
next_data = get_data_bert(GPUS, train_it)
|
||||
if INITMLPERF:
|
||||
next_data = get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
next_data = get_data_bert(GPUS, train_it)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
|
@ -513,11 +574,12 @@ def train_bert():
|
|||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or i == 1:
|
||||
train_step_bert.reset() # free the train step memory :(
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
|
||||
train_step_bert.reset()
|
||||
eval_lm_losses = []
|
||||
eval_clsf_losses = []
|
||||
eval_lm_accs = []
|
||||
|
@ -527,7 +589,10 @@ def train_bert():
|
|||
BEAM.value = EVAL_BEAM
|
||||
|
||||
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
if INITMLPERF:
|
||||
eval_data = get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
|
@ -546,7 +611,11 @@ def train_bert():
|
|||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if BENCHMARK and j == BENCHMARK: break
|
||||
if BENCHMARK and j == BENCHMARK:
|
||||
# assume INITMLPERF has BENCHMARK set
|
||||
if MLLOGGER and INITMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
|
||||
return
|
||||
|
||||
eval_step_bert.reset()
|
||||
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
|
||||
|
@ -557,12 +626,15 @@ def train_bert():
|
|||
results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
|
||||
eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
|
||||
tqdm.write(results)
|
||||
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
|
||||
"eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
|
||||
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
|
||||
|
||||
# save model if achieved target
|
||||
if not achieved and avg_lm_acc >= target:
|
||||
wc_end = time.perf_counter()
|
||||
|
@ -577,10 +649,16 @@ def train_bert():
|
|||
seconds = total_seconds % 60
|
||||
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
||||
achieved = True
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
|
||||
# stop once hitting the target
|
||||
break
|
||||
|
||||
if getenv("CKPT") and i % save_ckpt_freq == 0:
|
||||
if getenv("CKPT", 1) and i % save_ckpt_freq == 0:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
if previous_step:
|
||||
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
|
||||
MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
|
||||
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
|
||||
|
@ -594,6 +672,10 @@ def train_bert():
|
|||
last = ckpt_files.pop(0)
|
||||
print(f"Removing old ckpt {last}")
|
||||
os.remove(os.path.join(ckpt_dir, last))
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
|
||||
MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"], "step_num": i, "first_step_num": i+1})
|
||||
previous_step = i
|
||||
|
||||
def train_maskrcnn():
|
||||
# TODO: Mask RCNN
|
||||
|
|
Loading…
Reference in New Issue