Wikipedia download script for MLPerf BERT training (#4202)

* wikipedia download script

* add link

* checksum valueError

* ops
This commit is contained in:
Elias Wahl 2024-04-17 22:34:57 +02:00 committed by GitHub
parent f75020a903
commit 6eef8ee22a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 1 deletions

1
.gitignore vendored
View File

@ -24,6 +24,7 @@ disassemblers/cuda_ioctl_sniffer
extra/datasets/cifar-10-python.tar.gz
extra/datasets/librispeech/
extra/datasets/imagenet/
extra/datasets/wiki/
extra/datasets/kits19
extra/datasets/kits19/
extra/datasets/squad/

View File

@ -175,4 +175,4 @@ BS | [8, 16, 32, 64, 128] | batch size to use
Variable | Possible Value(s) | Description
---|---|---
IMGNET_TRAIN | [1] | download also training data with imagenet
IMGNET_TRAIN | [1] | download also training data with imagenet

View File

@ -0,0 +1,53 @@
# pip install gdown
# Downloads the 2020 wikipedia dataset used for MLPerf BERT training
import os, hashlib
from pathlib import Path
import tarfile
import gdown
from tqdm import tqdm
from tinygrad.helpers import getenv
def gdrive_download(url:str, path:str):
if not os.path.exists(path): gdown.download(url, path)
def wikipedia_uncompress_and_extract(file:str, path:str, small:bool=False):
if not os.path.exists(os.path.join(path, "results4")):
print("Uncompressing and extracting file...")
with tarfile.open(file, 'r:gz') as tar:
tar.extractall(path=path)
os.remove(file)
if small:
for member in tar.getmembers(): tar.extract(path=path, member=member)
else:
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
def verify_checksum(folder_path:str, checksum_path:str):
print("Verifying checksums...")
with open(checksum_path, 'r') as f:
for line in f:
expected_checksum, folder_name = line.split()
file_path = os.path.join(folder_path, folder_name[2:]) # remove './' from the start of the folder name
hasher = hashlib.md5()
with open(file_path, 'rb') as f:
for buf in iter(lambda: f.read(4096), b''): hasher.update(buf)
if hasher.hexdigest() != expected_checksum:
raise ValueError(f"Checksum does not match for file: {file_path}")
print("All checksums match.")
def download_wikipedia(path:str):
# Links from: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/dataset.md
os.makedirs(path, exist_ok=True)
gdrive_download("https://drive.google.com/uc?id=1fbGClQMi2CoMv7fwrwTC5YYPooQBdcFW", os.path.join(path, "bert_config.json"))
gdrive_download("https://drive.google.com/uc?id=1USK108J6hMM_d27xCHi738qBL8_BT1u1", os.path.join(path, "vocab.txt"))
gdrive_download("https://drive.google.com/uc?id=1pJhVkACK3p_7Uc-1pAzRaOXodNeeHZ7F", os.path.join(path, "model.ckpt-28252.data-00000-of-00001"))
gdrive_download("https://drive.google.com/uc?id=1oVBgtSxkXC9rH2SXJv85RXR9-WrMPy-Q", os.path.join(path, "model.ckpt-28252.index"))
with open(os.path.join(path, "checkpoint"), "w") as f: f.write('model_checkpoint_path: "model.ckpt-28252"\nall_model_checkpoint_paths: "model.ckpt-28252"')
if getenv("WIKI_TRAIN", 0):
gdrive_download("https://drive.google.com/uc?id=1tmMgLwoBvbEJEHXh77sqrXYw5RpqT8R_", os.path.join(path, "bert_reference_results_text_md5.txt"))
gdrive_download("https://drive.google.com/uc?id=14xV2OUGSQDG_yDBrmbSdcDC-QGeqpfs_", os.path.join(path, "results_text.tar.gz"))
wikipedia_uncompress_and_extract(os.path.join(path, "results_text.tar.gz"), path)
if getenv("VERIFY_CHECKSUM", 0):
verify_checksum(os.path.join(path, "results4"), os.path.join(path, "bert_reference_results_text_md5.txt"))
if __name__ == "__main__":
download_wikipedia(getenv("BASEDIR", os.path.join(Path(__file__).parent / "wiki")))