tinygrad/extra/models/clip.py

464 lines
18 KiB
Python

from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch
from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d
from typing import List, Optional, Union, Tuple, Dict
from abc import ABC, abstractmethod
from functools import lru_cache
from PIL import Image
import numpy as np
import re, gzip
@lru_cache()
def default_bpe():
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
class Tokenizer:
"""
Namespace for CLIP Text Tokenizer components.
"""
@staticmethod
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
return set(zip(word, word[1:]))
@staticmethod
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
@staticmethod
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class ClipTokenizer:
def __init__(self):
self.byte_encoder = Tokenizer.bytes_to_unicode()
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(Tokenizer.bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = Tokenizer.get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
pairs = Tokenizer.get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
bpe_tokens: List[int] = []
text = Tokenizer.whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
# Truncation, keeping two slots for start and end tokens.
if len(bpe_tokens) > 75:
bpe_tokens = bpe_tokens[:75]
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
class Embedder(ABC):
input_key: str
@abstractmethod
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
pass
class Closed:
"""
Namespace for OpenAI CLIP model components.
"""
class ClipMlp:
def __init__(self):
self.fc1 = Linear(768, 3072)
self.fc2 = Linear(3072, 768)
def __call__(self, h:Tensor) -> Tensor:
h = self.fc1(h)
h = h.quick_gelu()
h = self.fc2(h)
return h
class ClipAttention:
def __init__(self):
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = Linear(self.embed_dim, self.embed_dim)
self.v_proj = Linear(self.embed_dim, self.embed_dim)
self.q_proj = Linear(self.embed_dim, self.embed_dim)
self.out_proj = Linear(self.embed_dim, self.embed_dim)
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
bsz, tgt_len, embed_dim = hidden_states.shape
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
class ClipEncoderLayer:
def __init__(self):
self.self_attn = Closed.ClipAttention()
self.layer_norm1 = LayerNorm(768)
self.mlp = Closed.ClipMlp()
self.layer_norm2 = LayerNorm(768)
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class ClipTextEmbeddings:
def __init__(self):
self.token_embedding = Embedding(49408, 768)
self.position_embedding = Embedding(77, 768)
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
class ClipEncoder:
def __init__(self, layer_count:int=12):
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
for l in layers:
x = l(x, causal_attention_mask)
return x
class ClipTextTransformer:
def __init__(self, ret_layer_idx:Optional[int]=None):
self.embeddings = Closed.ClipTextEmbeddings()
self.encoder = Closed.ClipEncoder()
self.final_layer_norm = LayerNorm(768)
self.ret_layer_idx = ret_layer_idx
def __call__(self, input_ids:Tensor) -> Tensor:
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
class ClipTextModel:
def __init__(self, ret_layer_idx:Optional[int]):
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
class FrozenClosedClipEmbedder(Embedder):
def __init__(self, ret_layer_idx:Optional[int]=None):
self.tokenizer = Tokenizer.ClipTokenizer()
self.transformer = Closed.ClipTextModel(ret_layer_idx)
self.input_key = "txt"
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
return self.transformer.text_model(tokens.reshape(len(texts),-1))
class Open:
"""
Namespace for OpenCLIP model components.
"""
class MultiheadAttention:
def __init__(self, dims:int, n_heads:int):
self.dims = dims
self.n_heads = n_heads
self.d_head = self.dims // self.n_heads
self.in_proj_bias = Tensor.empty(3*dims)
self.in_proj_weight = Tensor.empty(3*dims, dims)
self.out_proj = Linear(dims, dims)
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
T,B,C = x.shape
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2)
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T*B, C)
attn_output = self.out_proj(attn_output)
attn_output = attn_output.reshape(T, B, C)
return attn_output
class Mlp:
def __init__(self, dims, hidden_dims):
self.c_fc = Linear(dims, hidden_dims)
self.c_proj = Linear(hidden_dims, dims)
def __call__(self, x:Tensor) -> Tensor:
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
class ResidualAttentionBlock:
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
self.ln_1 = LayerNorm(dims)
self.attn = Open.MultiheadAttention(dims, n_heads)
self.ln_2 = LayerNorm(dims)
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor:
q_x = self.ln_1(x)
attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask)
attn_out = attn_out.transpose(0, 1) if transpose else attn_out
x = x + attn_out
x = x + self.mlp(self.ln_2(x))
return x
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
class ClipTransformer:
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
self.resblocks = [
Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers)
]
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask, transpose=True)
return x
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
class ClipTextTransformer:
def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
self.token_embedding = Embedding(vocab_size, width)
self.positional_embedding = Tensor.empty(ctx_length, width)
self.transformer = Open.ClipTransformer(width, layers, n_heads)
self.ln_final = LayerNorm(width)
self.text_projection = Tensor.empty(width, width)
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
def __call__(self, text:Tensor) -> Tensor:
seq_len = text.shape[1]
x = self.token_embedding(text)
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x)
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
return pooled
class ClipVisionTransformer:
def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int):
grid_size = image_size // patch_size
n_heads = width // d_head
assert n_heads * d_head == width
self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
self.class_embedding = Tensor.empty(width)
self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width)
self.transformer = Open.ClipTransformer(width, layers, n_heads)
self.ln_pre = LayerNorm(width)
self.ln_post = LayerNorm(width)
self.proj = Tensor.empty(width, 1024)
def __call__(self, x:Tensor) -> Tensor:
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1)
x = x + self.positional_embedding
x = self.ln_pre(x)
x = self.transformer(x)
x = self.ln_post(x)
pooled = x[:, 0] @ self.proj
return pooled
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
class FrozenOpenClipEmbedder(Embedder):
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
self.tokenizer = Tokenizer.ClipTokenizer()
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
self.return_pooled = return_pooled
self.input_key = "txt"
self.ln_penultimate = ln_penultimate
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64, device=device).reshape(1,-1)
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
for r in self.model.transformer.resblocks:
x, penultimate = r(x, attn_mask=attn_mask), x
return x.permute(1, 0, 2), penultimate.permute(1, 0, 2)
def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]:
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.ln_penultimate:
penultimate = self.model.ln_final(penultimate)
if self.return_pooled:
x = self.model.ln_final(x)
index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1])
pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection
return penultimate, pooled
else:
return penultimate
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
return self.embed_tokens(tokens)
clip_configs: Dict = {
"ViT-H-14": {
"dims": 1024,
"vision_cfg": {
"width": 1280,
"layers": 32,
"d_head": 80,
"image_size": 224,
"patch_size": 14,
},
"text_cfg": {
"width": 1024,
"n_heads": 16,
"layers": 24,
"ctx_length": 77,
"vocab_size": 49408,
},
"return_pooled": False,
"ln_penultimate": True,
}
}
class OpenClipEncoder:
def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_):
self.visual = Open.ClipVisionTransformer(**vision_cfg)
text = Open.ClipTextTransformer(**text_cfg)
self.transformer = text.transformer
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1)
self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1)
# TODO:
# Should be doable in pure tinygrad, would just require some work and verification.
# This is very desirable since it would allow for full generation->evaluation in a single JIT call.
def prepare_image(self, image:Image.Image) -> Tensor:
SIZE = 224
w, h = image.size
scale = min(SIZE / h, SIZE / w)
image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC)
w, h = image.size
if w > SIZE:
left = (w - SIZE) // 2
image = image.crop((left, left+SIZE, 0, SIZE))
elif h > SIZE:
top = (h - SIZE) // 2
image = image.crop((0, SIZE, top, top+SIZE))
x = Tensor(np.array(image.convert('RGB')))
x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
return (x - self.mean) / self.std
def encode_tokens(self, tokens:Tensor) -> Tensor:
x = self.token_embedding(tokens)
x = x + self.positional_embedding
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x)
x = x[:, tokens.argmax(axis=-1)]
x = x @ self.text_projection
return x
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
image_features: Tensor = self.visual(image)
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
text_features = self.encode_tokens(tokens)
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
return (image_features * text_features).sum(axis=-1)