mirror of https://github.com/commaai/tinygrad.git
288 lines
12 KiB
Python
288 lines
12 KiB
Python
# pip3 install sentencepiece
|
|
|
|
# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
|
|
|
# coding=utf-8
|
|
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tinygrad T5 model."""
|
|
from tinygrad import nn, Tensor, dtypes
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import List, Union, Optional, Tuple
|
|
from pathlib import Path
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
# default config is t5-xxl
|
|
@dataclass
|
|
class T5Config:
|
|
d_ff:int = 10240
|
|
d_kv:int = 64
|
|
d_model:int = 4096
|
|
layer_norm_epsilon:float = 1e-6
|
|
num_decoder_layers:int = 24
|
|
num_heads:int = 64
|
|
num_layers:int = 24
|
|
relative_attention_num_buckets:int = 32
|
|
relative_attention_max_distance:int = 128
|
|
vocab_size:int = 32128
|
|
|
|
class T5Tokenizer:
|
|
def __init__(self, spiece_path):
|
|
self.spp = SentencePieceProcessor(str(spiece_path))
|
|
|
|
def encode(self, text:str, max_length:int) -> List[int]:
|
|
encoded = self.spp.Encode(text)
|
|
if len(encoded) > max_length - 1: encoded = encoded[:max_length - 1]
|
|
return encoded + [1] + [0]*(max_length - len(encoded) - 1)
|
|
|
|
class T5LayerNorm:
|
|
def __init__(self, hidden_size:int, eps:float=1e-6):
|
|
"""
|
|
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
|
"""
|
|
self.weight = Tensor.ones(hidden_size)
|
|
self.variance_epsilon = eps
|
|
|
|
def __call__(self, hidden_states:Tensor) -> Tensor:
|
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
|
# half-precision inputs is done in fp32
|
|
|
|
variance = hidden_states.cast(dtypes.float32).pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * Tensor.rsqrt(variance + self.variance_epsilon)
|
|
|
|
# convert into half-precision if necessary
|
|
if self.weight.dtype in [dtypes.float16, dtypes.bfloat16]:
|
|
hidden_states = hidden_states.cast(self.weight.dtype)
|
|
|
|
return self.weight * hidden_states
|
|
|
|
|
|
class T5DenseGatedActDense:
|
|
def __init__(self, config:T5Config):
|
|
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
|
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
|
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
|
|
|
def __call__(self, hidden_states:Tensor) -> Tensor:
|
|
hidden_gelu = self.wi_0(hidden_states).gelu()
|
|
hidden_linear = self.wi_1(hidden_states)
|
|
hidden_states = hidden_gelu * hidden_linear
|
|
hidden_states = self.wo(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class T5LayerFF:
|
|
def __init__(self, config:T5Config):
|
|
self.DenseReluDense = T5DenseGatedActDense(config)
|
|
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
|
|
def __call__(self, hidden_states:Tensor) -> Tensor:
|
|
forwarded_states = self.layer_norm(hidden_states)
|
|
forwarded_states = self.DenseReluDense(forwarded_states)
|
|
hidden_states = hidden_states + forwarded_states
|
|
return hidden_states
|
|
|
|
|
|
class T5Attention:
|
|
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
|
self.has_relative_attention_bias = has_relative_attention_bias
|
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
|
self.relative_attention_max_distance = config.relative_attention_max_distance
|
|
self.d_model = config.d_model
|
|
self.key_value_proj_dim = config.d_kv
|
|
self.n_heads = config.num_heads
|
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
|
|
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
|
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
|
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
|
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
|
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
|
|
|
if self.has_relative_attention_bias:
|
|
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
|
|
|
@staticmethod
|
|
def _relative_position_bucket(relative_position:Tensor, num_buckets:int=32, max_distance:int=128) -> Tensor:
|
|
"""
|
|
Adapted from Mesh Tensorflow:
|
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
|
|
|
Args:
|
|
relative_position: an int32 Tensor
|
|
bidirectional: a boolean - whether the attention is bidirectional
|
|
num_buckets: an integer
|
|
max_distance: an integer
|
|
|
|
Returns:
|
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
|
"""
|
|
relative_buckets = Tensor.zeros_like(relative_position)
|
|
num_buckets //= 2
|
|
relative_buckets += (relative_position > 0).cast(dtypes.long) * num_buckets
|
|
relative_position = Tensor.abs(relative_position)
|
|
|
|
# half of the buckets are for exact increments in positions
|
|
max_exact = num_buckets // 2
|
|
is_small = relative_position < max_exact
|
|
|
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
|
relative_position_if_large = max_exact + (
|
|
Tensor.log(relative_position.float() / max_exact)
|
|
/ math.log(max_distance / max_exact)
|
|
* (num_buckets - max_exact)
|
|
).cast(dtypes.long)
|
|
|
|
relative_position_if_large = Tensor.min(
|
|
Tensor.stack(
|
|
relative_position_if_large, Tensor.full_like(relative_position_if_large, num_buckets - 1)
|
|
),
|
|
axis=0,
|
|
)
|
|
relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large)
|
|
return relative_buckets
|
|
|
|
def compute_bias(self, query_length, key_length, device=None) -> Tensor:
|
|
"""Compute binned relative position bias"""
|
|
if device is None:
|
|
device = self.relative_attention_bias.weight.device
|
|
context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None]
|
|
memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :]
|
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
|
relative_position_bucket = self._relative_position_bucket(
|
|
relative_position, # shape (query_length, key_length)
|
|
num_buckets=self.relative_attention_num_buckets,
|
|
max_distance=self.relative_attention_max_distance,
|
|
)
|
|
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
|
return values
|
|
|
|
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor,Tensor]:
|
|
"""
|
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
|
"""
|
|
# Input is (batch_size, seq_length, dim)
|
|
batch_size, key_length = hidden_states.shape[:2]
|
|
|
|
def shape(states):
|
|
"""projection"""
|
|
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
|
|
|
def unshape(states):
|
|
"""reshape"""
|
|
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
|
|
|
def project(hidden_states, proj_layer):
|
|
"""projects hidden states correctly to key/query states"""
|
|
# self-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
return shape(proj_layer(hidden_states))
|
|
|
|
# get query states
|
|
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
|
|
|
# get key/value states
|
|
key_states = project(hidden_states, self.k)
|
|
value_states = project(hidden_states, self.v)
|
|
|
|
# compute scores
|
|
scores = Tensor.matmul(query_states, key_states.transpose(3, 2))
|
|
|
|
if position_bias is None:
|
|
position_bias = self.compute_bias(key_length, key_length, device=scores.device)
|
|
|
|
scores += position_bias
|
|
attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length)
|
|
|
|
attn_output = unshape(Tensor.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
|
attn_output = self.o(attn_output)
|
|
|
|
return attn_output, position_bias
|
|
|
|
|
|
class T5LayerSelfAttention:
|
|
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
|
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
|
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
|
|
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
|
|
normed_hidden_states = self.layer_norm(hidden_states)
|
|
attention_output, position_bias = self.SelfAttention(normed_hidden_states, position_bias=position_bias)
|
|
return hidden_states + attention_output, position_bias
|
|
|
|
|
|
class T5Block:
|
|
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
|
self.layer = (T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias),
|
|
T5LayerFF(config))
|
|
|
|
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
|
|
self_attention_outputs, position_bias = self.layer[0](hidden_states, position_bias=position_bias)
|
|
hidden_states = self_attention_outputs
|
|
|
|
# Apply Feed Forward layer
|
|
hidden_states = self.layer[-1](hidden_states)
|
|
|
|
return hidden_states, position_bias
|
|
|
|
|
|
class T5Stack:
|
|
def __init__(self, config:T5Config, embed_tokens:nn.Embedding):
|
|
self.config = config
|
|
self.embed_tokens = embed_tokens
|
|
self.block = [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
|
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
|
|
def __call__(self, input_ids:Tensor) -> Tensor:
|
|
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
|
|
|
hidden_states, position_bias = self.embed_tokens(input_ids), None
|
|
|
|
for layer_module in self.block:
|
|
hidden_states, position_bias = layer_module(hidden_states, position_bias=position_bias)
|
|
|
|
return self.final_layer_norm(hidden_states)
|
|
|
|
|
|
class T5EncoderModel:
|
|
def __init__(self, config:T5Config):
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
|
self.encoder = T5Stack(config, self.shared)
|
|
|
|
def __call__(self, input_ids:Tensor) -> Tensor:
|
|
return self.encoder(input_ids)
|
|
|
|
class T5Embedder:
|
|
def __init__(self, max_length:int, spiece_path:Union[str, Path]):
|
|
self.tokenizer = T5Tokenizer(spiece_path)
|
|
self.max_length = max_length
|
|
config = T5Config()
|
|
self.encoder = T5EncoderModel(config)
|
|
|
|
def __call__(self, texts:Union[str, List[str]]) -> Tensor:
|
|
if isinstance(texts, str): texts = [texts]
|
|
toks = Tensor.cat(*[Tensor(self.tokenizer.encode(text, self.max_length)) for text in texts], dim=0)
|
|
return self.encoder(toks) |