mirror of https://github.com/commaai/tinygrad.git
56 lines
2.6 KiB
Python
56 lines
2.6 KiB
Python
# Test whether pretrained weights from first BERT pretraining phase have been loaded correctly
|
|
# Usage:
|
|
# 1. Download the BERT checkoints with `wikipedia_download.py`
|
|
# Command: BASEDIR=/path/to/wiki python3 wikipedia_download.py
|
|
# 2. Run this script. (Adjust EVAL_BS and GPUS as needed)
|
|
# Command: EVAL_BEAM=4 DEFAULT_FLOAT=half GPUS=6 BASEDIR=/path/to/wiki python3 test/external/mlperf_bert/external_test_checkpoint_loading.py
|
|
import os
|
|
from tqdm import tqdm
|
|
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.device import Device
|
|
from tinygrad.helpers import getenv
|
|
from tinygrad.nn.state import get_state_dict
|
|
|
|
from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
|
|
from examples.mlperf.dataloader import batch_load_val_bert
|
|
from examples.mlperf.model_train import eval_step_bert
|
|
|
|
if __name__ == "__main__":
|
|
BASEDIR = os.environ["BASEDIR"] = getenv("BASEDIR", "/raid/datasets/wiki")
|
|
INIT_CKPT_DIR = getenv("INIT_CKPT_DIR", BASEDIR)
|
|
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
|
EVAL_BS = getenv("EVAL_BS", 4 * len(GPUS))
|
|
max_eval_steps = (10000 + EVAL_BS - 1) // EVAL_BS
|
|
|
|
for i in range(10):
|
|
assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
|
|
f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
|
|
|
|
required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
|
|
assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
|
|
f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
|
|
|
|
Tensor.training = False
|
|
|
|
model = get_mlperf_bert_model(INIT_CKPT_DIR)
|
|
|
|
for _, x in get_state_dict(model).items():
|
|
x.realize().to_(GPUS)
|
|
|
|
eval_accuracy = []
|
|
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
|
|
|
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps):
|
|
eval_data = get_data_bert(GPUS, eval_it)
|
|
eval_result: dict[str, Tensor] = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], \
|
|
eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], \
|
|
eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
|
|
|
|
mlm_accuracy = eval_result["masked_lm_accuracy"].numpy().item()
|
|
eval_accuracy.append(mlm_accuracy)
|
|
|
|
total_lm_accuracy = sum(eval_accuracy) / len(eval_accuracy)
|
|
assert total_lm_accuracy >= 0.34, "Checkpoint loaded incorrectly. Accuracy should be very close to 0.34085 as per MLPerf BERT README."
|
|
print(f"Checkpoint loaded correctly. Accuracy of {total_lm_accuracy*100:.3f}% achieved. (Reference: 34.085%)")
|