mirror of https://github.com/commaai/tinygrad.git
149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
import json
|
|
import os
|
|
from pathlib import Path
|
|
from transformers import BertTokenizer
|
|
import numpy as np
|
|
from tinygrad.helpers import fetch
|
|
|
|
BASEDIR = Path(__file__).parent / "squad"
|
|
def init_dataset():
|
|
os.makedirs(BASEDIR, exist_ok=True)
|
|
fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
|
|
with open(BASEDIR / "dev-v1.1.json") as f:
|
|
data = json.load(f)["data"]
|
|
|
|
examples = []
|
|
for article in data:
|
|
for paragraph in article["paragraphs"]:
|
|
text = paragraph["context"]
|
|
doc_tokens = []
|
|
prev_is_whitespace = True
|
|
for c in text:
|
|
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
|
prev_is_whitespace = True
|
|
else:
|
|
if prev_is_whitespace:
|
|
doc_tokens.append(c)
|
|
else:
|
|
doc_tokens[-1] += c
|
|
prev_is_whitespace = False
|
|
|
|
for qa in paragraph["qas"]:
|
|
qa_id = qa["id"]
|
|
q_text = qa["question"]
|
|
|
|
examples.append({
|
|
"id": qa_id,
|
|
"question": q_text,
|
|
"context": doc_tokens,
|
|
"answers": list(map(lambda x: x["text"], qa["answers"]))
|
|
})
|
|
return examples
|
|
|
|
def _check_is_max_context(doc_spans, cur_span_index, position):
|
|
best_score, best_span_index = None, None
|
|
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
|
end = doc_start + doc_length - 1
|
|
if position < doc_start:
|
|
continue
|
|
if position > end:
|
|
continue
|
|
num_left_context = position - doc_start
|
|
num_right_context = end - position
|
|
score = min(num_left_context, num_right_context) + 0.01 * doc_length
|
|
if best_score is None or score > best_score:
|
|
best_score = score
|
|
best_span_index = di
|
|
return cur_span_index == best_span_index
|
|
|
|
def convert_example_to_features(example, tokenizer):
|
|
query_tokens = tokenizer.tokenize(example["question"])
|
|
|
|
if len(query_tokens) > 64:
|
|
query_tokens = query_tokens[:64]
|
|
|
|
tok_to_orig_index = []
|
|
orig_to_tok_index = []
|
|
all_doc_tokens = []
|
|
for i, token in enumerate(example["context"]):
|
|
orig_to_tok_index.append(len(all_doc_tokens))
|
|
sub_tokens = tokenizer.tokenize(token)
|
|
for sub_token in sub_tokens:
|
|
tok_to_orig_index.append(i)
|
|
all_doc_tokens.append(sub_token)
|
|
|
|
max_tokens_for_doc = 384 - len(query_tokens) - 3
|
|
|
|
doc_spans = []
|
|
start_offset = 0
|
|
while start_offset < len(all_doc_tokens):
|
|
length = len(all_doc_tokens) - start_offset
|
|
length = min(length, max_tokens_for_doc)
|
|
doc_spans.append((start_offset, length))
|
|
if start_offset + length == len(all_doc_tokens):
|
|
break
|
|
start_offset += min(length, 128)
|
|
|
|
outputs = []
|
|
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
|
tokens = []
|
|
token_to_orig_map = {}
|
|
token_is_max_context = {}
|
|
segment_ids = []
|
|
tokens.append("[CLS]")
|
|
segment_ids.append(0)
|
|
for token in query_tokens:
|
|
tokens.append(token)
|
|
segment_ids.append(0)
|
|
tokens.append("[SEP]")
|
|
segment_ids.append(0)
|
|
|
|
for i in range(doc_length):
|
|
split_token_index = doc_start + i
|
|
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
|
token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
|
|
tokens.append(all_doc_tokens[split_token_index])
|
|
segment_ids.append(1)
|
|
tokens.append("[SEP]")
|
|
segment_ids.append(1)
|
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
input_mask = [1] * len(input_ids)
|
|
|
|
while len(input_ids) < 384:
|
|
input_ids.append(0)
|
|
input_mask.append(0)
|
|
segment_ids.append(0)
|
|
|
|
assert len(input_ids) == 384
|
|
assert len(input_mask) == 384
|
|
assert len(segment_ids) == 384
|
|
|
|
outputs.append({
|
|
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
|
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
|
|
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
|
|
"token_to_orig_map": token_to_orig_map,
|
|
"token_is_max_context": token_is_max_context,
|
|
"tokens": tokens,
|
|
})
|
|
|
|
return outputs
|
|
|
|
def iterate(tokenizer, start=0):
|
|
examples = init_dataset()
|
|
print(f"there are {len(examples)} pairs in the dataset")
|
|
|
|
for i in range(start, len(examples)):
|
|
example = examples[i]
|
|
features = convert_example_to_features(example, tokenizer)
|
|
# we need to yield all features here as the f1 score is the maximum over all features
|
|
yield features, example
|
|
|
|
if __name__ == "__main__":
|
|
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
|
|
|
|
X, Y = next(iterate(tokenizer))
|
|
print(" ".join(X[0]["tokens"]))
|
|
print(X[0]["input_ids"].shape, Y)
|