tinygrad/extra/models/t5.py

288 lines
12 KiB
Python
Raw Normal View History

# pip3 install sentencepiece
# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tinygrad T5 model."""
from tinygrad import nn, Tensor, dtypes
import math
from dataclasses import dataclass
from typing import List, Union, Optional, Tuple
from pathlib import Path
from sentencepiece import SentencePieceProcessor
# default config is t5-xxl
@dataclass
class T5Config:
d_ff:int = 10240
d_kv:int = 64
d_model:int = 4096
layer_norm_epsilon:float = 1e-6
num_decoder_layers:int = 24
num_heads:int = 64
num_layers:int = 24
relative_attention_num_buckets:int = 32
relative_attention_max_distance:int = 128
vocab_size:int = 32128
class T5Tokenizer:
def __init__(self, spiece_path):
self.spp = SentencePieceProcessor(str(spiece_path))
def encode(self, text:str, max_length:int) -> List[int]:
encoded = self.spp.Encode(text)
if len(encoded) > max_length - 1: encoded = encoded[:max_length - 1]
return encoded + [1] + [0]*(max_length - len(encoded) - 1)
class T5LayerNorm:
def __init__(self, hidden_size:int, eps:float=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
self.weight = Tensor.ones(hidden_size)
self.variance_epsilon = eps
def __call__(self, hidden_states:Tensor) -> Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.cast(dtypes.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * Tensor.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [dtypes.float16, dtypes.bfloat16]:
hidden_states = hidden_states.cast(self.weight.dtype)
return self.weight * hidden_states
class T5DenseGatedActDense:
def __init__(self, config:T5Config):
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
def __call__(self, hidden_states:Tensor) -> Tensor:
hidden_gelu = self.wi_0(hidden_states).gelu()
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF:
def __init__(self, config:T5Config):
self.DenseReluDense = T5DenseGatedActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, hidden_states:Tensor) -> Tensor:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + forwarded_states
return hidden_states
class T5Attention:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(relative_position:Tensor, num_buckets:int=32, max_distance:int=128) -> Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = Tensor.zeros_like(relative_position)
num_buckets //= 2
relative_buckets += (relative_position > 0).cast(dtypes.long) * num_buckets
relative_position = Tensor.abs(relative_position)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
Tensor.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).cast(dtypes.long)
relative_position_if_large = Tensor.min(
Tensor.stack(
relative_position_if_large, Tensor.full_like(relative_position_if_large, num_buckets - 1)
),
axis=0,
)
relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None) -> Tensor:
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None]
memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor,Tensor]:
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
batch_size, key_length = hidden_states.shape[:2]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer):
"""projects hidden states correctly to key/query states"""
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
return shape(proj_layer(hidden_states))
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(hidden_states, self.k)
value_states = project(hidden_states, self.v)
# compute scores
scores = Tensor.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None:
position_bias = self.compute_bias(key_length, key_length, device=scores.device)
scores += position_bias
attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length)
attn_output = unshape(Tensor.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
return attn_output, position_bias
class T5LayerSelfAttention:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output, position_bias = self.SelfAttention(normed_hidden_states, position_bias=position_bias)
return hidden_states + attention_output, position_bias
class T5Block:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.layer = (T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias),
T5LayerFF(config))
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
self_attention_outputs, position_bias = self.layer[0](hidden_states, position_bias=position_bias)
hidden_states = self_attention_outputs
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
return hidden_states, position_bias
class T5Stack:
def __init__(self, config:T5Config, embed_tokens:nn.Embedding):
self.config = config
self.embed_tokens = embed_tokens
self.block = [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, input_ids:Tensor) -> Tensor:
input_ids = input_ids.view(-1, input_ids.shape[-1])
hidden_states, position_bias = self.embed_tokens(input_ids), None
for layer_module in self.block:
hidden_states, position_bias = layer_module(hidden_states, position_bias=position_bias)
return self.final_layer_norm(hidden_states)
class T5EncoderModel:
def __init__(self, config:T5Config):
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = T5Stack(config, self.shared)
def __call__(self, input_ids:Tensor) -> Tensor:
return self.encoder(input_ids)
class T5Embedder:
def __init__(self, max_length:int, spiece_path:Union[str, Path]):
self.tokenizer = T5Tokenizer(spiece_path)
self.max_length = max_length
config = T5Config()
self.encoder = T5EncoderModel(config)
def __call__(self, texts:Union[str, List[str]]) -> Tensor:
if isinstance(texts, str): texts = [texts]
toks = Tensor.cat(*[Tensor(self.tokenizer.encode(text, self.max_length)) for text in texts], dim=0)
return self.encoder(toks)