From 801ecd4a075be80cc89c9503afa99f35c9d4fee4 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 12 Sep 2022 09:20:12 -0700 Subject: [PATCH] cleanup clip tokenizer --- examples/stable_diffusion.py | 193 +++++++++--------- .../bpe_simple_vocab_16e6.txt.gz | Bin 2 files changed, 98 insertions(+), 95 deletions(-) rename {examples => weights}/bpe_simple_vocab_16e6.txt.gz (100%) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 2982c29d..3def7195 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -490,112 +490,112 @@ class CLIPTextTransformer: # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) @lru_cache() def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../weights/bpe_simple_vocab_16e6.txt.gz") def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) class ClipTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - # Truncation, keeping two slots for start and end tokens. - if len(bpe_tokens) > 75: - bpe_tokens = bpe_tokens[:75] - return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(text.strip()).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + # Truncation, keeping two slots for start and end tokens. + if len(bpe_tokens) > 75: + bpe_tokens = bpe_tokens[:75] + return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) class StableDiffusion: def __init__(self): @@ -650,6 +650,9 @@ if __name__ == "__main__": tokenizer = ClipTokenizer() phrase = tokenizer.encode("a horse sized cat eating a bagel") + _phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407] + assert tuple(phrase) == tuple(_phrase) + # phrase = tokenizer.encode("penguin with fire extinguisher") context = model.cond_stage_model.transformer.text_model(phrase).realize() diff --git a/examples/bpe_simple_vocab_16e6.txt.gz b/weights/bpe_simple_vocab_16e6.txt.gz similarity index 100% rename from examples/bpe_simple_vocab_16e6.txt.gz rename to weights/bpe_simple_vocab_16e6.txt.gz