mirror of https://github.com/commaai/tinygrad.git
Refactor to class style (#4804)
This commit is contained in:
parent
1b8bed4a26
commit
04e237328b
|
@ -195,12 +195,11 @@ def get_bert_qa_prediction(features, example, start_end_logits):
|
|||
return "empty"
|
||||
|
||||
def get_mlperf_bert_config():
|
||||
"""Config is BERT-large"""
|
||||
return {
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 16,
|
||||
|
@ -209,7 +208,7 @@ def get_mlperf_bert_config():
|
|||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
def get_mlperf_bert_model():
|
||||
def get_mlperf_bert_model(checkpoint_path:str):
|
||||
from extra.models import bert
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
|
@ -217,95 +216,10 @@ def get_mlperf_bert_model():
|
|||
bert.Embedding = EmbeddingBert
|
||||
bert.LayerNorm = LayerNormBert
|
||||
|
||||
from extra.models.bert import BertForMLPerf
|
||||
|
||||
config = get_mlperf_bert_config()
|
||||
return BertForMLPerf(
|
||||
config["hidden_size"],
|
||||
config["intermediate_size"],
|
||||
config["max_position_embeddings"],
|
||||
config["num_attention_heads"],
|
||||
config["num_hidden_layers"],
|
||||
config["type_vocab_size"],
|
||||
config["vocab_size"],
|
||||
config["attention_probs_dropout_prob"],
|
||||
config["hidden_dropout_prob"]
|
||||
)
|
||||
|
||||
def init_bert_from_checkpoint(model, ckpt_dir:str):
|
||||
for tinygrad_key, x in get_state_dict(model).items():
|
||||
if not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
|
||||
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=ckpt_dir)
|
||||
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
|
||||
t = t.transpose()
|
||||
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
|
||||
t = t.reshape(*x.shape).transpose()
|
||||
elif all(k in tinygrad_key for k in ["self", "bias"]):
|
||||
t = t.reshape(*x.shape)
|
||||
x.assign(t)
|
||||
from extra.models.bert import BertForPretraining
|
||||
return BertForPretraining(**get_mlperf_bert_config()).load_from_pretrained(checkpoint_path)
|
||||
|
||||
def get_data_bert(GPUS:list[str], it):
|
||||
data: dict[str, Tensor] = next(it)
|
||||
for key in data.keys(): data[key].shard_(GPUS, axis=0)
|
||||
return data
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def load_tf_weights_to_dict(checkpoint_path):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
import tensorflow as tf
|
||||
reader = tf.train.load_checkpoint(checkpoint_path)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
weights_dict = {}
|
||||
|
||||
for key in sorted(var_to_shape_map):
|
||||
weights_dict[key] = reader.get_tensor(key)
|
||||
return weights_dict
|
||||
|
||||
def tt(tf_tensor): return Tensor(tf_tensor, dtype=dtypes.float32)
|
||||
|
||||
def load_from_tf2_ckpt(key: str, ckpt_dir: str):
|
||||
p = "model/layer-3/"
|
||||
s = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
tf_dict = load_tf_weights_to_dict(ckpt_dir)
|
||||
if key.startswith("model.embeddings"):
|
||||
if key.endswith("word_embeddings.weight"): return tt(tf_dict[p+"layer-1/embeddings"+s])
|
||||
elif key.endswith("position_embeddings.weight"): return tt(tf_dict[p+"layer-3/embeddings"+s])
|
||||
elif key.endswith("token_type_embeddings.weight"): return tt(tf_dict[p+"layer-4/embeddings"+s])
|
||||
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+"layer-6/gamma"+s])
|
||||
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+"layer-6/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif key.startswith("model.encoder.layer"):
|
||||
l_id = str(int(key.split(".")[3]) + 10)
|
||||
if ".attention." in key:
|
||||
if key.endswith("self.query.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/kernel"+s])
|
||||
elif key.endswith("self.query.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/bias"+s])
|
||||
elif key.endswith("self.key.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/kernel"+s])
|
||||
elif key.endswith("self.key.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/bias"+s])
|
||||
elif key.endswith("self.value.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/kernel"+s])
|
||||
elif key.endswith("self.value.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/bias"+s])
|
||||
# Attention output
|
||||
elif key.endswith("output.dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/kernel"+s])
|
||||
elif key.endswith("output.dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/bias"+s])
|
||||
elif key.endswith("output.LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/gamma"+s])
|
||||
elif key.endswith("output.LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif ".intermediate." in key:
|
||||
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/kernel"+s])
|
||||
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/bias"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif ".output." in key:
|
||||
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/kernel"+s])
|
||||
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/bias"+s])
|
||||
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/gamma"+s])
|
||||
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/beta"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
||||
elif key.startswith("clsf_pooler.weight"): return tt(tf_dict[f"model/layer-3/layer-35/kernel"+s])
|
||||
elif key.startswith("clsf_pooler.bias"): return tt(tf_dict[f"model/layer-3/layer-35/bias"+s])
|
||||
elif key.startswith("clsf_output.weight"): return tt(tf_dict[f"model/layer-6/layer-1/kernel"+s])
|
||||
elif key.startswith("clsf_output.bias"): return tt(tf_dict[f"model/layer-6/layer-1/bias"+s])
|
||||
elif key.startswith("lm_transform.weight"): return tt(tf_dict[f"model/layer-5/layer-3/kernel"+s])
|
||||
elif key.startswith("lm_transform.bias"): return tt(tf_dict[f"model/layer-5/layer-3/bias"+s])
|
||||
elif key.startswith("lm_norm.weight"): return tt(tf_dict[f"model/layer-5/layer-4/gamma"+s])
|
||||
elif key.startswith("lm_norm.bias"): return tt(tf_dict[f"model/layer-5/layer-4/beta"+s])
|
||||
elif key.startswith("lm_output_bias"): return tt(tf_dict[f"model/layer-5/layer-6/bias"+s])
|
||||
else: raise ValueError(f"Unknown key: {key}")
|
|
@ -355,50 +355,38 @@ def train_rnnt():
|
|||
|
||||
@TinyJit
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
|
||||
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
loss = lm_loss + clsf_loss
|
||||
optimizer.zero_grad()
|
||||
|
||||
if not getenv('DISABLE_BACKWARD', 0):
|
||||
optimizer.zero_grad()
|
||||
(loss * loss_scaler).backward()
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
(loss * loss_scaler).backward()
|
||||
|
||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
||||
for p in optimizer.params:
|
||||
p.grad = p.grad / loss_scaler
|
||||
global_norm += p.grad.float().square().sum()
|
||||
global_norm = global_norm.sqrt()
|
||||
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
|
||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
||||
for p in optimizer.params:
|
||||
p.grad = p.grad / loss_scaler
|
||||
global_norm += p.grad.float().square().sum()
|
||||
global_norm = global_norm.sqrt()
|
||||
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
return loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
|
||||
|
||||
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
|
||||
clsf_accuracy = (clsf_predictions == next_sentence_labels).mean()
|
||||
|
||||
mlm_predictions = lm_logits.log_softmax().argmax(-1)
|
||||
mask = (masked_lm_weights == 1.0)
|
||||
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.sum()
|
||||
|
||||
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return {
|
||||
"masked_lm_accuracy": mlm_accuracy.realize(),
|
||||
"masked_lm_loss": lm_loss.realize(),
|
||||
"next_sentence_accuracy": clsf_accuracy.realize(),
|
||||
"next_sentence_loss": clsf_loss.realize()
|
||||
"masked_lm_accuracy": masked_lm_accuracy.realize(),
|
||||
"next_sentence_accuracy": seq_relationship_accuracy.realize(),
|
||||
"masked_lm_loss": masked_lm_loss.realize(),
|
||||
"next_sentence_loss": next_sentence_loss.realize()
|
||||
}
|
||||
|
||||
def train_bert():
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
||||
|
||||
config = {}
|
||||
|
@ -435,8 +423,7 @@ def train_bert():
|
|||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model()
|
||||
if init_ckpt: init_bert_from_checkpoint(model, init_ckpt)
|
||||
model = get_mlperf_bert_model(init_ckpt)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
|
|
|
@ -39,8 +39,9 @@ def download_wikipedia(path:str):
|
|||
os.makedirs(path, exist_ok=True)
|
||||
gdrive_download("https://drive.google.com/uc?id=1fbGClQMi2CoMv7fwrwTC5YYPooQBdcFW", os.path.join(path, "bert_config.json"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1USK108J6hMM_d27xCHi738qBL8_BT1u1", os.path.join(path, "vocab.txt"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1pJhVkACK3p_7Uc-1pAzRaOXodNeeHZ7F", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1oVBgtSxkXC9rH2SXJv85RXR9-WrMPy-Q", os.path.join(path, "model.ckpt-28252.index"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1chiTBljF0Eh1U5pKs6ureVHgSbtU8OG_", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1Q47V3K3jFRkbJ2zGCrKkKk-n0fvMZsa0", os.path.join(path, "model.ckpt-28252.index"))
|
||||
gdrive_download("https://drive.google.com/uc?id=1vAcVmXSLsLeQ1q7gvHnQUSth5W_f_pwv", os.path.join(path, "model.ckpt-28252.meta"))
|
||||
with open(os.path.join(path, "checkpoint"), "w") as f: f.write('model_checkpoint_path: "model.ckpt-28252"\nall_model_checkpoint_paths: "model.ckpt-28252"')
|
||||
if getenv("WIKI_TRAIN", 0):
|
||||
gdrive_download("https://drive.google.com/uc?id=1tmMgLwoBvbEJEHXh77sqrXYw5RpqT8R_", os.path.join(path, "bert_reference_results_text_md5.txt"))
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
import re, os
|
||||
from pathlib import Path
|
||||
from tinygrad.tensor import Tensor, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
from pathlib import Path
|
||||
|
||||
from examples.mlperf.initializers import LinearBert, LayerNormBert
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
# allow for monkeypatching
|
||||
Embedding = nn.Embedding
|
||||
|
@ -39,35 +39,121 @@ class BertForQuestionAnswering:
|
|||
|
||||
return Tensor.stack(start_logits, end_logits)
|
||||
|
||||
class BertForMLPerf:
|
||||
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None:
|
||||
self.model = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
# for clsf:
|
||||
self.clsf_pooler = LinearBert(hidden_size, hidden_size) # [bs, seq, hidden] -> [bs, hidden]
|
||||
self.clsf_pooling_activation = Tensor.tanh
|
||||
self.clsf_output = LinearBert(hidden_size, 2) # [bs, hidden] -> [bs, 2]
|
||||
class BertForPretraining:
|
||||
def __init__(self, hidden_size:int=1024, intermediate_size:int=4096, max_position_embeddings:int=512, num_attention_heads:int=16, num_hidden_layers:int=24, type_vocab_size:int=2, vocab_size:int=30522, attention_probs_dropout_prob:float=0.1, hidden_dropout_prob:float=0.1):
|
||||
"""Default is BERT-large"""
|
||||
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
self.cls = BertPreTrainingHeads(hidden_size, vocab_size, self.bert.embeddings.word_embeddings.weight)
|
||||
|
||||
def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:Tensor, token_type_ids:Tensor):
|
||||
output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||
return self.cls(output, masked_lm_positions)
|
||||
|
||||
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
return masked_lm_loss + next_sentence_loss
|
||||
|
||||
def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
|
||||
# for lm:
|
||||
self.lm_transform = LinearBert(hidden_size, hidden_size)
|
||||
self.lm_transform_activation = gelu
|
||||
self.lm_norm = LayerNormBert(hidden_size, eps=1e-12)
|
||||
self.lm_output = LinearBert(hidden_size, vocab_size, bias=False) # [bs, seq, hidden] -> [bs, seq, vocab]
|
||||
self.lm_output.weight = self.model.embeddings.word_embeddings.weight
|
||||
self.lm_output_bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
|
||||
valid = masked_lm_ids != 0
|
||||
masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
|
||||
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
|
||||
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
|
||||
def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor):
|
||||
output = self.model(input_ids, attention_mask, segment_ids)
|
||||
clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32)
|
||||
seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
|
||||
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
|
||||
# gather only the masked_positions we care about
|
||||
counter = Tensor.arange(output.shape[1], requires_grad=False, device=output.device).reshape(1, 1, output.shape[1]).expand(*masked_positions.shape, output.shape[1])
|
||||
onehot = counter == masked_positions.unsqueeze(2).expand(*masked_positions.shape, output.shape[1])
|
||||
h_masked = onehot @ output
|
||||
return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
|
||||
|
||||
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
|
||||
# load from tensorflow
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked)))
|
||||
lm_logits = self.lm_output(h_masked) + self.lm_output_bias
|
||||
state_dict = {}
|
||||
for name, _ in tf.train.list_variables(str(tf_weight_path)):
|
||||
state_dict[name] = tf.train.load_variable(str(tf_weight_path), name)
|
||||
|
||||
return lm_logits, clsf_logits
|
||||
for k, v in state_dict.items():
|
||||
m = k.split("/")
|
||||
if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in m):
|
||||
continue
|
||||
|
||||
pointer = self
|
||||
n = m[-1] # this is just to stop python from complaining about possibly unbound local variable
|
||||
for i, n in enumerate(m):
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', n):
|
||||
l = re.split(r'_(\d+)', n)[:-1]
|
||||
else:
|
||||
l = [n]
|
||||
if l[0] in ["kernel", "gamma", "output_weights"]:
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] in ["output_bias", "beta"]:
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "pooler":
|
||||
pointer = getattr(getattr(self, "cls"), "pooler")
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) == 2: # layers
|
||||
pointer = pointer[int(l[1])]
|
||||
if n[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif n == "kernel":
|
||||
v = np.transpose(v)
|
||||
cast(Tensor, pointer).assign(v).realize()
|
||||
|
||||
params = get_parameters(self)
|
||||
count = 0
|
||||
for p in params:
|
||||
param_count = 1
|
||||
for s in p.shape:
|
||||
param_count *= s
|
||||
count += param_count
|
||||
print(f"Total parameters: {count / 1000 / 1000}M")
|
||||
return self
|
||||
|
||||
class BertPreTrainingHeads:
|
||||
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
|
||||
self.predictions = BertLMPredictionHead(hidden_size, vocab_size, embeddings_weight)
|
||||
self.pooler = BertPooler(hidden_size)
|
||||
self.seq_relationship = Linear(hidden_size, 2)
|
||||
|
||||
def __call__(self, sequence_output:Tensor, masked_lm_positions:Tensor):
|
||||
prediction_logits = self.predictions(gather(sequence_output, masked_lm_positions))
|
||||
seq_relationship_logits = self.seq_relationship(self.pooler(sequence_output))
|
||||
return prediction_logits, seq_relationship_logits
|
||||
|
||||
class BertLMPredictionHead:
|
||||
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
|
||||
self.transform = BertPredictionHeadTransform(hidden_size)
|
||||
self.embedding_weight = embeddings_weight
|
||||
self.bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, hidden_states:Tensor):
|
||||
return self.transform(hidden_states) @ self.embedding_weight.T + self.bias
|
||||
|
||||
class BertPredictionHeadTransform:
|
||||
def __init__(self, hidden_size:int):
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
|
||||
def __call__(self, hidden_states:Tensor):
|
||||
return self.LayerNorm(gelu(self.dense(hidden_states)))
|
||||
|
||||
class BertPooler:
|
||||
def __init__(self, hidden_size:int):
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
|
||||
def __call__(self, hidden_states:Tensor):
|
||||
return self.dense(hidden_states[:, 0]).tanh()
|
||||
|
||||
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
|
||||
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
return onehot @ prediction_logits
|
||||
|
||||
class Bert:
|
||||
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.device import Device
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
|
||||
from examples.mlperf.dataloader import batch_load_val_bert
|
||||
from examples.mlperf.model_train import eval_step_bert
|
||||
|
||||
|
@ -27,14 +27,13 @@ if __name__ == "__main__":
|
|||
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
|
||||
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
|
||||
|
||||
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index"]
|
||||
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
|
||||
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
|
||||
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
|
||||
|
||||
Tensor.training = False
|
||||
|
||||
model = get_mlperf_bert_model()
|
||||
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
|
||||
model = get_mlperf_bert_model(INIT_CKPT_DIR)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
|
|
Loading…
Reference in New Issue