# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training # This is a modified version of the original script: # https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py # ENV VARS: # MAX_SEQ_LENGTH - Maximum sequence length # MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence # RANDOM_SEED - Random seed # DUPE_FACTOR - Number of times to duplicate the input data with different masks # MASKED_LM_PROB - Probability of masking a token # SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH import os, sys, pickle, random, unicodedata from pathlib import Path import numpy as np from tqdm import tqdm from tqdm.contrib.concurrent import process_map from tinygrad.helpers import diskcache, getenv BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki") ################### Tokenization ##################### def _is_whitespace(char:str) -> bool: if char == " " or char == "\t" or char == "\n" or char == "\r": return True return unicodedata.category(char) == "Zs" def _is_control(char:str) -> bool: if char == "\t" or char == "\n" or char == "\r": return False return unicodedata.category(char).startswith("C") def _is_punctuation(char:str) -> bool: # range(33, 48) -> ! " # $ % & ' ( ) * + , - . / # range(58, 65) -> : ; < = > ? @ # range(91, 97) -> [ \ ] ^ _ # range(123, 127) -> { | } ~ if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127): return True return unicodedata.category(char).startswith("P") def _is_chinese_char(cp:int) -> bool: if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) or (cp >= 0x20000 and cp <= 0x2A6DF) or (cp >= 0x2A700 and cp <= 0x2B73F) or (cp >= 0x2B740 and cp <= 0x2B81F) or (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F)): return True return False def _run_split_on_punc(text:str) -> list[str]: if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): return [text] start_new_word = True output = [] for i in range(len(text)): if _is_punctuation(char := text[i]): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) return ["".join(x) for x in output] def _run_strip_accents(text:str) -> str: output = [] for char in unicodedata.normalize("NFD", text): if unicodedata.category(char) != "Mn": output.append(char) return "".join(output) def _clean_text(text:str) -> str: output = [] for char in text: if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)): output.append(" " if _is_whitespace(char) else char) return "".join(output) def _tokenize_chinese_chars(text:str) -> str: output = [] for char in text: cp = ord(char) if _is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def whitespace_tokenize(text): if not (text := text.strip()): return [] return text.split() def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]: text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text output_tokens = [] for token in text.strip().split(): chars = list(token) if len(chars) > 200: output_tokens.append("[UNK]") continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append("[UNK]") else: output_tokens.extend(sub_tokens) return output_tokens class Tokenizer: def __init__(self, vocab_file): self.vocab = {} with open(vocab_file) as f: for line in f: line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab) self.inv_vocab = {v: k for k, v in self.vocab.items()} def tokenize(self, text:str) -> list[str]: # BasicTokenizer split_tokens = [] for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))): split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower()))) split_tokens = " ".join(split_tokens).strip().split() # WordpieceTokenizer tokens = [] for token in split_tokens: tokens.extend(_wordpiece_tokenize(token, self.vocab)) return tokens def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens] def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids] ##################### Feature transformation ##################### def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None: while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_num_tokens: break trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b assert len(trunc_tokens) >= 1 if rng.random() < 0.5: del trunc_tokens[0] else: trunc_tokens.pop() def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]: cand_indices = [] for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue cand_indices.append(i) rng.shuffle(cand_indices) output_tokens = list(tokens) num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15))))) masked_lms = [] covered_indices = set() for index in cand_indices: if len(masked_lms) >= num_to_predict: break if index in covered_indices: continue covered_indices.add(index) masked_token = None if rng.random() < 0.8: masked_token = "[MASK]" else: if rng.random() < 0.5: masked_token = tokens[index] else: masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)] output_tokens[index] = masked_token masked_lms.append((index, tokens[index])) masked_lms = sorted(masked_lms, key=lambda x: x[0]) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p[0]) masked_lm_labels.append(p[1]) return output_tokens, masked_lm_positions, masked_lm_labels def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]: max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP] target_seq_length = max_num_tokens if rng.random() < getenv("SHORT_SEQ_PROB", 0.1): target_seq_length = rng.randint(2, max_num_tokens) instances = [] current_chunk = [] current_length = 0 i = 0 while i < len(doc): segment = doc[i] current_chunk.append(segment) current_length += len(segment) if i == len(doc) - 1 or current_length >= target_seq_length: if current_chunk: a_end = 1 if len(current_chunk) >= 2: a_end = rng.randint(1, len(current_chunk) - 1) tokens_a = [] for j in range(a_end): tokens_a.extend(current_chunk[j]) tokens_b = [] is_random_next = False if len(current_chunk) == 1 or rng.random() < 0.5: is_random_next = True target_b_length = target_seq_length - len(tokens_a) for _ in range(10): random_document_index = rng.randint(0, len(documents) - 1) if random_document_index != di: break random_document = documents[random_document_index] random_start = rng.randint(0, len(random_document) - 1) for j in range(random_start, len(random_document)): tokens_b.extend(random_document[j]) if len(tokens_b) >= target_b_length: break num_unused_segments = len(current_chunk) - a_end i -= num_unused_segments else: is_random_next = False for j in range(a_end, len(current_chunk)): tokens_b.extend(current_chunk[j]) truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) assert len(tokens_a) >= 1 assert len(tokens_b) >= 1 tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) for token in tokens_b: tokens.append(token) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys())) instances.append({ "tokens": tokens, "segment_ids": segment_ids, "masked_lm_positions": masked_lm_positions, "masked_lm_labels": masked_lm_labels, "is_random_next": is_random_next }) current_chunk = [] current_length = 0 i += 1 return instances def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]: documents = [[]] with open(BASEDIR / fn) as f: for line in f.readlines(): if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break if not (line := line.strip()): documents.append([]) if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens) documents = [x for x in documents if x] rng.shuffle(documents) return documents def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]: instances = [] for _ in range(getenv('DUPE_FACTOR', 10)): for di, doc in enumerate(documents): instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents)) rng.shuffle(instances) return instances def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict: input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"]) input_mask = [1] * len(input_ids) segment_ids = instance["segment_ids"] max_seq_length = getenv('MAX_SEQ_LENGTH', 512) assert len(input_ids) <= max_seq_length while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length masked_lm_positions = instance["masked_lm_positions"] masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"]) masked_lm_weights = [1.0] * len(masked_lm_ids) while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76): masked_lm_positions.append(0) masked_lm_ids.append(0) masked_lm_weights.append(0.0) next_sentence_label = 1 if instance["is_random_next"] else 0 return { "input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0), "input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0), "segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0), "masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0), "masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0), "masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0), "next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0), } def process_part(part:int): tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt") os.makedirs(BASEDIR / "train" / str(part), exist_ok=True) for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)): with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f: pickle.dump(feature_batch, f) def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples rng = random.Random(getenv('RANDOM_SEED', 12345)) if val: tqdm.write("Getting samples from dataset") documents = get_documents(rng, tokenizer, "results4/eval.txt") instances = get_instances(rng, tokenizer, documents) tqdm.write(f"There are {len(instances)} samples in the dataset") tqdm.write(f"Picking 10000 samples") pick_ratio = len(instances) / 10000 picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)] for batch in range(10): yield picks[batch*1000:(batch+1)*1000] else: documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500") instances = get_instances(rng, tokenizer, documents) while len(instances) > 0: batch_size = min(1000, len(instances)) # We batch 1000 samples to one file batch = instances[:batch_size] del instances[:batch_size] yield [instance_to_features(instance, tokenizer) for instance in batch] ##################### Load files ##################### def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl"))) @diskcache def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl"))) if __name__ == "__main__": tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt") assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all" if sys.argv[1] == "pre-eval": # Generate 10000 eval samples os.makedirs(BASEDIR / "eval", exist_ok=True) for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10): with open(BASEDIR / f"eval/{i}.pkl", "wb") as f: pickle.dump(feature_batch, f) elif sys.argv[1] == "pre-train": os.makedirs(BASEDIR / "train", exist_ok=True) if sys.argv[2] == "all": # Use all 500 parts for training generation process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1) else: # Use a specific part for training generation part = int(sys.argv[2]) os.makedirs(BASEDIR / "train" / str(part), exist_ok=True) for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))): with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f: pickle.dump(feature_batch, f)