mirror of https://github.com/commaai/tinygrad.git
412 lines
15 KiB
Python
412 lines
15 KiB
Python
# Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
|
|
# This is a modified version of the original script:
|
|
# https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
|
|
# ENV VARS:
|
|
# MAX_SEQ_LENGTH - Maximum sequence length
|
|
# MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
|
|
# RANDOM_SEED - Random seed
|
|
# DUPE_FACTOR - Number of times to duplicate the input data with different masks
|
|
# MASKED_LM_PROB - Probability of masking a token
|
|
# SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
|
|
|
|
import os, sys, pickle, random, unicodedata
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from tqdm.contrib.concurrent import process_map
|
|
|
|
from tinygrad.helpers import diskcache, getenv
|
|
|
|
BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
|
|
|
|
################### Tokenization #####################
|
|
|
|
def _is_whitespace(char:str) -> bool:
|
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
|
return True
|
|
return unicodedata.category(char) == "Zs"
|
|
|
|
def _is_control(char:str) -> bool:
|
|
if char == "\t" or char == "\n" or char == "\r":
|
|
return False
|
|
return unicodedata.category(char).startswith("C")
|
|
|
|
def _is_punctuation(char:str) -> bool:
|
|
# range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
|
|
# range(58, 65) -> : ; < = > ? @
|
|
# range(91, 97) -> [ \ ] ^ _
|
|
# range(123, 127) -> { | } ~
|
|
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
|
|
return True
|
|
return unicodedata.category(char).startswith("P")
|
|
|
|
def _is_chinese_char(cp:int) -> bool:
|
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or
|
|
(cp >= 0x3400 and cp <= 0x4DBF) or
|
|
(cp >= 0x20000 and cp <= 0x2A6DF) or
|
|
(cp >= 0x2A700 and cp <= 0x2B73F) or
|
|
(cp >= 0x2B740 and cp <= 0x2B81F) or
|
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
|
(cp >= 0xF900 and cp <= 0xFAFF) or
|
|
(cp >= 0x2F800 and cp <= 0x2FA1F)):
|
|
return True
|
|
return False
|
|
|
|
def _run_split_on_punc(text:str) -> list[str]:
|
|
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
|
return [text]
|
|
start_new_word = True
|
|
output = []
|
|
for i in range(len(text)):
|
|
if _is_punctuation(char := text[i]):
|
|
output.append([char])
|
|
start_new_word = True
|
|
else:
|
|
if start_new_word:
|
|
output.append([])
|
|
start_new_word = False
|
|
output[-1].append(char)
|
|
return ["".join(x) for x in output]
|
|
|
|
def _run_strip_accents(text:str) -> str:
|
|
output = []
|
|
for char in unicodedata.normalize("NFD", text):
|
|
if unicodedata.category(char) != "Mn":
|
|
output.append(char)
|
|
return "".join(output)
|
|
|
|
def _clean_text(text:str) -> str:
|
|
output = []
|
|
for char in text:
|
|
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
|
|
output.append(" " if _is_whitespace(char) else char)
|
|
return "".join(output)
|
|
|
|
def _tokenize_chinese_chars(text:str) -> str:
|
|
output = []
|
|
for char in text:
|
|
cp = ord(char)
|
|
if _is_chinese_char(cp):
|
|
output.append(" ")
|
|
output.append(char)
|
|
output.append(" ")
|
|
else:
|
|
output.append(char)
|
|
return "".join(output)
|
|
|
|
def whitespace_tokenize(text):
|
|
if not (text := text.strip()): return []
|
|
return text.split()
|
|
|
|
def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
|
|
text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
|
|
output_tokens = []
|
|
for token in text.strip().split():
|
|
chars = list(token)
|
|
if len(chars) > 200:
|
|
output_tokens.append("[UNK]")
|
|
continue
|
|
|
|
is_bad = False
|
|
start = 0
|
|
sub_tokens = []
|
|
while start < len(chars):
|
|
end = len(chars)
|
|
cur_substr = None
|
|
while start < end:
|
|
substr = "".join(chars[start:end])
|
|
if start > 0: substr = "##" + substr
|
|
if substr in vocab:
|
|
cur_substr = substr
|
|
break
|
|
end -= 1
|
|
if cur_substr is None:
|
|
is_bad = True
|
|
break
|
|
sub_tokens.append(cur_substr)
|
|
start = end
|
|
|
|
if is_bad: output_tokens.append("[UNK]")
|
|
else: output_tokens.extend(sub_tokens)
|
|
return output_tokens
|
|
|
|
class Tokenizer:
|
|
def __init__(self, vocab_file):
|
|
self.vocab = {}
|
|
with open(vocab_file) as f:
|
|
for line in f:
|
|
line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
|
|
if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
|
|
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
|
|
def tokenize(self, text:str) -> list[str]:
|
|
# BasicTokenizer
|
|
split_tokens = []
|
|
for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
|
|
split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
|
|
split_tokens = " ".join(split_tokens).strip().split()
|
|
# WordpieceTokenizer
|
|
tokens = []
|
|
for token in split_tokens:
|
|
tokens.extend(_wordpiece_tokenize(token, self.vocab))
|
|
return tokens
|
|
|
|
def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
|
|
def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
|
|
|
|
##################### Feature transformation #####################
|
|
|
|
def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
|
|
while True:
|
|
total_length = len(tokens_a) + len(tokens_b)
|
|
if total_length <= max_num_tokens:
|
|
break
|
|
|
|
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
|
assert len(trunc_tokens) >= 1
|
|
|
|
if rng.random() < 0.5:
|
|
del trunc_tokens[0]
|
|
else:
|
|
trunc_tokens.pop()
|
|
|
|
def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
|
|
cand_indices = []
|
|
for i, token in enumerate(tokens):
|
|
if token == "[CLS]" or token == "[SEP]":
|
|
continue
|
|
cand_indices.append(i)
|
|
|
|
rng.shuffle(cand_indices)
|
|
output_tokens = list(tokens)
|
|
num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
|
|
|
|
masked_lms = []
|
|
covered_indices = set()
|
|
for index in cand_indices:
|
|
if len(masked_lms) >= num_to_predict:
|
|
break
|
|
if index in covered_indices:
|
|
continue
|
|
covered_indices.add(index)
|
|
|
|
masked_token = None
|
|
if rng.random() < 0.8:
|
|
masked_token = "[MASK]"
|
|
else:
|
|
if rng.random() < 0.5:
|
|
masked_token = tokens[index]
|
|
else:
|
|
masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
|
|
|
|
output_tokens[index] = masked_token
|
|
masked_lms.append((index, tokens[index]))
|
|
masked_lms = sorted(masked_lms, key=lambda x: x[0])
|
|
|
|
masked_lm_positions = []
|
|
masked_lm_labels = []
|
|
for p in masked_lms:
|
|
masked_lm_positions.append(p[0])
|
|
masked_lm_labels.append(p[1])
|
|
|
|
return output_tokens, masked_lm_positions, masked_lm_labels
|
|
|
|
def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
|
|
max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
|
|
|
|
target_seq_length = max_num_tokens
|
|
if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
|
|
target_seq_length = rng.randint(2, max_num_tokens)
|
|
|
|
instances = []
|
|
current_chunk = []
|
|
current_length = 0
|
|
i = 0
|
|
while i < len(doc):
|
|
segment = doc[i]
|
|
current_chunk.append(segment)
|
|
current_length += len(segment)
|
|
if i == len(doc) - 1 or current_length >= target_seq_length:
|
|
if current_chunk:
|
|
a_end = 1
|
|
if len(current_chunk) >= 2:
|
|
a_end = rng.randint(1, len(current_chunk) - 1)
|
|
|
|
tokens_a = []
|
|
for j in range(a_end):
|
|
tokens_a.extend(current_chunk[j])
|
|
|
|
tokens_b = []
|
|
is_random_next = False
|
|
if len(current_chunk) == 1 or rng.random() < 0.5:
|
|
is_random_next = True
|
|
target_b_length = target_seq_length - len(tokens_a)
|
|
|
|
for _ in range(10):
|
|
random_document_index = rng.randint(0, len(documents) - 1)
|
|
if random_document_index != di:
|
|
break
|
|
|
|
random_document = documents[random_document_index]
|
|
random_start = rng.randint(0, len(random_document) - 1)
|
|
for j in range(random_start, len(random_document)):
|
|
tokens_b.extend(random_document[j])
|
|
if len(tokens_b) >= target_b_length:
|
|
break
|
|
|
|
num_unused_segments = len(current_chunk) - a_end
|
|
i -= num_unused_segments
|
|
else:
|
|
is_random_next = False
|
|
for j in range(a_end, len(current_chunk)):
|
|
tokens_b.extend(current_chunk[j])
|
|
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
|
|
|
assert len(tokens_a) >= 1
|
|
assert len(tokens_b) >= 1
|
|
|
|
tokens = []
|
|
segment_ids = []
|
|
tokens.append("[CLS]")
|
|
segment_ids.append(0)
|
|
for token in tokens_a:
|
|
tokens.append(token)
|
|
segment_ids.append(0)
|
|
tokens.append("[SEP]")
|
|
segment_ids.append(0)
|
|
for token in tokens_b:
|
|
tokens.append(token)
|
|
segment_ids.append(1)
|
|
tokens.append("[SEP]")
|
|
segment_ids.append(1)
|
|
|
|
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
|
|
instances.append({
|
|
"tokens": tokens,
|
|
"segment_ids": segment_ids,
|
|
"masked_lm_positions": masked_lm_positions,
|
|
"masked_lm_labels": masked_lm_labels,
|
|
"is_random_next": is_random_next
|
|
})
|
|
current_chunk = []
|
|
current_length = 0
|
|
i += 1
|
|
return instances
|
|
|
|
def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
|
|
documents = [[]]
|
|
with open(BASEDIR / fn) as f:
|
|
for line in f.readlines():
|
|
if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
|
|
if not (line := line.strip()): documents.append([])
|
|
if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
|
|
documents = [x for x in documents if x]
|
|
rng.shuffle(documents)
|
|
return documents
|
|
|
|
def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
|
|
instances = []
|
|
for _ in range(getenv('DUPE_FACTOR', 10)):
|
|
for di, doc in enumerate(documents):
|
|
instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
|
|
rng.shuffle(instances)
|
|
return instances
|
|
|
|
def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
|
|
input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
|
|
input_mask = [1] * len(input_ids)
|
|
segment_ids = instance["segment_ids"]
|
|
|
|
max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
|
|
|
|
assert len(input_ids) <= max_seq_length
|
|
while len(input_ids) < max_seq_length:
|
|
input_ids.append(0)
|
|
input_mask.append(0)
|
|
segment_ids.append(0)
|
|
assert len(input_ids) == max_seq_length
|
|
assert len(input_mask) == max_seq_length
|
|
assert len(segment_ids) == max_seq_length
|
|
|
|
masked_lm_positions = instance["masked_lm_positions"]
|
|
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
|
|
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
|
|
|
while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
|
|
masked_lm_positions.append(0)
|
|
masked_lm_ids.append(0)
|
|
masked_lm_weights.append(0.0)
|
|
|
|
next_sentence_label = 1 if instance["is_random_next"] else 0
|
|
|
|
return {
|
|
"input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
|
|
"input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
|
|
"segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
|
|
"masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
|
|
"masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
|
|
"masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
|
|
"next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
|
|
}
|
|
|
|
def process_part(part:int):
|
|
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
|
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
|
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
|
|
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
|
pickle.dump(feature_batch, f)
|
|
|
|
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
|
|
rng = random.Random(getenv('RANDOM_SEED', 12345))
|
|
|
|
if val:
|
|
tqdm.write("Getting samples from dataset")
|
|
documents = get_documents(rng, tokenizer, "results4/eval.txt")
|
|
instances = get_instances(rng, tokenizer, documents)
|
|
|
|
tqdm.write(f"There are {len(instances)} samples in the dataset")
|
|
tqdm.write(f"Picking 10000 samples")
|
|
|
|
pick_ratio = len(instances) / 10000
|
|
picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
|
|
for batch in range(10):
|
|
yield picks[batch*1000:(batch+1)*1000]
|
|
else:
|
|
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
|
|
instances = get_instances(rng, tokenizer, documents)
|
|
|
|
while len(instances) > 0:
|
|
batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
|
|
batch = instances[:batch_size]
|
|
del instances[:batch_size]
|
|
yield [instance_to_features(instance, tokenizer) for instance in batch]
|
|
|
|
##################### Load files #####################
|
|
|
|
def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
|
|
|
|
@diskcache
|
|
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
|
|
|
|
if __name__ == "__main__":
|
|
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
|
|
|
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
|
|
|
|
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
|
|
os.makedirs(BASEDIR / "eval", exist_ok=True)
|
|
|
|
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
|
|
with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
|
|
pickle.dump(feature_batch, f)
|
|
elif sys.argv[1] == "pre-train":
|
|
os.makedirs(BASEDIR / "train", exist_ok=True)
|
|
if sys.argv[2] == "all": # Use all 500 parts for training generation
|
|
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
|
|
else: # Use a specific part for training generation
|
|
part = int(sys.argv[2])
|
|
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
|
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
|
|
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
|
pickle.dump(feature_batch, f)
|