* initial commit

* whitespace

* get rid of torch import

* indentation

* less hardcoding

* add flux.1-dev

* jit

* no double

* t5 tidy up

* validation image

* reuse sdxl autoencoder

* typing changes

* empty lines

* remove unneeded comments

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
samm393 2024-09-24 03:08:04 +01:00 committed by GitHub
parent 31b9c74c77
commit 19c11792fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 784 additions and 0 deletions

496
examples/flux1.py Normal file
View File

@ -0,0 +1,496 @@
# pip3 install sentencepiece
# This file incorporates code from the following:
# Github Name | License | Link
# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses
from tinygrad import Tensor, nn, dtypes, TinyJit
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, tqdm, colored
from sdxl import FirstStage
from extra.models.clip import FrozenClosedClipEmbedder
from extra.models.t5 import T5Embedder
import numpy as np
import math, time, argparse, tempfile
from typing import List, Dict, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
urls:dict = {
"flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
"flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
"ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
"T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
"T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
"T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
"clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
}
def tensor_identity(x:Tensor) -> Tensor: return x
class AutoEncoder:
def __init__(self, scale_factor:float, shift_factor:float):
self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
self.scale_factor = scale_factor
self.shift_factor = shift_factor
def decode(self, z:Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
# Conditioner
class ClipEmbedder(FrozenClosedClipEmbedder):
def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]
# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = Tensor.scaled_dot_product_attention(q, k, v)
return x.rearrange("B H L D -> B L (H D)")
def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
assert dim % 2 == 0
scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
omega = 1.0 / (theta**scale)
out = Tensor.einsum("...n,d->...nd", pos, omega)
out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND:
def __init__(self, dim:int, theta:int, axes_dim:List[int]):
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids:Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class MLPEmbedder:
def __init__(self, in_dim:int, hidden_dim:int):
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x:Tensor) -> Tensor:
return self.out_layer(self.in_layer(x).silu())
class QKNorm:
def __init__(self, dim:int):
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention:
def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k)
x = attention(q, k, v, pe=pe)
return self.proj(x)
@dataclass
class ModulationOut:
shift:Tensor
scale:Tensor
gate:Tensor
class Modulation:
def __init__(self, dim:int, double:bool):
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None
class DoubleStreamBlock:
def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
assert img_mod2 is not None and txt_mod2 is not None
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = Tensor.cat(txt_q, img_q, dim=2)
k = Tensor.cat(txt_k, img_k, dim=2)
v = Tensor.cat(txt_v, img_v, dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
return img, txt
class SingleStreamBlock:
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = Tensor.gelu
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
return x + mod.gate * output
class LastLayer:
def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]
def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
return self.linear(x)
def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype)
return embedding
# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
class Flux:
"""
Transformer model for flow matching on sequences.
"""
def __init__(
self,
guidance_embed:bool,
in_channels:int = 64,
vec_in_dim:int = 768,
context_in_dim:int = 4096,
hidden_size:int = 3072,
mlp_ratio:float = 4.0,
num_heads:int = 24,
depth:int = 19,
depth_single_blocks:int = 38,
axes_dim:Optional[List[int]] = None,
theta:int = 10_000,
qkv_bias:bool = True,
):
axes_dim = axes_dim or [16, 56, 56]
self.guidance_embed = guidance_embed
self.in_channels = in_channels
self.out_channels = self.in_channels
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = Tensor.cat(txt_ids, img_ids, dim=1)
pe = self.pe_embedder(ids)
for double_block in self.double_blocks:
img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)
img = Tensor.cat(txt, img, dim=1)
for single_block in self.single_blocks:
img = single_block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
def load_flow_model(name:str):
# Loading Flux
print("Init model")
model = Flux(guidance_embed=(name != "flux-schnell"))
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()}
load_state_dict(model, state_dict)
return model
def load_T5(max_length:int=512):
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
print("Init T5")
T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
pt_1 = fetch(urls["T5_1_of_2"])
pt_2 = fetch(urls["T5_2_of_2"])
load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
return T5
def load_clip():
print("Init Clip")
clip = ClipEmbedder()
load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
return clip
def load_ae() -> AutoEncoder:
# Loading the autoencoder
print("Init AE")
ae = AutoEncoder(0.3611, 0.1159)
load_state_dict(ae, safe_load(fetch(urls["ae"])))
return ae
# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = img.expand((bs, *img.shape[1:]))
img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
img_ids = img_ids.expand((bs, *img_ids.shape[1:]))
if isinstance(prompt, str):
prompt = [prompt]
txt = T5(prompt).realize()
if txt.shape[0] == 1 and bs > 1:
txt = txt.expand((bs, *txt.shape[1:]))
txt_ids = Tensor.zeros(bs, txt.shape[1], 3)
vec = clip(prompt).realize()
if vec.shape[0] == 1 and bs > 1:
vec = vec.expand((bs, *vec.shape[1:]))
return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}
def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
# extra step for zero
step_size = -1.0 / num_steps
timesteps = Tensor.arange(1, 0 + step_size, step_size)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
return timesteps.tolist()
@TinyJit
def run(model, *args): return model(*args).realize()
def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
# this is ignored for schnell
guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x:Tensor, height:int, width:int) -> Tensor:
return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
if __name__ == "__main__":
default_prompt = "bananas and a can of coke"
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation")
parser.add_argument("--output_dir", type=str, default="output", help="output directory")
args = parser.parse_args()
if args.name not in ["flux-schnell", "flux-dev"]:
raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")
if args.num_steps is None:
args.num_steps = 4 if args.name == "flux-schnell" else 50
# allow for packing and conversion to latent space
height = 16 * (args.height // 16)
width = 16 * (args.width // 16)
if args.seed is None: args.seed = Tensor._seed
else: Tensor.manual_seed(args.seed)
print(f"Generating with seed {args.seed}:\n{args.prompt}")
t0 = time.perf_counter()
# prepare input noise
x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")
# load text embedders
T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
clip = load_clip()
# embed text to get inputs for model
inp = prepare(T5, clip, x, prompt=args.prompt)
timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))
# done with text embedders
del T5, clip
# load model
model = load_flow_model(args.name)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
# done with model
del model, run
# load autoencoder
ae = load_ae()
# decode latents to pixel space
x = unpack(x.float(), height, width)
x = ae.decode(x).realize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")
# bring into PIL format and save
x = x.clamp(-1, 1)
x = x[0].rearrange("c h w -> h w c")
x = (127.5 * (x + 1.0)).cast("uint8")
img = Image.fromarray(x.numpy())
img.save(args.out)
# validation!
if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))

BIN
examples/flux1_seed0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 286 KiB

288
extra/models/t5.py Normal file
View File

@ -0,0 +1,288 @@
# pip3 install sentencepiece
# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tinygrad T5 model."""
from tinygrad import nn, Tensor, dtypes
import math
from dataclasses import dataclass
from typing import List, Union, Optional, Tuple
from pathlib import Path
from sentencepiece import SentencePieceProcessor
# default config is t5-xxl
@dataclass
class T5Config:
d_ff:int = 10240
d_kv:int = 64
d_model:int = 4096
layer_norm_epsilon:float = 1e-6
num_decoder_layers:int = 24
num_heads:int = 64
num_layers:int = 24
relative_attention_num_buckets:int = 32
relative_attention_max_distance:int = 128
vocab_size:int = 32128
class T5Tokenizer:
def __init__(self, spiece_path):
self.spp = SentencePieceProcessor(str(spiece_path))
def encode(self, text:str, max_length:int) -> List[int]:
encoded = self.spp.Encode(text)
if len(encoded) > max_length - 1: encoded = encoded[:max_length - 1]
return encoded + [1] + [0]*(max_length - len(encoded) - 1)
class T5LayerNorm:
def __init__(self, hidden_size:int, eps:float=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
self.weight = Tensor.ones(hidden_size)
self.variance_epsilon = eps
def __call__(self, hidden_states:Tensor) -> Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.cast(dtypes.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * Tensor.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [dtypes.float16, dtypes.bfloat16]:
hidden_states = hidden_states.cast(self.weight.dtype)
return self.weight * hidden_states
class T5DenseGatedActDense:
def __init__(self, config:T5Config):
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
def __call__(self, hidden_states:Tensor) -> Tensor:
hidden_gelu = self.wi_0(hidden_states).gelu()
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF:
def __init__(self, config:T5Config):
self.DenseReluDense = T5DenseGatedActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, hidden_states:Tensor) -> Tensor:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + forwarded_states
return hidden_states
class T5Attention:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(relative_position:Tensor, num_buckets:int=32, max_distance:int=128) -> Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = Tensor.zeros_like(relative_position)
num_buckets //= 2
relative_buckets += (relative_position > 0).cast(dtypes.long) * num_buckets
relative_position = Tensor.abs(relative_position)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
Tensor.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).cast(dtypes.long)
relative_position_if_large = Tensor.min(
Tensor.stack(
relative_position_if_large, Tensor.full_like(relative_position_if_large, num_buckets - 1)
),
axis=0,
)
relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None) -> Tensor:
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None]
memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor,Tensor]:
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
batch_size, key_length = hidden_states.shape[:2]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer):
"""projects hidden states correctly to key/query states"""
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
return shape(proj_layer(hidden_states))
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(hidden_states, self.k)
value_states = project(hidden_states, self.v)
# compute scores
scores = Tensor.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None:
position_bias = self.compute_bias(key_length, key_length, device=scores.device)
scores += position_bias
attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length)
attn_output = unshape(Tensor.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
return attn_output, position_bias
class T5LayerSelfAttention:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output, position_bias = self.SelfAttention(normed_hidden_states, position_bias=position_bias)
return hidden_states + attention_output, position_bias
class T5Block:
def __init__(self, config:T5Config, has_relative_attention_bias:bool=False):
self.layer = (T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias),
T5LayerFF(config))
def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]:
self_attention_outputs, position_bias = self.layer[0](hidden_states, position_bias=position_bias)
hidden_states = self_attention_outputs
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
return hidden_states, position_bias
class T5Stack:
def __init__(self, config:T5Config, embed_tokens:nn.Embedding):
self.config = config
self.embed_tokens = embed_tokens
self.block = [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, input_ids:Tensor) -> Tensor:
input_ids = input_ids.view(-1, input_ids.shape[-1])
hidden_states, position_bias = self.embed_tokens(input_ids), None
for layer_module in self.block:
hidden_states, position_bias = layer_module(hidden_states, position_bias=position_bias)
return self.final_layer_norm(hidden_states)
class T5EncoderModel:
def __init__(self, config:T5Config):
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = T5Stack(config, self.shared)
def __call__(self, input_ids:Tensor) -> Tensor:
return self.encoder(input_ids)
class T5Embedder:
def __init__(self, max_length:int, spiece_path:Union[str, Path]):
self.tokenizer = T5Tokenizer(spiece_path)
self.max_length = max_length
config = T5Config()
self.encoder = T5EncoderModel(config)
def __call__(self, texts:Union[str, List[str]]) -> Tensor:
if isinstance(texts, str): texts = [texts]
toks = Tensor.cat(*[Tensor(self.tokenizer.encode(text, self.max_length)) for text in texts], dim=0)
return self.encoder(toks)