mirror of https://github.com/commaai/tinygrad.git
MLPerf BERT: Main training loop (#4288)
* BERT language modeling head + trunc normal initializers * add train loop + helpers * shuffle in dataloaders + slight changes in main loop * beam change * Minor changes * random.shuffle * HParam update * Use deque for dataloader * wandb bert project name * half fixes * BENCHMARK + remove epoch * cast + print() --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
61c97d5305
commit
27613dd881
|
@ -1,12 +1,13 @@
|
|||
import os, random
|
||||
from typing import List
|
||||
import os, random, pickle, functools, itertools
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, prod, Timing, Context
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
|
||||
from collections import deque
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
|
||||
|
||||
class MyQueue:
|
||||
def __init__(self, multiple_readers=True, multiple_writers=True):
|
||||
|
@ -140,6 +141,7 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
|
|||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def load_bert_file(fn:str) -> List[dict]:
|
||||
with open(fn, "rb") as f: data = pickle.load(f)
|
||||
return data
|
||||
|
@ -147,7 +149,7 @@ def load_bert_file(fn:str) -> List[dict]:
|
|||
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
||||
return {
|
||||
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float),
|
||||
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
|
@ -155,22 +157,69 @@ def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
|||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
}
|
||||
|
||||
# For train: Stop when we run through all data
|
||||
# For val: Wrap around val dataset and never stop
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 420
|
||||
def batch_load_bert(BS:int, val=False):
|
||||
from extra.datasets.wikipedia import get_wiki_train_files, get_wiki_val_files
|
||||
files = get_wiki_val_files() if val else get_wiki_train_files()
|
||||
blob, end = [], False
|
||||
while files: # As long as there is data, keep going
|
||||
while len(blob) < BS and not end: # Fill blob until there is enough for next step
|
||||
blob.extend(load_bert_file(files.pop(0)))
|
||||
if not files:
|
||||
if val: files = get_val_files()
|
||||
else: end = True # End of train data - avoid pop on empty file list
|
||||
if len(blob) >= BS: # if last train step does not have enough for a full batch
|
||||
yield process_batch_bert(blob[:BS])
|
||||
blob = blob[BS:]
|
||||
def shuffle_parts(file_paths: List[str]) -> List[str]:
|
||||
parts = {}
|
||||
for f in file_paths:
|
||||
part = Path(f).stem.split('_')[0]
|
||||
if part not in parts: parts[part] = []
|
||||
parts[part].append(f)
|
||||
|
||||
part_ids = list(parts.keys())
|
||||
random.shuffle(part_ids)
|
||||
|
||||
shuffled_files = []
|
||||
for p in part_ids:
|
||||
parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
|
||||
shuffled_files.extend(parts[p])
|
||||
return shuffled_files
|
||||
|
||||
def random_sample(data: List[str]):
|
||||
index = random.randint(0, len(data) - 1)
|
||||
selected_sample = data[index]
|
||||
return selected_sample, index
|
||||
|
||||
def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
|
||||
data = load_bert_file(file_and_offset[0])
|
||||
return data[file_and_offset[1]]
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
|
||||
def batch_load_train_bert(BS:int):
|
||||
from extra.datasets.wikipedia import get_wiki_train_files
|
||||
files = shuffle_parts(get_wiki_train_files())
|
||||
dataset = []
|
||||
for f in files:
|
||||
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
|
||||
dataset.extend(lists)
|
||||
|
||||
active_set = deque(dataset[:1000])
|
||||
remaining_set = deque(dataset[1000:])
|
||||
|
||||
while dataset:
|
||||
blob = []
|
||||
for _ in range(BS):
|
||||
if active_set:
|
||||
index = random.randint(0, len(active_set) - 1)
|
||||
sample = active_set[index]
|
||||
active_set.remove(sample)
|
||||
blob.append(sample)
|
||||
if remaining_set:
|
||||
active_set.append(remaining_set.popleft())
|
||||
yield process_batch_bert([load_datasample(sample) for sample in blob])
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
|
||||
def batch_load_val_bert(BS:int):
|
||||
from extra.datasets.wikipedia import get_wiki_val_files
|
||||
files = get_wiki_val_files()
|
||||
dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
|
||||
idx = 0
|
||||
while True:
|
||||
start_idx = (idx * BS) % len(dataset)
|
||||
end_idx = ((idx + 1) * BS) % len(dataset)
|
||||
if start_idx < end_idx:
|
||||
yield process_batch_bert(dataset[start_idx:end_idx])
|
||||
else: # wrap around the end to the beginning of the dataset
|
||||
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
|
||||
idx += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extra.datasets.imagenet import get_train_files, get_val_files
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from collections import OrderedDict
|
||||
import unicodedata
|
||||
import os, unicodedata, json, functools
|
||||
import numpy as np
|
||||
from tinygrad.nn import state
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
#
|
||||
# checkpointing utils
|
||||
|
@ -190,3 +192,88 @@ def get_bert_qa_prediction(features, example, start_end_logits):
|
|||
orig_text = " ".join(orig_tokens)
|
||||
return _get_final_text(tok_text, orig_text)
|
||||
return "empty"
|
||||
|
||||
def get_mlperf_bert_model(config_path:str):
|
||||
from extra.models import bert
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
bert.Linear = LinearBert
|
||||
bert.Embedding = EmbeddingBert
|
||||
bert.LayerNorm = LayerNormBert
|
||||
|
||||
from extra.models.bert import BertForMLPerf
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
return BertForMLPerf(
|
||||
config["hidden_size"],
|
||||
config["intermediate_size"],
|
||||
config["max_position_embeddings"],
|
||||
config["num_attention_heads"],
|
||||
config["num_hidden_layers"],
|
||||
config["type_vocab_size"],
|
||||
config["vocab_size"],
|
||||
config["attention_probs_dropout_prob"],
|
||||
config["hidden_dropout_prob"]
|
||||
)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def load_tf_weights_to_dict(checkpoint_path):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
import tensorflow as tf
|
||||
reader = tf.train.load_checkpoint(checkpoint_path)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
weights_dict = {}
|
||||
|
||||
for key in sorted(var_to_shape_map):
|
||||
weights_dict[key] = reader.get_tensor(key)
|
||||
return weights_dict
|
||||
|
||||
def tt(tf_tensor): return Tensor(tf_tensor, dtype=dtypes.float32)
|
||||
|
||||
def load_from_tf2_ckpt(key: str, ckpt_dir: str):
|
||||
p = "model/layer-3/"
|
||||
s = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
tf_dict = load_tf_weights_to_dict(ckpt_dir)
|
||||
if key.startswith("model.embeddings"):
|
||||
if key.endswith("word_embeddings.weight"): return tt(tf_dict[p+"layer-1/embeddings"+s])
|
||||
elif key.endswith("position_embeddings.weight"): return tt(tf_dict[p+"layer-3/embeddings"+s])
|
||||
elif key.endswith("token_type_embeddings.weight"): return tt(tf_dict[p+"layer-4/embeddings"+s])
|
||||
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+"layer-6/gamma"+s])
|
||||
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+"layer-6/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif key.startswith("model.encoder.layer"):
|
||||
l_id = str(int(key.split(".")[3]) + 10)
|
||||
if ".attention." in key:
|
||||
if key.endswith("self.query.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/kernel"+s])
|
||||
elif key.endswith("self.query.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/bias"+s])
|
||||
elif key.endswith("self.key.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/kernel"+s])
|
||||
elif key.endswith("self.key.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/bias"+s])
|
||||
elif key.endswith("self.value.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/kernel"+s])
|
||||
elif key.endswith("self.value.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/bias"+s])
|
||||
# Attention output
|
||||
elif key.endswith("output.dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/kernel"+s])
|
||||
elif key.endswith("output.dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/bias"+s])
|
||||
elif key.endswith("output.LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/gamma"+s])
|
||||
elif key.endswith("output.LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif ".intermediate." in key:
|
||||
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/kernel"+s])
|
||||
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/bias"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif ".output." in key:
|
||||
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/kernel"+s])
|
||||
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/bias"+s])
|
||||
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/gamma"+s])
|
||||
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif key.startswith("clsf_pooler.weight"): return tt(tf_dict[f"model/layer-3/layer-35/kernel"+s])
|
||||
elif key.startswith("clsf_pooler.bias"): return tt(tf_dict[f"model/layer-3/layer-35/bias"+s])
|
||||
elif key.startswith("clsf_output.weight"): return tt(tf_dict[f"model/layer-6/layer-1/kernel"+s])
|
||||
elif key.startswith("clsf_output.bias"): return tt(tf_dict[f"model/layer-6/layer-1/bias"+s])
|
||||
elif key.startswith("lm_transform.weight"): return tt(tf_dict[f"model/layer-5/layer-3/kernel"+s])
|
||||
elif key.startswith("lm_transform.bias"): return tt(tf_dict[f"model/layer-5/layer-3/bias"+s])
|
||||
elif key.startswith("lm_norm.weight"): return tt(tf_dict[f"model/layer-5/layer-4/gamma"+s])
|
||||
elif key.startswith("lm_norm.bias"): return tt(tf_dict[f"model/layer-5/layer-4/beta"+s])
|
||||
elif key.startswith("lm_output_bias"): return tt(tf_dict[f"model/layer-5/layer-6/bias"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from typing import Union, Tuple
|
||||
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.helpers import prod, argfix
|
||||
|
@ -33,3 +34,35 @@ class Linear(nn.Linear):
|
|||
if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
|
||||
def __call__(self, x:Tensor):
|
||||
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
|
||||
|
||||
class LinearBert(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, std=0.02):
|
||||
self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
|
||||
self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
|
||||
|
||||
class EmbeddingBert(nn.Embedding):
|
||||
def __init__(self, vocab_size:int, embed_size:int, std=0.02):
|
||||
self.vocab_sz, self.embed_sz = vocab_size, embed_size
|
||||
self.weight = std * rand_truncn(vocab_size, embed_size, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
|
||||
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
|
||||
|
||||
class LayerNormBert:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
|
||||
if not self.elementwise_affine: return xn
|
||||
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
|
|
@ -1,13 +1,12 @@
|
|||
import functools
|
||||
import os
|
||||
import time
|
||||
import os, time, math, functools
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
||||
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
from examples.mlperf.helpers import get_training_state, load_training_state
|
||||
|
@ -259,8 +258,251 @@ def train_rnnt():
|
|||
pass
|
||||
|
||||
def train_bert():
|
||||
# TODO: BERT
|
||||
pass
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, load_from_tf2_ckpt
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
||||
|
||||
config = {}
|
||||
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
|
||||
|
||||
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
print(f"Training on {GPUS}")
|
||||
for x in GPUS: Device[x]
|
||||
seed = config["seed"] = getenv("SEED", 12345)
|
||||
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 4 * len(GPUS)) # FP32 4090: 6 GPUS -> BS24
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS)
|
||||
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
|
||||
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", train_steps // 10)
|
||||
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS) # Round down
|
||||
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
|
||||
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
|
||||
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
||||
|
||||
decay = config["decay"] = getenv("DECAY", 0.01)
|
||||
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
|
||||
|
||||
target, achieved = getenv("TARGET", 0.72), False
|
||||
|
||||
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
||||
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model(BASEDIR / "bert_config.json")
|
||||
|
||||
# shard weights and initialize in order
|
||||
for tinygrad_key, x in get_state_dict(model).items():
|
||||
if init_ckpt and not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
|
||||
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=init_ckpt)
|
||||
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
|
||||
t = t.transpose()
|
||||
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
|
||||
t = t.reshape(*x.shape).transpose()
|
||||
elif all(k in tinygrad_key for k in ["self", "bias"]):
|
||||
t = t.reshape(*x.shape)
|
||||
x.assign(t).realize().to_(GPUS)
|
||||
x.realize().to_(GPUS)
|
||||
parameters = get_parameters(model)
|
||||
|
||||
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
|
||||
|
||||
# ** Log hparams **
|
||||
for key, value in config.items():
|
||||
print(f'HParam: "{key}": {value}')
|
||||
|
||||
# ** Optimizer **
|
||||
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
||||
parameters = [x for x in parameters if x not in set(skip_list)]
|
||||
optimizer = LAMB(parameters, 1 / warmup_steps, eps=1e-6, wd=decay, adam=False)
|
||||
optimizer_skip = LAMB(skip_list, 1 / warmup_steps, eps=1e-6, wd=0.0, adam=False)
|
||||
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
|
||||
|
||||
# ** LR scheduler **
|
||||
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
print(f"Training with batch size {BS} for one epoch with {train_steps} steps")
|
||||
|
||||
# ** resume from checkpointing **
|
||||
start_step = 0
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
|
||||
start_step = scheduler.epoch_counter.numpy().item()
|
||||
print(f"resuming from {ckpt} at step {start_step}")
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
||||
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
@TinyJit
|
||||
def train_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
|
||||
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
loss = lm_loss + clsf_loss
|
||||
|
||||
if not getenv('DISABLE_BACKWARD', 0):
|
||||
optimizer_group.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
optimizer_group.step()
|
||||
scheduler.step()
|
||||
return loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def eval_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
|
||||
|
||||
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
|
||||
clsf_accuracy = (clsf_predictions == next_sentence_labels).float().mean()
|
||||
|
||||
mlm_predictions = lm_logits.log_softmax().argmax(-1)
|
||||
mask = (masked_lm_weights == 1.0)
|
||||
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.float().sum()
|
||||
|
||||
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
return {
|
||||
"masked_lm_accuracy": mlm_accuracy.realize(),
|
||||
"masked_lm_loss": lm_loss.realize(),
|
||||
"next_sentence_accuracy": clsf_accuracy.realize(),
|
||||
"next_sentence_loss": clsf_loss.realize()
|
||||
}
|
||||
|
||||
def data_get(it):
|
||||
data: dict[str, Tensor] = next(it)
|
||||
for key in data.keys(): data[key].shard_(GPUS, axis=0)
|
||||
return data
|
||||
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
i, train_data = 0, data_get(train_it)
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss = train_step(train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
next_data = data_get(train_it)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.numpy().item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
||||
if WANDB:
|
||||
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
|
||||
|
||||
train_data, next_data = next_data, None
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * train_steps / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or i == 1:
|
||||
train_step.reset() # free the train step memory :(
|
||||
eval_loss = []
|
||||
eval_accuracy = []
|
||||
eval_times = []
|
||||
Tensor.training = False
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
eval_data = data_get(eval_it)
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
eval_result: dict[str, Tensor] = eval_step(eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"], \
|
||||
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
|
||||
|
||||
lm_loss, clsf_loss = eval_result["masked_lm_loss"].numpy().item(), eval_result["next_sentence_loss"].numpy().item()
|
||||
mlm_accuracy, clsf_accuracy = eval_result["masked_lm_accuracy"].numpy().item(), eval_result["next_sentence_accuracy"].numpy().item()
|
||||
eval_loss.append([lm_loss, clsf_loss])
|
||||
eval_accuracy.append([mlm_accuracy, clsf_accuracy])
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
eval_step.reset()
|
||||
total_lm_loss = sum(pair[0] for pair in eval_loss) / len(eval_loss)
|
||||
total_clsf_loss = sum(pair[1] for pair in eval_loss) / len(eval_loss)
|
||||
total_lm_accuracy = sum(pair[0] for pair in eval_accuracy) / len(eval_accuracy)
|
||||
total_clsf_accuracy = sum(pair[1] for pair in eval_accuracy) / len(eval_accuracy)
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
results = f"eval lm loss: {total_lm_loss:.2f}, eval clsf loss: {total_clsf_loss:.2f}, eval lm accuracy: {total_lm_accuracy:.6f}, \
|
||||
eval clsf accuracy: {total_clsf_accuracy:.2f}, avg eval step time: {total_fw_time:.2f}"
|
||||
tqdm.write(results)
|
||||
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/lm_loss": total_lm_loss, "eval/clsf_loss": total_clsf_loss, "eval/lm_accuracy": total_lm_accuracy, \
|
||||
"eval/clsf_accuracy": total_clsf_accuracy, "eval/forward_time": total_fw_time})
|
||||
|
||||
# save model if achieved target
|
||||
if not achieved and total_lm_accuracy >= target:
|
||||
wc_end = time.perf_counter()
|
||||
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/bert-large.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
print(f" *** Model saved to {fn} ***")
|
||||
|
||||
total_seconds = wc_end - wc_start
|
||||
hours = int(total_seconds // 3600)
|
||||
minutes = int((total_seconds % 3600) // 60)
|
||||
seconds = total_seconds % 60
|
||||
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
|
||||
achieved = True
|
||||
|
||||
if getenv("CKPT") and i % save_ckpt_freq == 0:
|
||||
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
|
||||
else:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
|
||||
print(f"saving ckpt to {fn}")
|
||||
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
|
||||
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
|
||||
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
|
||||
while len(ckpt_files) > keep_ckpt_amount:
|
||||
last = ckpt_files.pop(0)
|
||||
print(f"Removing old ckpt {last}")
|
||||
os.remove(os.path.join(ckpt_dir, last))
|
||||
|
||||
def train_maskrcnn():
|
||||
# TODO: Mask RCNN
|
||||
|
|
|
@ -353,7 +353,7 @@ def process_part(part:int):
|
|||
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
|
||||
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
|
||||
|
@ -407,5 +407,5 @@ if __name__ == "__main__":
|
|||
part = int(sys.argv[2])
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import nn
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
from pathlib import Path
|
||||
|
||||
from examples.mlperf.initializers import LinearBert, LayerNormBert
|
||||
|
||||
# allow for monkeypatching
|
||||
Embedding = nn.Embedding
|
||||
Linear = nn.Linear
|
||||
|
@ -37,6 +39,33 @@ class BertForQuestionAnswering:
|
|||
|
||||
return Tensor.stack([start_logits, end_logits])
|
||||
|
||||
class BertForMLPerf:
|
||||
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None:
|
||||
self.model = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
# for clsf:
|
||||
self.clsf_pooler = LinearBert(hidden_size, hidden_size) # [bs, seq, hidden] -> [bs, hidden]
|
||||
self.clsf_pooling_activation = Tensor.tanh
|
||||
self.clsf_output = LinearBert(hidden_size, 2) # [bs, hidden] -> [bs, 2]
|
||||
|
||||
# for lm:
|
||||
self.lm_transform = LinearBert(hidden_size, hidden_size)
|
||||
self.lm_transform_activation = gelu
|
||||
self.lm_norm = LayerNormBert(hidden_size, eps=1e-12)
|
||||
self.lm_output = LinearBert(hidden_size, vocab_size, bias=False) # [bs, seq, hidden] -> [bs, seq, vocab]
|
||||
self.lm_output.weight = self.model.embeddings.word_embeddings.weight
|
||||
self.lm_output_bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor):
|
||||
output = self.model(input_ids, attention_mask, segment_ids)
|
||||
clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32)
|
||||
|
||||
masked_positions = masked_positions[:, :, None].expand(-1, -1, output.shape[-1])
|
||||
h_masked = Tensor.gather(output, masked_positions, 1)
|
||||
h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked)))
|
||||
lm_logits = self.lm_output(h_masked) + self.lm_output_bias
|
||||
|
||||
return lm_logits, clsf_logits
|
||||
|
||||
class Bert:
|
||||
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
|
||||
|
@ -106,6 +135,9 @@ class BertOutput:
|
|||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
def gelu(x):
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
|
||||
# approximation of the error function
|
||||
def erf(x):
|
||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||
|
@ -118,7 +150,7 @@ class BertIntermediate:
|
|||
def __call__(self, hidden_states):
|
||||
x = self.dense(hidden_states)
|
||||
# tinygrad gelu is openai gelu but we need the original bert gelu
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
return gelu(x)
|
||||
|
||||
class BertAttention:
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
|
|
|
@ -64,7 +64,7 @@ if __name__ == "__main__":
|
|||
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
|
||||
|
||||
preprocessed_samples = []
|
||||
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
|
||||
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
|
||||
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
|
||||
samples = pickle.load(f)
|
||||
preprocessed_samples.extend(samples)
|
||||
|
|
Loading…
Reference in New Issue