
56 lines
2.6 KiB

# Test whether pretrained weights from first BERT pretraining phase have been loaded correctly
# Usage:
# 1. Download the BERT checkoints with `wikipedia_download.py`
# Command: BASEDIR=/path/to/wiki python3 wikipedia_download.py
# 2. Run this script. (Adjust EVAL_BS and GPUS as needed)
# Command: EVAL_BEAM=4 DEFAULT_FLOAT=half GPUS=6 BASEDIR=/path/to/wiki python3 test/external/mlperf_bert/external_test_checkpoint_loading.py
import os
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.device import Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_state_dict
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
from examples.mlperf.dataloader import batch_load_val_bert
from examples.mlperf.model_train import eval_step_bert
if __name__ == "__main__":
BASEDIR = os.environ["BASEDIR"] = getenv("BASEDIR", "/raid/datasets/wiki")
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
EVAL_BS = getenv("EVAL_BS", 4 * len(GPUS))
max_eval_steps = (10000 + EVAL_BS - 1) // EVAL_BS
for i in range(10):
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
Tensor.training = False
model = get_mlperf_bert_model(INIT_CKPT_DIR)
for _, x in get_state_dict(model).items():
eval_accuracy = []
eval_it = iter(batch_load_val_bert(EVAL_BS))
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps):
eval_data = get_data_bert(GPUS, eval_it)
eval_result: dict[str, Tensor] = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], \
eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], \
eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
mlm_accuracy = eval_result["masked_lm_accuracy"].numpy().item()
total_lm_accuracy = sum(eval_accuracy) / len(eval_accuracy)
assert total_lm_accuracy >= 0.34, "Checkpoint loaded incorrectly. Accuracy should be very close to 0.34085 as per MLPerf BERT README."
print(f"Checkpoint loaded correctly. Accuracy of {total_lm_accuracy*100:.3f}% achieved. (Reference: 34.085%)")