Refactor to class style (#4804)

This commit is contained in:
Elias Wahl 2024-06-04 23:08:31 +02:00 committed by GitHub
parent 1b8bed4a26
commit 04e237328b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 144 additions and 157 deletions

View File

@ -195,12 +195,11 @@ def get_bert_qa_prediction(features, example, start_end_logits):
return "empty"
def get_mlperf_bert_config():
"""Config is BERT-large"""
return {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
@ -209,7 +208,7 @@ def get_mlperf_bert_config():
"vocab_size": 30522
}
def get_mlperf_bert_model():
def get_mlperf_bert_model(checkpoint_path:str):
from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
@ -217,95 +216,10 @@ def get_mlperf_bert_model():
bert.Embedding = EmbeddingBert
bert.LayerNorm = LayerNormBert
from extra.models.bert import BertForMLPerf
config = get_mlperf_bert_config()
return BertForMLPerf(
config["hidden_size"],
config["intermediate_size"],
config["max_position_embeddings"],
config["num_attention_heads"],
config["num_hidden_layers"],
config["type_vocab_size"],
config["vocab_size"],
config["attention_probs_dropout_prob"],
config["hidden_dropout_prob"]
)
def init_bert_from_checkpoint(model, ckpt_dir:str):
for tinygrad_key, x in get_state_dict(model).items():
if not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=ckpt_dir)
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
t = t.transpose()
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
t = t.reshape(*x.shape).transpose()
elif all(k in tinygrad_key for k in ["self", "bias"]):
t = t.reshape(*x.shape)
x.assign(t)
from extra.models.bert import BertForPretraining
return BertForPretraining(**get_mlperf_bert_config()).load_from_pretrained(checkpoint_path)
def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it)
for key in data.keys(): data[key].shard_(GPUS, axis=0)
return data
@functools.lru_cache(maxsize=None)
def load_tf_weights_to_dict(checkpoint_path):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
reader = tf.train.load_checkpoint(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
weights_dict = {}
for key in sorted(var_to_shape_map):
weights_dict[key] = reader.get_tensor(key)
return weights_dict
def tt(tf_tensor): return Tensor(tf_tensor, dtype=dtypes.float32)
def load_from_tf2_ckpt(key: str, ckpt_dir: str):
p = "model/layer-3/"
s = "/.ATTRIBUTES/VARIABLE_VALUE"
tf_dict = load_tf_weights_to_dict(ckpt_dir)
if key.startswith("model.embeddings"):
if key.endswith("word_embeddings.weight"): return tt(tf_dict[p+"layer-1/embeddings"+s])
elif key.endswith("position_embeddings.weight"): return tt(tf_dict[p+"layer-3/embeddings"+s])
elif key.endswith("token_type_embeddings.weight"): return tt(tf_dict[p+"layer-4/embeddings"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+"layer-6/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+"layer-6/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("model.encoder.layer"):
l_id = str(int(key.split(".")[3]) + 10)
if ".attention." in key:
if key.endswith("self.query.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/kernel"+s])
elif key.endswith("self.query.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/bias"+s])
elif key.endswith("self.key.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/kernel"+s])
elif key.endswith("self.key.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/bias"+s])
elif key.endswith("self.value.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/kernel"+s])
elif key.endswith("self.value.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/bias"+s])
# Attention output
elif key.endswith("output.dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/kernel"+s])
elif key.endswith("output.dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/bias"+s])
elif key.endswith("output.LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/gamma"+s])
elif key.endswith("output.LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".intermediate." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/bias"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".output." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/bias"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("clsf_pooler.weight"): return tt(tf_dict[f"model/layer-3/layer-35/kernel"+s])
elif key.startswith("clsf_pooler.bias"): return tt(tf_dict[f"model/layer-3/layer-35/bias"+s])
elif key.startswith("clsf_output.weight"): return tt(tf_dict[f"model/layer-6/layer-1/kernel"+s])
elif key.startswith("clsf_output.bias"): return tt(tf_dict[f"model/layer-6/layer-1/bias"+s])
elif key.startswith("lm_transform.weight"): return tt(tf_dict[f"model/layer-5/layer-3/kernel"+s])
elif key.startswith("lm_transform.bias"): return tt(tf_dict[f"model/layer-5/layer-3/bias"+s])
elif key.startswith("lm_norm.weight"): return tt(tf_dict[f"model/layer-5/layer-4/gamma"+s])
elif key.startswith("lm_norm.bias"): return tt(tf_dict[f"model/layer-5/layer-4/beta"+s])
elif key.startswith("lm_output_bias"): return tt(tf_dict[f"model/layer-5/layer-6/bias"+s])
else: raise ValueError(f"Unknown key: {key}")

View File

@ -355,50 +355,38 @@ def train_rnnt():
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
loss = lm_loss + clsf_loss
optimizer.zero_grad()
if not getenv('DISABLE_BACKWARD', 0):
optimizer.zero_grad()
(loss * loss_scaler).backward()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
(loss * loss_scaler).backward()
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt()
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt()
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
optimizer.step()
scheduler.step()
optimizer.step()
scheduler.step()
return loss.realize()
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
clsf_accuracy = (clsf_predictions == next_sentence_labels).mean()
mlm_predictions = lm_logits.log_softmax().argmax(-1)
mask = (masked_lm_weights == 1.0)
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.sum()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return {
"masked_lm_accuracy": mlm_accuracy.realize(),
"masked_lm_loss": lm_loss.realize(),
"next_sentence_accuracy": clsf_accuracy.realize(),
"next_sentence_loss": clsf_loss.realize()
"masked_lm_accuracy": masked_lm_accuracy.realize(),
"next_sentence_accuracy": seq_relationship_accuracy.realize(),
"masked_lm_loss": masked_lm_loss.realize(),
"next_sentence_loss": next_sentence_loss.realize()
}
def train_bert():
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
@ -435,8 +423,7 @@ def train_bert():
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model()
if init_ckpt: init_bert_from_checkpoint(model, init_ckpt)
model = get_mlperf_bert_model(init_ckpt)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)

View File

@ -39,8 +39,9 @@ def download_wikipedia(path:str):
os.makedirs(path, exist_ok=True)
gdrive_download("https://drive.google.com/uc?id=1fbGClQMi2CoMv7fwrwTC5YYPooQBdcFW", os.path.join(path, "bert_config.json"))
gdrive_download("https://drive.google.com/uc?id=1USK108J6hMM_d27xCHi738qBL8_BT1u1", os.path.join(path, "vocab.txt"))
gdrive_download("https://drive.google.com/uc?id=1pJhVkACK3p_7Uc-1pAzRaOXodNeeHZ7F", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
gdrive_download("https://drive.google.com/uc?id=1oVBgtSxkXC9rH2SXJv85RXR9-WrMPy-Q", os.path.join(path, "model.ckpt-28252.index"))
gdrive_download("https://drive.google.com/uc?id=1chiTBljF0Eh1U5pKs6ureVHgSbtU8OG_", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
gdrive_download("https://drive.google.com/uc?id=1Q47V3K3jFRkbJ2zGCrKkKk-n0fvMZsa0", os.path.join(path, "model.ckpt-28252.index"))
gdrive_download("https://drive.google.com/uc?id=1vAcVmXSLsLeQ1q7gvHnQUSth5W_f_pwv", os.path.join(path, "model.ckpt-28252.meta"))
with open(os.path.join(path, "checkpoint"), "w") as f: f.write('model_checkpoint_path: "model.ckpt-28252"\nall_model_checkpoint_paths: "model.ckpt-28252"')
if getenv("WIKI_TRAIN", 0):
gdrive_download("https://drive.google.com/uc?id=1tmMgLwoBvbEJEHXh77sqrXYw5RpqT8R_", os.path.join(path, "bert_reference_results_text_md5.txt"))

View File

@ -1,9 +1,9 @@
from tinygrad.tensor import Tensor
import re, os
from pathlib import Path
from tinygrad.tensor import Tensor, cast
from tinygrad import nn, dtypes
from tinygrad.helpers import fetch, get_child
from pathlib import Path
from examples.mlperf.initializers import LinearBert, LayerNormBert
from tinygrad.nn.state import get_parameters
# allow for monkeypatching
Embedding = nn.Embedding
@ -39,35 +39,121 @@ class BertForQuestionAnswering:
return Tensor.stack(start_logits, end_logits)
class BertForMLPerf:
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None:
self.model = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
# for clsf:
self.clsf_pooler = LinearBert(hidden_size, hidden_size) # [bs, seq, hidden] -> [bs, hidden]
self.clsf_pooling_activation = Tensor.tanh
self.clsf_output = LinearBert(hidden_size, 2) # [bs, hidden] -> [bs, 2]
class BertForPretraining:
def __init__(self, hidden_size:int=1024, intermediate_size:int=4096, max_position_embeddings:int=512, num_attention_heads:int=16, num_hidden_layers:int=24, type_vocab_size:int=2, vocab_size:int=30522, attention_probs_dropout_prob:float=0.1, hidden_dropout_prob:float=0.1):
"""Default is BERT-large"""
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
self.cls = BertPreTrainingHeads(hidden_size, vocab_size, self.bert.embeddings.word_embeddings.weight)
def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:Tensor, token_type_ids:Tensor):
output = self.bert(input_ids, attention_mask, token_type_ids)
return self.cls(output, masked_lm_positions)
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_loss + next_sentence_loss
def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
# for lm:
self.lm_transform = LinearBert(hidden_size, hidden_size)
self.lm_transform_activation = gelu
self.lm_norm = LayerNormBert(hidden_size, eps=1e-12)
self.lm_output = LinearBert(hidden_size, vocab_size, bias=False) # [bs, seq, hidden] -> [bs, seq, vocab]
self.lm_output.weight = self.model.embeddings.word_embeddings.weight
self.lm_output_bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
valid = masked_lm_ids != 0
masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor):
output = self.model(input_ids, attention_mask, segment_ids)
clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32)
seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
# gather only the masked_positions we care about
counter = Tensor.arange(output.shape[1], requires_grad=False, device=output.device).reshape(1, 1, output.shape[1]).expand(*masked_positions.shape, output.shape[1])
onehot = counter == masked_positions.unsqueeze(2).expand(*masked_positions.shape, output.shape[1])
h_masked = onehot @ output
return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
# load from tensorflow
import tensorflow as tf
import numpy as np
h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked)))
lm_logits = self.lm_output(h_masked) + self.lm_output_bias
state_dict = {}
for name, _ in tf.train.list_variables(str(tf_weight_path)):
state_dict[name] = tf.train.load_variable(str(tf_weight_path), name)
return lm_logits, clsf_logits
for k, v in state_dict.items():
m = k.split("/")
if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in m):
continue
pointer = self
n = m[-1] # this is just to stop python from complaining about possibly unbound local variable
for i, n in enumerate(m):
if re.fullmatch(r'[A-Za-z]+_\d+', n):
l = re.split(r'_(\d+)', n)[:-1]
else:
l = [n]
if l[0] in ["kernel", "gamma", "output_weights"]:
pointer = getattr(pointer, "weight")
elif l[0] in ["output_bias", "beta"]:
pointer = getattr(pointer, "bias")
elif l[0] == "pooler":
pointer = getattr(getattr(self, "cls"), "pooler")
else:
pointer = getattr(pointer, l[0])
if len(l) == 2: # layers
pointer = pointer[int(l[1])]
if n[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif n == "kernel":
v = np.transpose(v)
cast(Tensor, pointer).assign(v).realize()
params = get_parameters(self)
count = 0
for p in params:
param_count = 1
for s in p.shape:
param_count *= s
count += param_count
print(f"Total parameters: {count / 1000 / 1000}M")
return self
class BertPreTrainingHeads:
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
self.predictions = BertLMPredictionHead(hidden_size, vocab_size, embeddings_weight)
self.pooler = BertPooler(hidden_size)
self.seq_relationship = Linear(hidden_size, 2)
def __call__(self, sequence_output:Tensor, masked_lm_positions:Tensor):
prediction_logits = self.predictions(gather(sequence_output, masked_lm_positions))
seq_relationship_logits = self.seq_relationship(self.pooler(sequence_output))
return prediction_logits, seq_relationship_logits
class BertLMPredictionHead:
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
self.transform = BertPredictionHeadTransform(hidden_size)
self.embedding_weight = embeddings_weight
self.bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
def __call__(self, hidden_states:Tensor):
return self.transform(hidden_states) @ self.embedding_weight.T + self.bias
class BertPredictionHeadTransform:
def __init__(self, hidden_size:int):
self.dense = Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
def __call__(self, hidden_states:Tensor):
return self.LayerNorm(gelu(self.dense(hidden_states)))
class BertPooler:
def __init__(self, hidden_size:int):
self.dense = Linear(hidden_size, hidden_size)
def __call__(self, hidden_states:Tensor):
return self.dense(hidden_states[:, 0]).tanh()
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
return onehot @ prediction_logits
class Bert:
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):

View File

@ -12,7 +12,7 @@ from tinygrad.device import Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_state_dict
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.dataloader import batch_load_val_bert
from examples.mlperf.model_train import eval_step_bert
@ -27,14 +27,13 @@ if __name__ == "__main__":
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index"]
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
Tensor.training = False
model = get_mlperf_bert_model()
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
model = get_mlperf_bert_model(INIT_CKPT_DIR)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)