mirror of https://github.com/commaai/tinygrad.git
Flux.1 (#6334)
* initial commit * whitespace * get rid of torch import * indentation * less hardcoding * add flux.1-dev * jit * no double * t5 tidy up * validation image * reuse sdxl autoencoder * typing changes * empty lines * remove unneeded comments --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
31b9c74c77
commit
19c11792fd
|
@ -0,0 +1,496 @@
|
|||
# pip3 install sentencepiece
|
||||
|
||||
# This file incorporates code from the following:
|
||||
# Github Name | License | Link
|
||||
# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses
|
||||
|
||||
from tinygrad import Tensor, nn, dtypes, TinyJit
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
from tinygrad.helpers import fetch, tqdm, colored
|
||||
from sdxl import FirstStage
|
||||
from extra.models.clip import FrozenClosedClipEmbedder
|
||||
from extra.models.t5 import T5Embedder
|
||||
import numpy as np
|
||||
|
||||
import math, time, argparse, tempfile
|
||||
from typing import List, Dict, Optional, Union, Tuple, Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
urls:dict = {
|
||||
"flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
|
||||
"flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
|
||||
"ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
|
||||
"T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
|
||||
"T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
|
||||
"T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
|
||||
"clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
|
||||
}
|
||||
|
||||
def tensor_identity(x:Tensor) -> Tensor: return x
|
||||
|
||||
class AutoEncoder:
|
||||
def __init__(self, scale_factor:float, shift_factor:float):
|
||||
self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
|
||||
self.scale_factor = scale_factor
|
||||
self.shift_factor = shift_factor
|
||||
|
||||
def decode(self, z:Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
# Conditioner
|
||||
class ClipEmbedder(FrozenClosedClipEmbedder):
|
||||
def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
||||
return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
x = Tensor.scaled_dot_product_attention(q, k, v)
|
||||
return x.rearrange("B H L D -> B L (H D)")
|
||||
|
||||
def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = Tensor.einsum("...n,d->...nd", pos, omega)
|
||||
out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
|
||||
out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)
|
||||
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND:
|
||||
def __init__(self, dim:int, theta:int, axes_dim:List[int]):
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def __call__(self, ids:Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
class MLPEmbedder:
|
||||
def __init__(self, in_dim:int, hidden_dim:int):
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.out_layer(self.in_layer(x).silu())
|
||||
|
||||
class QKNorm:
|
||||
def __init__(self, dim:int):
|
||||
self.query_norm = nn.RMSNorm(dim)
|
||||
self.key_norm = nn.RMSNorm(dim)
|
||||
|
||||
def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
return self.query_norm(q), self.key_norm(k)
|
||||
|
||||
class SelfAttention:
|
||||
def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
return self.proj(x)
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift:Tensor
|
||||
scale:Tensor
|
||||
gate:Tensor
|
||||
|
||||
class Modulation:
|
||||
def __init__(self, dim:int, double:bool):
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None
|
||||
|
||||
class DoubleStreamBlock:
|
||||
def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
||||
|
||||
def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
assert img_mod2 is not None and txt_mod2 is not None
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
||||
|
||||
# run actual attention
|
||||
q = Tensor.cat(txt_q, img_q, dim=2)
|
||||
k = Tensor.cat(txt_k, img_k, dim=2)
|
||||
v = Tensor.cat(txt_v, img_v, dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock:
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = Tensor.gelu
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer:
|
||||
def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]
|
||||
|
||||
def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
|
||||
shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
return self.linear(x)
|
||||
|
||||
def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
|
||||
if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype)
|
||||
return embedding
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
|
||||
class Flux:
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_embed:bool,
|
||||
in_channels:int = 64,
|
||||
vec_in_dim:int = 768,
|
||||
context_in_dim:int = 4096,
|
||||
hidden_size:int = 3072,
|
||||
mlp_ratio:float = 4.0,
|
||||
num_heads:int = 24,
|
||||
depth:int = 19,
|
||||
depth_single_blocks:int = 38,
|
||||
axes_dim:Optional[List[int]] = None,
|
||||
theta:int = 10_000,
|
||||
qkv_bias:bool = True,
|
||||
):
|
||||
|
||||
axes_dim = axes_dim or [16, 56, 56]
|
||||
self.guidance_embed = guidance_embed
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
||||
pe_dim = hidden_size // num_heads
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
|
||||
self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
|
||||
self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
ids = Tensor.cat(txt_ids, img_ids, dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
for double_block in self.double_blocks:
|
||||
img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = Tensor.cat(txt, img, dim=1)
|
||||
for single_block in self.single_blocks:
|
||||
img = single_block(img, vec=vec, pe=pe)
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
|
||||
def load_flow_model(name:str):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
model = Flux(guidance_embed=(name != "flux-schnell"))
|
||||
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()}
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
def load_T5(max_length:int=512):
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
print("Init T5")
|
||||
T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
|
||||
pt_1 = fetch(urls["T5_1_of_2"])
|
||||
pt_2 = fetch(urls["T5_2_of_2"])
|
||||
load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
|
||||
return T5
|
||||
|
||||
def load_clip():
|
||||
print("Init Clip")
|
||||
clip = ClipEmbedder()
|
||||
load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
|
||||
return clip
|
||||
|
||||
def load_ae() -> AutoEncoder:
|
||||
# Loading the autoencoder
|
||||
print("Init AE")
|
||||
ae = AutoEncoder(0.3611, 0.1159)
|
||||
load_state_dict(ae, safe_load(fetch(urls["ae"])))
|
||||
return ae
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
|
||||
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
|
||||
bs, _, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = img.expand((bs, *img.shape[1:]))
|
||||
|
||||
img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
|
||||
img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
|
||||
img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
|
||||
img_ids = img_ids.expand((bs, *img_ids.shape[1:]))
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = T5(prompt).realize()
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = txt.expand((bs, *txt.shape[1:]))
|
||||
txt_ids = Tensor.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt).realize()
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = vec.expand((bs, *vec.shape[1:]))
|
||||
|
||||
return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}
|
||||
|
||||
|
||||
def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
|
||||
# extra step for zero
|
||||
step_size = -1.0 / num_steps
|
||||
timesteps = Tensor.arange(1, 0 + step_size, step_size)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
|
||||
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
|
||||
return timesteps.tolist()
|
||||
|
||||
@TinyJit
|
||||
def run(model, *args): return model(*args).realize()
|
||||
|
||||
def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
|
||||
# this is ignored for schnell
|
||||
guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
|
||||
t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
||||
pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
def unpack(x:Tensor, height:int, width:int) -> Tensor:
|
||||
return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
|
||||
if __name__ == "__main__":
|
||||
default_prompt = "bananas and a can of coke"
|
||||
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
|
||||
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
|
||||
parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
|
||||
parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation")
|
||||
parser.add_argument("--output_dir", type=str, default="output", help="output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.name not in ["flux-schnell", "flux-dev"]:
|
||||
raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")
|
||||
|
||||
if args.num_steps is None:
|
||||
args.num_steps = 4 if args.name == "flux-schnell" else 50
|
||||
|
||||
# allow for packing and conversion to latent space
|
||||
height = 16 * (args.height // 16)
|
||||
width = 16 * (args.width // 16)
|
||||
|
||||
if args.seed is None: args.seed = Tensor._seed
|
||||
else: Tensor.manual_seed(args.seed)
|
||||
|
||||
print(f"Generating with seed {args.seed}:\n{args.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input noise
|
||||
x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")
|
||||
|
||||
# load text embedders
|
||||
T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
|
||||
clip = load_clip()
|
||||
|
||||
# embed text to get inputs for model
|
||||
inp = prepare(T5, clip, x, prompt=args.prompt)
|
||||
timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))
|
||||
|
||||
# done with text embedders
|
||||
del T5, clip
|
||||
|
||||
# load model
|
||||
model = load_flow_model(args.name)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
|
||||
|
||||
# done with model
|
||||
del model, run
|
||||
|
||||
# load autoencoder
|
||||
ae = load_ae()
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), height, width)
|
||||
x = ae.decode(x).realize()
|
||||
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")
|
||||
|
||||
# bring into PIL format and save
|
||||
x = x.clamp(-1, 1)
|
||||
x = x[0].rearrange("c h w -> h w c")
|
||||
x = (127.5 * (x + 1.0)).cast("uint8")
|
||||
|
||||
img = Image.fromarray(x.numpy())
|
||||
|
||||
img.save(args.out)
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
|
||||
ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
Binary file not shown.
After Width: | Height: | Size: 286 KiB |
|
@ -0,0 +1,288 @@
|
|||
# pip3 install sentencepiece
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tinygrad T5 model."""
|
||||
from tinygrad import nn, Tensor, dtypes
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
# default config is t5-xxl
|
||||
@dataclass
|
||||
class T5Config:
|
||||
d_ff:int = 10240
|
||||
d_kv:int = 64
|
||||
d_model:int = 4096
|
||||
layer_norm_epsilon:float = 1e-6
|
||||
num_decoder_layers:int = 24
|
||||
num_heads:int = 64
|
||||
num_layers:int = 24
|
||||
relative_attention_num_buckets:int = 32
|
||||
relative_attention_max_distance:int = 128
|
||||
vocab_size:int = 32128
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, spiece_path):
|
||||
self.spp = SentencePieceProcessor(str(spiece_path))
|
||||
|
||||
def encode(self, text:str, max_length:int) -> List[int]:
|
||||
encoded = self.spp.Encode(text)
|
||||
if len(encoded) > max_length - 1: encoded = encoded[:max_length - 1]
|
||||
return encoded + [1] + [0]*(max_length - len(encoded) - 1)
|
||||
|
||||
class T5LayerNorm:
|
||||
def __init__(self, hidden_size:int, eps:float=1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
self.weight = Tensor.ones(hidden_size)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def __call__(self, hidden_states:Tensor) -> Tensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.cast(dtypes.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * Tensor.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [dtypes.float16, dtypes.bfloat16]:
|
||||
hidden_states = hidden_states.cast(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense:
|
||||
def __init__(self, config:T5Config):
|
||||
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
|
||||
def __call__(self, hidden_states:Tensor) -> Tensor:
|
||||
hidden_gelu = self.wi_0(hidden_states).gelu()
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerFF:
|
||||
def __init__(self, config:T5Config):
|
||||
self.DenseReluDense = T5DenseGatedActDense(config)
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, hidden_states:Tensor) -> Tensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + forwarded_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5Attention:
|
||||
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = config.relative_attention_max_distance
|
||||
self.d_model = config.d_model
|
||||
self.key_value_proj_dim = config.d_kv
|
||||
self.n_heads = config.num_heads
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position:Tensor, num_buckets:int=32, max_distance:int=128) -> Tensor:
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = Tensor.zeros_like(relative_position)
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).cast(dtypes.long) * num_buckets
|
||||
relative_position = Tensor.abs(relative_position)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
Tensor.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).cast(dtypes.long)
|
||||
|
||||
relative_position_if_large = Tensor.min(
|
||||
Tensor.stack(
|
||||
relative_position_if_large, Tensor.full_like(relative_position_if_large, num_buckets - 1)
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device=None) -> Tensor:
|
||||
"""Compute binned relative position bias"""
|
||||
if device is None:
|
||||
device = self.relative_attention_bias.weight.device
|
||||
context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None]
|
||||
memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor,Tensor]:
|
||||
"""
|
||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||
"""
|
||||
# Input is (batch_size, seq_length, dim)
|
||||
batch_size, key_length = hidden_states.shape[:2]
|
||||
|
||||
def shape(states):
|
||||
"""projection"""
|
||||
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
|
||||
def unshape(states):
|
||||
"""reshape"""
|
||||
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
||||
|
||||
def project(hidden_states, proj_layer):
|
||||
"""projects hidden states correctly to key/query states"""
|
||||
# self-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
return shape(proj_layer(hidden_states))
|
||||
|
||||
# get query states
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
||||
|
||||
# get key/value states
|
||||
key_states = project(hidden_states, self.k)
|
||||
value_states = project(hidden_states, self.v)
|
||||
|
||||
# compute scores
|
||||
scores = Tensor.matmul(query_states, key_states.transpose(3, 2))
|
||||
|
||||
if position_bias is None:
|
||||
position_bias = self.compute_bias(key_length, key_length, device=scores.device)
|
||||
|
||||
scores += position_bias
|
||||
attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
attn_output = unshape(Tensor.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
return attn_output, position_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention:
|
||||
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
||||
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output, position_bias = self.SelfAttention(normed_hidden_states, position_bias=position_bias)
|
||||
return hidden_states + attention_output, position_bias
|
||||
|
||||
|
||||
class T5Block:
|
||||
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
|
||||
self.layer = (T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias),
|
||||
T5LayerFF(config))
|
||||
|
||||
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
|
||||
self_attention_outputs, position_bias = self.layer[0](hidden_states, position_bias=position_bias)
|
||||
hidden_states = self_attention_outputs
|
||||
|
||||
# Apply Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
|
||||
return hidden_states, position_bias
|
||||
|
||||
|
||||
class T5Stack:
|
||||
def __init__(self, config:T5Config, embed_tokens:nn.Embedding):
|
||||
self.config = config
|
||||
self.embed_tokens = embed_tokens
|
||||
self.block = [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
||||
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
hidden_states, position_bias = self.embed_tokens(input_ids), None
|
||||
|
||||
for layer_module in self.block:
|
||||
hidden_states, position_bias = layer_module(hidden_states, position_bias=position_bias)
|
||||
|
||||
return self.final_layer_norm(hidden_states)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(self, config:T5Config):
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = T5Stack(config, self.shared)
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
return self.encoder(input_ids)
|
||||
|
||||
class T5Embedder:
|
||||
def __init__(self, max_length:int, spiece_path:Union[str, Path]):
|
||||
self.tokenizer = T5Tokenizer(spiece_path)
|
||||
self.max_length = max_length
|
||||
config = T5Config()
|
||||
self.encoder = T5EncoderModel(config)
|
||||
|
||||
def __call__(self, texts:Union[str, List[str]]) -> Tensor:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
toks = Tensor.cat(*[Tensor(self.tokenizer.encode(text, self.max_length)) for text in texts], dim=0)
|
||||
return self.encoder(toks)
|
Loading…
Reference in New Issue