MLPerf BERT: Main training loop (#4288)

* BERT language modeling head + trunc normal initializers

* add train loop + helpers

* shuffle in dataloaders + slight changes in main loop

* beam change

* Minor changes

* random.shuffle

* HParam update

* Use deque for dataloader

* wandb bert project name

* half fixes

* BENCHMARK + remove epoch

* cast + print()

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Elias Wahl 2024-04-29 20:35:27 +02:00 committed by GitHub
parent 61c97d5305
commit 27613dd881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 476 additions and 33 deletions

View File

@ -1,12 +1,13 @@
import os, random
from typing import List
import os, random, pickle, functools, itertools
from typing import List, Tuple
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
import pickle
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Timing, Context
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
from collections import deque
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
class MyQueue:
def __init__(self, multiple_readers=True, multiple_writers=True):
@ -140,6 +141,7 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None):
shm.close()
shm.unlink()
@functools.lru_cache(maxsize=128)
def load_bert_file(fn:str) -> List[dict]:
with open(fn, "rb") as f: data = pickle.load(f)
return data
@ -147,7 +149,7 @@ def load_bert_file(fn:str) -> List[dict]:
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
return {
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.float32),
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float),
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
@ -155,22 +157,69 @@ def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
}
# For train: Stop when we run through all data
# For val: Wrap around val dataset and never stop
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 420
def batch_load_bert(BS:int, val=False):
from extra.datasets.wikipedia import get_wiki_train_files, get_wiki_val_files
files = get_wiki_val_files() if val else get_wiki_train_files()
blob, end = [], False
while files: # As long as there is data, keep going
while len(blob) < BS and not end: # Fill blob until there is enough for next step
blob.extend(load_bert_file(files.pop(0)))
if not files:
if val: files = get_val_files()
else: end = True # End of train data - avoid pop on empty file list
if len(blob) >= BS: # if last train step does not have enough for a full batch
yield process_batch_bert(blob[:BS])
blob = blob[BS:]
def shuffle_parts(file_paths: List[str]) -> List[str]:
parts = {}
for f in file_paths:
part = Path(f).stem.split('_')[0]
if part not in parts: parts[part] = []
parts[part].append(f)
part_ids = list(parts.keys())
random.shuffle(part_ids)
shuffled_files = []
for p in part_ids:
parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
shuffled_files.extend(parts[p])
return shuffled_files
def random_sample(data: List[str]):
index = random.randint(0, len(data) - 1)
selected_sample = data[index]
return selected_sample, index
def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
data = load_bert_file(file_and_offset[0])
return data[file_and_offset[1]]
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int):
from extra.datasets.wikipedia import get_wiki_train_files
files = shuffle_parts(get_wiki_train_files())
dataset = []
for f in files:
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
dataset.extend(lists)
active_set = deque(dataset[:1000])
remaining_set = deque(dataset[1000:])
while dataset:
blob = []
for _ in range(BS):
if active_set:
index = random.randint(0, len(active_set) - 1)
sample = active_set[index]
active_set.remove(sample)
blob.append(sample)
if remaining_set:
active_set.append(remaining_set.popleft())
yield process_batch_bert([load_datasample(sample) for sample in blob])
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
def batch_load_val_bert(BS:int):
from extra.datasets.wikipedia import get_wiki_val_files
files = get_wiki_val_files()
dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
idx = 0
while True:
start_idx = (idx * BS) % len(dataset)
end_idx = ((idx + 1) * BS) % len(dataset)
if start_idx < end_idx:
yield process_batch_bert(dataset[start_idx:end_idx])
else: # wrap around the end to the beginning of the dataset
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
idx += 1
if __name__ == "__main__":
from extra.datasets.imagenet import get_train_files, get_val_files

View File

@ -1,7 +1,9 @@
from collections import OrderedDict
import unicodedata
import os, unicodedata, json, functools
import numpy as np
from tinygrad.nn import state
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
#
# checkpointing utils
@ -190,3 +192,88 @@ def get_bert_qa_prediction(features, example, start_end_logits):
orig_text = " ".join(orig_tokens)
return _get_final_text(tok_text, orig_text)
return "empty"
def get_mlperf_bert_model(config_path:str):
from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
bert.Linear = LinearBert
bert.Embedding = EmbeddingBert
bert.LayerNorm = LayerNormBert
from extra.models.bert import BertForMLPerf
with open(config_path, "r") as f:
config = json.load(f)
return BertForMLPerf(
config["hidden_size"],
config["intermediate_size"],
config["max_position_embeddings"],
config["num_attention_heads"],
config["num_hidden_layers"],
config["type_vocab_size"],
config["vocab_size"],
config["attention_probs_dropout_prob"],
config["hidden_dropout_prob"]
)
@functools.lru_cache(maxsize=None)
def load_tf_weights_to_dict(checkpoint_path):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
reader = tf.train.load_checkpoint(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
weights_dict = {}
for key in sorted(var_to_shape_map):
weights_dict[key] = reader.get_tensor(key)
return weights_dict
def tt(tf_tensor): return Tensor(tf_tensor, dtype=dtypes.float32)
def load_from_tf2_ckpt(key: str, ckpt_dir: str):
p = "model/layer-3/"
s = "/.ATTRIBUTES/VARIABLE_VALUE"
tf_dict = load_tf_weights_to_dict(ckpt_dir)
if key.startswith("model.embeddings"):
if key.endswith("word_embeddings.weight"): return tt(tf_dict[p+"layer-1/embeddings"+s])
elif key.endswith("position_embeddings.weight"): return tt(tf_dict[p+"layer-3/embeddings"+s])
elif key.endswith("token_type_embeddings.weight"): return tt(tf_dict[p+"layer-4/embeddings"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+"layer-6/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+"layer-6/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("model.encoder.layer"):
l_id = str(int(key.split(".")[3]) + 10)
if ".attention." in key:
if key.endswith("self.query.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/kernel"+s])
elif key.endswith("self.query.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/bias"+s])
elif key.endswith("self.key.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/kernel"+s])
elif key.endswith("self.key.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/bias"+s])
elif key.endswith("self.value.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/kernel"+s])
elif key.endswith("self.value.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/bias"+s])
# Attention output
elif key.endswith("output.dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/kernel"+s])
elif key.endswith("output.dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/bias"+s])
elif key.endswith("output.LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/gamma"+s])
elif key.endswith("output.LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".intermediate." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/bias"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".output." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/bias"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("clsf_pooler.weight"): return tt(tf_dict[f"model/layer-3/layer-35/kernel"+s])
elif key.startswith("clsf_pooler.bias"): return tt(tf_dict[f"model/layer-3/layer-35/bias"+s])
elif key.startswith("clsf_output.weight"): return tt(tf_dict[f"model/layer-6/layer-1/kernel"+s])
elif key.startswith("clsf_output.bias"): return tt(tf_dict[f"model/layer-6/layer-1/bias"+s])
elif key.startswith("lm_transform.weight"): return tt(tf_dict[f"model/layer-5/layer-3/kernel"+s])
elif key.startswith("lm_transform.bias"): return tt(tf_dict[f"model/layer-5/layer-3/bias"+s])
elif key.startswith("lm_norm.weight"): return tt(tf_dict[f"model/layer-5/layer-4/gamma"+s])
elif key.startswith("lm_norm.bias"): return tt(tf_dict[f"model/layer-5/layer-4/beta"+s])
elif key.startswith("lm_output_bias"): return tt(tf_dict[f"model/layer-5/layer-6/bias"+s])
else: raise ValueError(f"Unknown key: {key}")

View File

@ -1,4 +1,5 @@
import math
from typing import Union, Tuple
from tinygrad import Tensor, nn, dtypes
from tinygrad.helpers import prod, argfix
@ -33,3 +34,35 @@ class Linear(nn.Linear):
if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
def __call__(self, x:Tensor):
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
class LinearBert(nn.Linear):
def __init__(self, in_features, out_features, bias=True, std=0.02):
self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
def __call__(self, x:Tensor):
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
class EmbeddingBert(nn.Embedding):
def __init__(self, vocab_size:int, embed_size:int, std=0.02):
self.vocab_sz, self.embed_sz = vocab_size, embed_size
self.weight = std * rand_truncn(vocab_size, embed_size, dtype=dtypes.float32)
def __call__(self, idx:Tensor) -> Tensor:
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
class LayerNormBert:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)
def __call__(self, x:Tensor):
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
if not self.elementwise_affine: return xn
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))

View File

@ -1,13 +1,12 @@
import functools
import os
import time
import os, time, math, functools
from pathlib import Path
from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
@ -259,8 +258,251 @@ def train_rnnt():
pass
def train_bert():
# TODO: BERT
pass
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, load_from_tf2_ckpt
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"Training on {GPUS}")
for x in GPUS: Device[x]
seed = config["seed"] = getenv("SEED", 12345)
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 4 * len(GPUS)) # FP32 4090: 6 GPUS -> BS24
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", train_steps // 10)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS) # Round down
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
decay = config["decay"] = getenv("DECAY", 0.01)
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(BASEDIR / "bert_config.json")
# shard weights and initialize in order
for tinygrad_key, x in get_state_dict(model).items():
if init_ckpt and not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=init_ckpt)
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
t = t.transpose()
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
t = t.reshape(*x.shape).transpose()
elif all(k in tinygrad_key for k in ["self", "bias"]):
t = t.reshape(*x.shape)
x.assign(t).realize().to_(GPUS)
x.realize().to_(GPUS)
parameters = get_parameters(model)
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
# ** Log hparams **
for key, value in config.items():
print(f'HParam: "{key}": {value}')
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LAMB(parameters, 1 / warmup_steps, eps=1e-6, wd=decay, adam=False)
optimizer_skip = LAMB(skip_list, 1 / warmup_steps, eps=1e-6, wd=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
print(f"Training with batch size {BS} for one epoch with {train_steps} steps")
# ** resume from checkpointing **
start_step = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
start_step = scheduler.epoch_counter.numpy().item()
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
BENCHMARK = getenv("BENCHMARK")
@TinyJit
def train_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
loss = lm_loss + clsf_loss
if not getenv('DISABLE_BACKWARD', 0):
optimizer_group.zero_grad()
loss.backward()
optimizer_group.step()
scheduler.step()
return loss.realize()
@TinyJit
def eval_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
clsf_accuracy = (clsf_predictions == next_sentence_labels).float().mean()
mlm_predictions = lm_logits.log_softmax().argmax(-1)
mask = (masked_lm_weights == 1.0)
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.float().sum()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
return {
"masked_lm_accuracy": mlm_accuracy.realize(),
"masked_lm_loss": lm_loss.realize(),
"next_sentence_accuracy": clsf_accuracy.realize(),
"next_sentence_loss": clsf_loss.realize()
}
def data_get(it):
data: dict[str, Tensor] = next(it)
for key in data.keys(): data[key].shard_(GPUS, axis=0)
return data
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **
wc_start = time.perf_counter()
Tensor.training = True
BEAM.value = TRAIN_BEAM
i, train_data = 0, data_get(train_it)
while train_data is not None and i < train_steps and not achieved:
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step(train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
pt = time.perf_counter()
try:
next_data = data_get(train_it)
except StopIteration:
next_data = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
train_data, next_data = next_data, None
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
return
# ** eval loop **
if i % eval_step_freq == 0 or i == 1:
train_step.reset() # free the train step memory :(
eval_loss = []
eval_accuracy = []
eval_times = []
Tensor.training = False
BEAM.value = EVAL_BEAM
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
eval_data = data_get(eval_it)
GlobalCounters.reset()
st = time.time()
eval_result: dict[str, Tensor] = eval_step(eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"], \
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
lm_loss, clsf_loss = eval_result["masked_lm_loss"].numpy().item(), eval_result["next_sentence_loss"].numpy().item()
mlm_accuracy, clsf_accuracy = eval_result["masked_lm_accuracy"].numpy().item(), eval_result["next_sentence_accuracy"].numpy().item()
eval_loss.append([lm_loss, clsf_loss])
eval_accuracy.append([mlm_accuracy, clsf_accuracy])
et = time.time()
eval_times.append(et - st)
eval_step.reset()
total_lm_loss = sum(pair[0] for pair in eval_loss) / len(eval_loss)
total_clsf_loss = sum(pair[1] for pair in eval_loss) / len(eval_loss)
total_lm_accuracy = sum(pair[0] for pair in eval_accuracy) / len(eval_accuracy)
total_clsf_accuracy = sum(pair[1] for pair in eval_accuracy) / len(eval_accuracy)
total_fw_time = sum(eval_times) / len(eval_times)
results = f"eval lm loss: {total_lm_loss:.2f}, eval clsf loss: {total_clsf_loss:.2f}, eval lm accuracy: {total_lm_accuracy:.6f}, \
eval clsf accuracy: {total_clsf_accuracy:.2f}, avg eval step time: {total_fw_time:.2f}"
tqdm.write(results)
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
if WANDB:
wandb.log({"eval/lm_loss": total_lm_loss, "eval/clsf_loss": total_clsf_loss, "eval/lm_accuracy": total_lm_accuracy, \
"eval/clsf_accuracy": total_clsf_accuracy, "eval/forward_time": total_fw_time})
# save model if achieved target
if not achieved and total_lm_accuracy >= target:
wc_end = time.perf_counter()
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/bert-large.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
total_seconds = wc_end - wc_start
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = total_seconds % 60
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
achieved = True
if getenv("CKPT") and i % save_ckpt_freq == 0:
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
else:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
while len(ckpt_files) > keep_ckpt_amount:
last = ckpt_files.pop(0)
print(f"Removing old ckpt {last}")
os.remove(os.path.join(ckpt_dir, last))
def train_maskrcnn():
# TODO: Mask RCNN

View File

@ -353,7 +353,7 @@ def process_part(part:int):
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
@ -407,5 +407,5 @@ if __name__ == "__main__":
part = int(sys.argv[2])
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}.pkl", "wb") as f:
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
pickle.dump(feature_batch, f)

View File

@ -1,8 +1,10 @@
from tinygrad.tensor import Tensor
from tinygrad import nn
from tinygrad import nn, dtypes
from tinygrad.helpers import fetch, get_child
from pathlib import Path
from examples.mlperf.initializers import LinearBert, LayerNormBert
# allow for monkeypatching
Embedding = nn.Embedding
Linear = nn.Linear
@ -37,6 +39,33 @@ class BertForQuestionAnswering:
return Tensor.stack([start_logits, end_logits])
class BertForMLPerf:
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None:
self.model = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
# for clsf:
self.clsf_pooler = LinearBert(hidden_size, hidden_size) # [bs, seq, hidden] -> [bs, hidden]
self.clsf_pooling_activation = Tensor.tanh
self.clsf_output = LinearBert(hidden_size, 2) # [bs, hidden] -> [bs, 2]
# for lm:
self.lm_transform = LinearBert(hidden_size, hidden_size)
self.lm_transform_activation = gelu
self.lm_norm = LayerNormBert(hidden_size, eps=1e-12)
self.lm_output = LinearBert(hidden_size, vocab_size, bias=False) # [bs, seq, hidden] -> [bs, seq, vocab]
self.lm_output.weight = self.model.embeddings.word_embeddings.weight
self.lm_output_bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor):
output = self.model(input_ids, attention_mask, segment_ids)
clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32)
masked_positions = masked_positions[:, :, None].expand(-1, -1, output.shape[-1])
h_masked = Tensor.gather(output, masked_positions, 1)
h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked)))
lm_logits = self.lm_output(h_masked) + self.lm_output_bias
return lm_logits, clsf_logits
class Bert:
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
@ -106,6 +135,9 @@ class BertOutput:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def gelu(x):
return x * 0.5 * (1.0 + erf(x / 1.41421))
# approximation of the error function
def erf(x):
t = (1 + 0.3275911 * x.abs()).reciprocal()
@ -118,7 +150,7 @@ class BertIntermediate:
def __call__(self, hidden_states):
x = self.dense(hidden_states)
# tinygrad gelu is openai gelu but we need the original bert gelu
return x * 0.5 * (1.0 + erf(x / 1.41421))
return gelu(x)
class BertAttention:
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):

View File

@ -64,7 +64,7 @@ if __name__ == "__main__":
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
preprocessed_samples = []
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
samples = pickle.load(f)
preprocessed_samples.extend(samples)