tinygrad/extra/datasets/wikipedia.py

399 lines
14 KiB
Python

# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
# This is a modified version of the original script:
# https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
# ENV VARS:
# MAX_SEQ_LENGTH - Maximum sequence length
# MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
# RANDOM_SEED - Random seed
# DUPE_FACTOR - Number of times to duplicate the input data with different masks
# MASKED_LM_PROB - Probability of masking a token
# SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
import os, sys, pickle, random, unicodedata
from pathlib import Path
import numpy as np
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from tinygrad.helpers import diskcache, getenv
BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
################### Tokenization #####################
def _is_whitespace(char:str) -> bool:
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
return unicodedata.category(char) == "Zs"
def _is_control(char:str) -> bool:
if char == "\t" or char == "\n" or char == "\r":
return False
return unicodedata.category(char).startswith("C")
def _is_punctuation(char:str) -> bool:
# range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
# range(58, 65) -> : ; < = > ? @
# range(91, 97) -> [ \ ] ^ _
# range(123, 127) -> { | } ~
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
return True
return unicodedata.category(char).startswith("P")
def _is_chinese_char(cp:int) -> bool:
if ((cp >= 0x4E00 and cp <= 0x9FFF) or
(cp >= 0x3400 and cp <= 0x4DBF) or
(cp >= 0x20000 and cp <= 0x2A6DF) or
(cp >= 0x2A700 and cp <= 0x2B73F) or
(cp >= 0x2B740 and cp <= 0x2B81F) or
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or
(cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _run_split_on_punc(text:str) -> list[str]:
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
return [text]
start_new_word = True
output = []
for i in range(len(text)):
if _is_punctuation(char := text[i]):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
return ["".join(x) for x in output]
def _run_strip_accents(text:str) -> str:
output = []
for char in unicodedata.normalize("NFD", text):
if unicodedata.category(char) != "Mn":
output.append(char)
return "".join(output)
def _clean_text(text:str) -> str:
output = []
for char in text:
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
output.append(" " if _is_whitespace(char) else char)
return "".join(output)
def _tokenize_chinese_chars(text:str) -> str:
output = []
for char in text:
cp = ord(char)
if _is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(text):
if not (text := text.strip()): return []
return text.split()
def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
output_tokens = []
for token in text.strip().split():
chars = list(token)
if len(chars) > 200:
output_tokens.append("[UNK]")
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0: substr = "##" + substr
if substr in vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad: output_tokens.append("[UNK]")
else: output_tokens.extend(sub_tokens)
return output_tokens
class Tokenizer:
def __init__(self, vocab_file):
self.vocab = {}
with open(vocab_file) as f:
for line in f:
line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
def tokenize(self, text:str) -> list[str]:
# BasicTokenizer
split_tokens = []
for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
split_tokens = " ".join(split_tokens).strip().split()
# WordpieceTokenizer
tokens = []
for token in split_tokens:
tokens.extend(_wordpiece_tokenize(token, self.vocab))
return tokens
def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
##################### Feature transformation #####################
def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
cand_indices = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indices.append(i)
rng.shuffle(cand_indices)
output_tokens = list(tokens)
num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
masked_lms = []
covered_indices = set()
for index in cand_indices:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indices:
continue
covered_indices.add(index)
masked_token = None
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
if rng.random() < 0.5:
masked_token = tokens[index]
else:
masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
output_tokens[index] = masked_token
masked_lms.append((index, tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x[0])
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p[0])
masked_lm_labels.append(p[1])
return output_tokens, masked_lm_positions, masked_lm_labels
def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
target_seq_length = max_num_tokens
if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
target_seq_length = rng.randint(2, max_num_tokens)
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(doc):
segment = doc[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(doc) - 1 or current_length >= target_seq_length:
if current_chunk:
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
for _ in range(10):
random_document_index = rng.randint(0, len(documents) - 1)
if random_document_index != di:
break
random_document = documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
instances.append({
"tokens": tokens,
"segment_ids": segment_ids,
"masked_lm_positions": masked_lm_positions,
"masked_lm_labels": masked_lm_labels,
"is_random_next": is_random_next
})
current_chunk = []
current_length = 0
i += 1
return instances
def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
documents = [[]]
with open(BASEDIR / fn) as f:
for line in f.readlines():
if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
if not (line := line.strip()): documents.append([])
if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
documents = [x for x in documents if x]
rng.shuffle(documents)
return documents
def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
instances = []
for _ in range(getenv('DUPE_FACTOR', 10)):
for di, doc in enumerate(documents):
instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
rng.shuffle(instances)
return instances
def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
input_mask = [1] * len(input_ids)
segment_ids = instance["segment_ids"]
max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = instance["masked_lm_positions"]
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance["is_random_next"] else 0
return {
"input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
"input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
"segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
"masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
"masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
"masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
"next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
}
def process_part(part:int):
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
os.makedirs(BASEDIR / "train", exist_ok=True)
if os.path.exists(BASEDIR / f"train/{str(part)}.pkl"): return
features = get_features_from_part(tokenizer, val=False, part=part)
with open(BASEDIR / f"train/{str(part)}.pkl", "wb") as f:
pickle.dump(features, f)
def get_features_from_part(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
rng = random.Random(getenv('RANDOM_SEED', 12345))
if val:
tqdm.write("Getting samples from dataset")
documents = get_documents(rng, tokenizer, "results4/eval.txt")
instances = get_instances(rng, tokenizer, documents)
tqdm.write(f"There are {len(instances)} samples in the dataset")
tqdm.write(f"Picking 10000 samples")
pick_ratio = len(instances) / 10000
return [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
else:
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
instances = get_instances(rng, tokenizer, documents)
return [instance_to_features(instance, tokenizer) for instance in instances]
##################### Load files #####################
@diskcache
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*.pkl")))
if __name__ == "__main__":
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
with open(BASEDIR / "eval.pkl", "wb") as f:
pickle.dump(get_features_from_part(tokenizer, val=True), f)
elif sys.argv[1] == "pre-train":
if sys.argv[2] == "all": # Use all 500 parts for training generation
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
else: # Use a specific part for training generation
part = sys.argv[2]
print(f"Processing part {part}...")
process_part(int(part))