Refactor to class style (#4804)

This commit is contained in:
Elias Wahl 2024-06-04 23:08:31 +02:00 committed by GitHub
parent 1b8bed4a26
commit 04e237328b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 144 additions and 157 deletions

View File

@ -195,12 +195,11 @@ def get_bert_qa_prediction(features, example, start_end_logits):
return "empty" return "empty"
def get_mlperf_bert_config(): def get_mlperf_bert_config():
"""Config is BERT-large"""
return { return {
"attention_probs_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1, "hidden_dropout_prob": 0.1,
"hidden_size": 1024, "hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096, "intermediate_size": 4096,
"max_position_embeddings": 512, "max_position_embeddings": 512,
"num_attention_heads": 16, "num_attention_heads": 16,
@ -209,7 +208,7 @@ def get_mlperf_bert_config():
"vocab_size": 30522 "vocab_size": 30522
} }
def get_mlperf_bert_model(): def get_mlperf_bert_model(checkpoint_path:str):
from extra.models import bert from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
@ -217,95 +216,10 @@ def get_mlperf_bert_model():
bert.Embedding = EmbeddingBert bert.Embedding = EmbeddingBert
bert.LayerNorm = LayerNormBert bert.LayerNorm = LayerNormBert
from extra.models.bert import BertForMLPerf from extra.models.bert import BertForPretraining
return BertForPretraining(**get_mlperf_bert_config()).load_from_pretrained(checkpoint_path)
config = get_mlperf_bert_config()
return BertForMLPerf(
config["hidden_size"],
config["intermediate_size"],
config["max_position_embeddings"],
config["num_attention_heads"],
config["num_hidden_layers"],
config["type_vocab_size"],
config["vocab_size"],
config["attention_probs_dropout_prob"],
config["hidden_dropout_prob"]
)
def init_bert_from_checkpoint(model, ckpt_dir:str):
for tinygrad_key, x in get_state_dict(model).items():
if not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=ckpt_dir)
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
t = t.transpose()
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
t = t.reshape(*x.shape).transpose()
elif all(k in tinygrad_key for k in ["self", "bias"]):
t = t.reshape(*x.shape)
x.assign(t)
def get_data_bert(GPUS:list[str], it): def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it) data: dict[str, Tensor] = next(it)
for key in data.keys(): data[key].shard_(GPUS, axis=0) for key in data.keys(): data[key].shard_(GPUS, axis=0)
return data return data
@functools.lru_cache(maxsize=None)
def load_tf_weights_to_dict(checkpoint_path):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
reader = tf.train.load_checkpoint(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
weights_dict = {}
for key in sorted(var_to_shape_map):
weights_dict[key] = reader.get_tensor(key)
return weights_dict
def tt(tf_tensor): return Tensor(tf_tensor, dtype=dtypes.float32)
def load_from_tf2_ckpt(key: str, ckpt_dir: str):
p = "model/layer-3/"
s = "/.ATTRIBUTES/VARIABLE_VALUE"
tf_dict = load_tf_weights_to_dict(ckpt_dir)
if key.startswith("model.embeddings"):
if key.endswith("word_embeddings.weight"): return tt(tf_dict[p+"layer-1/embeddings"+s])
elif key.endswith("position_embeddings.weight"): return tt(tf_dict[p+"layer-3/embeddings"+s])
elif key.endswith("token_type_embeddings.weight"): return tt(tf_dict[p+"layer-4/embeddings"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+"layer-6/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+"layer-6/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("model.encoder.layer"):
l_id = str(int(key.split(".")[3]) + 10)
if ".attention." in key:
if key.endswith("self.query.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/kernel"+s])
elif key.endswith("self.query.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_query_dense/bias"+s])
elif key.endswith("self.key.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/kernel"+s])
elif key.endswith("self.key.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_key_dense/bias"+s])
elif key.endswith("self.value.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/kernel"+s])
elif key.endswith("self.value.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer/_value_dense/bias"+s])
# Attention output
elif key.endswith("output.dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/kernel"+s])
elif key.endswith("output.dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_output_dense/bias"+s])
elif key.endswith("output.LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/gamma"+s])
elif key.endswith("output.LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_attention_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".intermediate." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_intermediate_dense/bias"+s])
else: raise ValueError(f"Unknown key: {key}")
elif ".output." in key:
if key.endswith("dense.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/kernel"+s])
elif key.endswith("dense.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_dense/bias"+s])
elif key.endswith("LayerNorm.weight"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/gamma"+s])
elif key.endswith("LayerNorm.bias"): return tt(tf_dict[p+f"layer-{l_id}/_output_layer_norm/beta"+s])
else: raise ValueError(f"Unknown key: {key}")
elif key.startswith("clsf_pooler.weight"): return tt(tf_dict[f"model/layer-3/layer-35/kernel"+s])
elif key.startswith("clsf_pooler.bias"): return tt(tf_dict[f"model/layer-3/layer-35/bias"+s])
elif key.startswith("clsf_output.weight"): return tt(tf_dict[f"model/layer-6/layer-1/kernel"+s])
elif key.startswith("clsf_output.bias"): return tt(tf_dict[f"model/layer-6/layer-1/bias"+s])
elif key.startswith("lm_transform.weight"): return tt(tf_dict[f"model/layer-5/layer-3/kernel"+s])
elif key.startswith("lm_transform.bias"): return tt(tf_dict[f"model/layer-5/layer-3/bias"+s])
elif key.startswith("lm_norm.weight"): return tt(tf_dict[f"model/layer-5/layer-4/gamma"+s])
elif key.startswith("lm_norm.bias"): return tt(tf_dict[f"model/layer-5/layer-4/beta"+s])
elif key.startswith("lm_output_bias"): return tt(tf_dict[f"model/layer-5/layer-6/bias"+s])
else: raise ValueError(f"Unknown key: {key}")

View File

@ -355,50 +355,38 @@ def train_rnnt():
@TinyJit @TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions) optimizer.zero_grad()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
loss = lm_loss + clsf_loss
if not getenv('DISABLE_BACKWARD', 0): lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
optimizer.zero_grad() loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
(loss * loss_scaler).backward() (loss * loss_scaler).backward()
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize() global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params: for p in optimizer.params:
p.grad = p.grad / loss_scaler p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum() global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt() global_norm = global_norm.sqrt()
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype) for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
return loss.realize() return loss.realize()
@TinyJit @TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions) lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
clsf_accuracy = (clsf_predictions == next_sentence_labels).mean()
mlm_predictions = lm_logits.log_softmax().argmax(-1)
mask = (masked_lm_weights == 1.0)
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.sum()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
return { return {
"masked_lm_accuracy": mlm_accuracy.realize(), "masked_lm_accuracy": masked_lm_accuracy.realize(),
"masked_lm_loss": lm_loss.realize(), "next_sentence_accuracy": seq_relationship_accuracy.realize(),
"next_sentence_accuracy": clsf_accuracy.realize(), "masked_lm_loss": masked_lm_loss.realize(),
"next_sentence_loss": clsf_loss.realize() "next_sentence_loss": next_sentence_loss.realize()
} }
def train_bert(): def train_bert():
# NOTE: pip install tensorflow, wandb required # NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {} config = {}
@ -435,8 +423,7 @@ def train_bert():
Tensor.manual_seed(seed) # seed for weight initialization Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model() model = get_mlperf_bert_model(init_ckpt)
if init_ckpt: init_bert_from_checkpoint(model, init_ckpt)
for _, x in get_state_dict(model).items(): for _, x in get_state_dict(model).items():
x.realize().to_(GPUS) x.realize().to_(GPUS)

View File

@ -39,8 +39,9 @@ def download_wikipedia(path:str):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
gdrive_download("https://drive.google.com/uc?id=1fbGClQMi2CoMv7fwrwTC5YYPooQBdcFW", os.path.join(path, "bert_config.json")) gdrive_download("https://drive.google.com/uc?id=1fbGClQMi2CoMv7fwrwTC5YYPooQBdcFW", os.path.join(path, "bert_config.json"))
gdrive_download("https://drive.google.com/uc?id=1USK108J6hMM_d27xCHi738qBL8_BT1u1", os.path.join(path, "vocab.txt")) gdrive_download("https://drive.google.com/uc?id=1USK108J6hMM_d27xCHi738qBL8_BT1u1", os.path.join(path, "vocab.txt"))
gdrive_download("https://drive.google.com/uc?id=1pJhVkACK3p_7Uc-1pAzRaOXodNeeHZ7F", os.path.join(path, "model.ckpt-28252.data-00000-of-00001")) gdrive_download("https://drive.google.com/uc?id=1chiTBljF0Eh1U5pKs6ureVHgSbtU8OG_", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
gdrive_download("https://drive.google.com/uc?id=1oVBgtSxkXC9rH2SXJv85RXR9-WrMPy-Q", os.path.join(path, "model.ckpt-28252.index")) gdrive_download("https://drive.google.com/uc?id=1Q47V3K3jFRkbJ2zGCrKkKk-n0fvMZsa0", os.path.join(path, "model.ckpt-28252.index"))
gdrive_download("https://drive.google.com/uc?id=1vAcVmXSLsLeQ1q7gvHnQUSth5W_f_pwv", os.path.join(path, "model.ckpt-28252.meta"))
with open(os.path.join(path, "checkpoint"), "w") as f: f.write('model_checkpoint_path: "model.ckpt-28252"\nall_model_checkpoint_paths: "model.ckpt-28252"') with open(os.path.join(path, "checkpoint"), "w") as f: f.write('model_checkpoint_path: "model.ckpt-28252"\nall_model_checkpoint_paths: "model.ckpt-28252"')
if getenv("WIKI_TRAIN", 0): if getenv("WIKI_TRAIN", 0):
gdrive_download("https://drive.google.com/uc?id=1tmMgLwoBvbEJEHXh77sqrXYw5RpqT8R_", os.path.join(path, "bert_reference_results_text_md5.txt")) gdrive_download("https://drive.google.com/uc?id=1tmMgLwoBvbEJEHXh77sqrXYw5RpqT8R_", os.path.join(path, "bert_reference_results_text_md5.txt"))

View File

@ -1,9 +1,9 @@
from tinygrad.tensor import Tensor import re, os
from pathlib import Path
from tinygrad.tensor import Tensor, cast
from tinygrad import nn, dtypes from tinygrad import nn, dtypes
from tinygrad.helpers import fetch, get_child from tinygrad.helpers import fetch, get_child
from pathlib import Path from tinygrad.nn.state import get_parameters
from examples.mlperf.initializers import LinearBert, LayerNormBert
# allow for monkeypatching # allow for monkeypatching
Embedding = nn.Embedding Embedding = nn.Embedding
@ -39,35 +39,121 @@ class BertForQuestionAnswering:
return Tensor.stack(start_logits, end_logits) return Tensor.stack(start_logits, end_logits)
class BertForMLPerf: class BertForPretraining:
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None: def __init__(self, hidden_size:int=1024, intermediate_size:int=4096, max_position_embeddings:int=512, num_attention_heads:int=16, num_hidden_layers:int=24, type_vocab_size:int=2, vocab_size:int=30522, attention_probs_dropout_prob:float=0.1, hidden_dropout_prob:float=0.1):
self.model = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob) """Default is BERT-large"""
# for clsf: self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
self.clsf_pooler = LinearBert(hidden_size, hidden_size) # [bs, seq, hidden] -> [bs, hidden] self.cls = BertPreTrainingHeads(hidden_size, vocab_size, self.bert.embeddings.word_embeddings.weight)
self.clsf_pooling_activation = Tensor.tanh
self.clsf_output = LinearBert(hidden_size, 2) # [bs, hidden] -> [bs, 2]
# for lm: def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:Tensor, token_type_ids:Tensor):
self.lm_transform = LinearBert(hidden_size, hidden_size) output = self.bert(input_ids, attention_mask, token_type_ids)
self.lm_transform_activation = gelu return self.cls(output, masked_lm_positions)
self.lm_norm = LayerNormBert(hidden_size, eps=1e-12)
self.lm_output = LinearBert(hidden_size, vocab_size, bias=False) # [bs, seq, hidden] -> [bs, seq, vocab]
self.lm_output.weight = self.model.embeddings.word_embeddings.weight
self.lm_output_bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor): def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
output = self.model(input_ids, attention_mask, segment_ids) masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_loss + next_sentence_loss
# gather only the masked_positions we care about def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
counter = Tensor.arange(output.shape[1], requires_grad=False, device=output.device).reshape(1, 1, output.shape[1]).expand(*masked_positions.shape, output.shape[1])
onehot = counter == masked_positions.unsqueeze(2).expand(*masked_positions.shape, output.shape[1])
h_masked = onehot @ output
h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked))) valid = masked_lm_ids != 0
lm_logits = self.lm_output(h_masked) + self.lm_output_bias masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
return lm_logits, clsf_logits seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
# load from tensorflow
import tensorflow as tf
import numpy as np
state_dict = {}
for name, _ in tf.train.list_variables(str(tf_weight_path)):
state_dict[name] = tf.train.load_variable(str(tf_weight_path), name)
for k, v in state_dict.items():
m = k.split("/")
if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in m):
continue
pointer = self
n = m[-1] # this is just to stop python from complaining about possibly unbound local variable
for i, n in enumerate(m):
if re.fullmatch(r'[A-Za-z]+_\d+', n):
l = re.split(r'_(\d+)', n)[:-1]
else:
l = [n]
if l[0] in ["kernel", "gamma", "output_weights"]:
pointer = getattr(pointer, "weight")
elif l[0] in ["output_bias", "beta"]:
pointer = getattr(pointer, "bias")
elif l[0] == "pooler":
pointer = getattr(getattr(self, "cls"), "pooler")
else:
pointer = getattr(pointer, l[0])
if len(l) == 2: # layers
pointer = pointer[int(l[1])]
if n[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif n == "kernel":
v = np.transpose(v)
cast(Tensor, pointer).assign(v).realize()
params = get_parameters(self)
count = 0
for p in params:
param_count = 1
for s in p.shape:
param_count *= s
count += param_count
print(f"Total parameters: {count / 1000 / 1000}M")
return self
class BertPreTrainingHeads:
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
self.predictions = BertLMPredictionHead(hidden_size, vocab_size, embeddings_weight)
self.pooler = BertPooler(hidden_size)
self.seq_relationship = Linear(hidden_size, 2)
def __call__(self, sequence_output:Tensor, masked_lm_positions:Tensor):
prediction_logits = self.predictions(gather(sequence_output, masked_lm_positions))
seq_relationship_logits = self.seq_relationship(self.pooler(sequence_output))
return prediction_logits, seq_relationship_logits
class BertLMPredictionHead:
def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
self.transform = BertPredictionHeadTransform(hidden_size)
self.embedding_weight = embeddings_weight
self.bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
def __call__(self, hidden_states:Tensor):
return self.transform(hidden_states) @ self.embedding_weight.T + self.bias
class BertPredictionHeadTransform:
def __init__(self, hidden_size:int):
self.dense = Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
def __call__(self, hidden_states:Tensor):
return self.LayerNorm(gelu(self.dense(hidden_states)))
class BertPooler:
def __init__(self, hidden_size:int):
self.dense = Linear(hidden_size, hidden_size)
def __call__(self, hidden_states:Tensor):
return self.dense(hidden_states[:, 0]).tanh()
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
return onehot @ prediction_logits
class Bert: class Bert:
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob): def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):

View File

@ -12,7 +12,7 @@ from tinygrad.device import Device
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.nn.state import get_state_dict from tinygrad.nn.state import get_state_dict
from examples.mlperf.helpers import get_mlperf_bert_model, init_bert_from_checkpoint, get_data_bert from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.dataloader import batch_load_val_bert from examples.mlperf.dataloader import batch_load_val_bert
from examples.mlperf.model_train import eval_step_bert from examples.mlperf.model_train import eval_step_bert
@ -27,14 +27,13 @@ if __name__ == "__main__":
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \ assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}" f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index"] required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \ assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}" f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
Tensor.training = False Tensor.training = False
model = get_mlperf_bert_model() model = get_mlperf_bert_model(INIT_CKPT_DIR)
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
for _, x in get_state_dict(model).items(): for _, x in get_state_dict(model).items():
x.realize().to_(GPUS) x.realize().to_(GPUS)