diff --git a/examples/flux1.py b/examples/flux1.py new file mode 100644 index 00000000..66d44823 --- /dev/null +++ b/examples/flux1.py @@ -0,0 +1,496 @@ +# pip3 install sentencepiece + +# This file incorporates code from the following: +# Github Name | License | Link +# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses + +from tinygrad import Tensor, nn, dtypes, TinyJit +from tinygrad.nn.state import safe_load, load_state_dict +from tinygrad.helpers import fetch, tqdm, colored +from sdxl import FirstStage +from extra.models.clip import FrozenClosedClipEmbedder +from extra.models.t5 import T5Embedder +import numpy as np + +import math, time, argparse, tempfile +from typing import List, Dict, Optional, Union, Tuple, Callable +from dataclasses import dataclass +from pathlib import Path +from PIL import Image + +urls:dict = { + "flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors", + "flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft", + "ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors", + "T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors", + "T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors", + "T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model", + "clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors" +} + +def tensor_identity(x:Tensor) -> Tensor: return x + +class AutoEncoder: + def __init__(self, scale_factor:float, shift_factor:float): + self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256) + self.scale_factor = scale_factor + self.shift_factor = shift_factor + + def decode(self, z:Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + +# Conditioner +class ClipEmbedder(FrozenClosedClipEmbedder): + def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor: + if isinstance(texts, str): texts = [texts] + assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" + tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0) + return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)] + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + x = Tensor.scaled_dot_product_attention(q, k, v) + return x.rearrange("B H L D -> B L (H D)") + +def rope(pos:Tensor, dim:int, theta:int) -> Tensor: + assert dim % 2 == 0 + scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation + omega = 1.0 / (theta**scale) + out = Tensor.einsum("...n,d->...nd", pos, omega) + out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1) + out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + +def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype) + + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND: + def __init__(self, dim:int, theta:int, axes_dim:List[int]): + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def __call__(self, ids:Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + +class MLPEmbedder: + def __init__(self, in_dim:int, hidden_dim:int): + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def __call__(self, x:Tensor) -> Tensor: + return self.out_layer(self.in_layer(x).silu()) + +class QKNorm: + def __init__(self, dim:int): + self.query_norm = nn.RMSNorm(dim) + self.key_norm = nn.RMSNorm(dim) + + def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]: + return self.query_norm(q), self.key_norm(k) + +class SelfAttention: + def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False): + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def __call__(self, x:Tensor, pe:Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k) + x = attention(q, k, v, pe=pe) + return self.proj(x) + +@dataclass +class ModulationOut: + shift:Tensor + scale:Tensor + gate:Tensor + +class Modulation: + def __init__(self, dim:int, double:bool): + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]: + out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1) + return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None + +class DoubleStreamBlock: + def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False): + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)] + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)] + + def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + assert img_mod2 is not None and txt_mod2 is not None + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k) + + # run actual attention + q = Tensor.cat(txt_q, img_q, dim=2) + k = Tensor.cat(txt_k, img_k, dim=2) + v = Tensor.cat(txt_v, img_v, dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp) + return img, txt + + +class SingleStreamBlock: + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None): + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = Tensor.gelu + self.modulation = Modulation(hidden_size, double=False) + + def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2)) + return x + mod.gate * output + + +class LastLayer: + def __init__(self, hidden_size:int, patch_size:int, out_channels:int): + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)] + + def __call__(self, x:Tensor, vec:Tensor) -> Tensor: + shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + return self.linear(x) + +def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor: + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1) + if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1) + if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype) + return embedding + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py +class Flux: + """ + Transformer model for flow matching on sequences. + """ + + def __init__( + self, + guidance_embed:bool, + in_channels:int = 64, + vec_in_dim:int = 768, + context_in_dim:int = 4096, + hidden_size:int = 3072, + mlp_ratio:float = 4.0, + num_heads:int = 24, + depth:int = 19, + depth_single_blocks:int = 38, + axes_dim:Optional[List[int]] = None, + theta:int = 10_000, + qkv_bias:bool = True, + ): + + axes_dim = axes_dim or [16, 56, 56] + self.guidance_embed = guidance_embed + self.in_channels = in_channels + self.out_channels = self.in_channels + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") + pe_dim = hidden_size // num_heads + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size) + self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + + self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)] + self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)] + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + ids = Tensor.cat(txt_ids, img_ids, dim=1) + pe = self.pe_embedder(ids) + for double_block in self.double_blocks: + img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe) + + img = Tensor.cat(txt, img, dim=1) + for single_block in self.single_blocks: + img = single_block(img, vec=vec, pe=pe) + + img = img[:, txt.shape[1] :, ...] + + return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py +def load_flow_model(name:str): + # Loading Flux + print("Init model") + model = Flux(guidance_embed=(name != "flux-schnell")) + state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()} + load_state_dict(model, state_dict) + return model + +def load_T5(max_length:int=512): + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + print("Init T5") + T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"])) + pt_1 = fetch(urls["T5_1_of_2"]) + pt_2 = fetch(urls["T5_2_of_2"]) + load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False) + return T5 + +def load_clip(): + print("Init Clip") + clip = ClipEmbedder() + load_state_dict(clip.transformer, safe_load(fetch(urls["clip"]))) + return clip + +def load_ae() -> AutoEncoder: + # Loading the autoencoder + print("Init AE") + ae = AutoEncoder(0.3611, 0.1159) + load_state_dict(ae, safe_load(fetch(urls["ae"]))) + return ae + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py +def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]: + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = img.expand((bs, *img.shape[1:])) + + img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous() + img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :] + img_ids = img_ids.rearrange("h w c -> 1 (h w) c") + img_ids = img_ids.expand((bs, *img_ids.shape[1:])) + + if isinstance(prompt, str): + prompt = [prompt] + txt = T5(prompt).realize() + if txt.shape[0] == 1 and bs > 1: + txt = txt.expand((bs, *txt.shape[1:])) + txt_ids = Tensor.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt).realize() + if vec.shape[0] == 1 and bs > 1: + vec = vec.expand((bs, *vec.shape[1:])) + + return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)} + + +def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]: + # extra step for zero + step_size = -1.0 / num_steps + timesteps = Tensor.arange(1, 0 + step_size, step_size) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256) + timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1)) + return timesteps.tolist() + +@TinyJit +def run(model, *args): return model(*args).realize() + +def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor: + # this is ignored for schnell + guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],)) + for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"): + t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],)) + pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec) + img = img + (t_prev - t_curr) * pred + + return img + +def unpack(x:Tensor, height:int, width:int) -> Tensor: + return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2) + +# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py +if __name__ == "__main__": + default_prompt = "bananas and a can of coke" + parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load") + parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)") + parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)") + parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling") + parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling") + parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") + parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501 + parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation") + parser.add_argument("--output_dir", type=str, default="output", help="output directory") + args = parser.parse_args() + + if args.name not in ["flux-schnell", "flux-dev"]: + raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev") + + if args.num_steps is None: + args.num_steps = 4 if args.name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (args.height // 16) + width = 16 * (args.width // 16) + + if args.seed is None: args.seed = Tensor._seed + else: Tensor.manual_seed(args.seed) + + print(f"Generating with seed {args.seed}:\n{args.prompt}") + t0 = time.perf_counter() + + # prepare input noise + x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16") + + # load text embedders + T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512) + clip = load_clip() + + # embed text to get inputs for model + inp = prepare(T5, clip, x, prompt=args.prompt) + timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell")) + + # done with text embedders + del T5, clip + + # load model + model = load_flow_model(args.name) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance) + + # done with model + del model, run + + # load autoencoder + ae = load_ae() + + # decode latents to pixel space + x = unpack(x.float(), height, width) + x = ae.decode(x).realize() + + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s. Saving {args.out}") + + # bring into PIL format and save + x = x.clamp(-1, 1) + x = x[0].rearrange("c h w -> h w c") + x = (127.5 * (x + 1.0)).cast("uint8") + + img = Image.fromarray(x.numpy()) + + img.save(args.out) + + # validation! + if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512: + ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png"))) + distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() + assert distance < 4e-3, colored(f"validation failed with {distance=}", "red") + print(colored(f"output validated with {distance=}", "green")) \ No newline at end of file diff --git a/examples/flux1_seed0.png b/examples/flux1_seed0.png new file mode 100644 index 00000000..3acd1719 Binary files /dev/null and b/examples/flux1_seed0.png differ diff --git a/extra/models/t5.py b/extra/models/t5.py new file mode 100644 index 00000000..fd0ba34a --- /dev/null +++ b/extra/models/t5.py @@ -0,0 +1,288 @@ +# pip3 install sentencepiece + +# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tinygrad T5 model.""" +from tinygrad import nn, Tensor, dtypes + +import math +from dataclasses import dataclass +from typing import List, Union, Optional, Tuple +from pathlib import Path +from sentencepiece import SentencePieceProcessor + +# default config is t5-xxl +@dataclass +class T5Config: + d_ff:int = 10240 + d_kv:int = 64 + d_model:int = 4096 + layer_norm_epsilon:float = 1e-6 + num_decoder_layers:int = 24 + num_heads:int = 64 + num_layers:int = 24 + relative_attention_num_buckets:int = 32 + relative_attention_max_distance:int = 128 + vocab_size:int = 32128 + +class T5Tokenizer: + def __init__(self, spiece_path): + self.spp = SentencePieceProcessor(str(spiece_path)) + + def encode(self, text:str, max_length:int) -> List[int]: + encoded = self.spp.Encode(text) + if len(encoded) > max_length - 1: encoded = encoded[:max_length - 1] + return encoded + [1] + [0]*(max_length - len(encoded) - 1) + +class T5LayerNorm: + def __init__(self, hidden_size:int, eps:float=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + self.weight = Tensor.ones(hidden_size) + self.variance_epsilon = eps + + def __call__(self, hidden_states:Tensor) -> Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.cast(dtypes.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * Tensor.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [dtypes.float16, dtypes.bfloat16]: + hidden_states = hidden_states.cast(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseGatedActDense: + def __init__(self, config:T5Config): + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + + def __call__(self, hidden_states:Tensor) -> Tensor: + hidden_gelu = self.wi_0(hidden_states).gelu() + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF: + def __init__(self, config:T5Config): + self.DenseReluDense = T5DenseGatedActDense(config) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def __call__(self, hidden_states:Tensor) -> Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention: + def __init__(self, config:T5Config, has_relative_attention_bias:bool=False): + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + + @staticmethod + def _relative_position_bucket(relative_position:Tensor, num_buckets:int=32, max_distance:int=128) -> Tensor: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = Tensor.zeros_like(relative_position) + num_buckets //= 2 + relative_buckets += (relative_position > 0).cast(dtypes.long) * num_buckets + relative_position = Tensor.abs(relative_position) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + Tensor.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).cast(dtypes.long) + + relative_position_if_large = Tensor.min( + Tensor.stack( + relative_position_if_large, Tensor.full_like(relative_position_if_large, num_buckets - 1) + ), + axis=0, + ) + relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None) -> Tensor: + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None] + memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor,Tensor]: + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + batch_size, key_length = hidden_states.shape[:2] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer): + """projects hidden states correctly to key/query states""" + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + return shape(proj_layer(hidden_states)) + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project(hidden_states, self.k) + value_states = project(hidden_states, self.v) + + # compute scores + scores = Tensor.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + position_bias = self.compute_bias(key_length, key_length, device=scores.device) + + scores += position_bias + attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length) + + attn_output = unshape(Tensor.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + return attn_output, position_bias + + +class T5LayerSelfAttention: + def __init__(self, config:T5Config, has_relative_attention_bias:bool=False): + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output, position_bias = self.SelfAttention(normed_hidden_states, position_bias=position_bias) + return hidden_states + attention_output, position_bias + + +class T5Block: + def __init__(self, config:T5Config, has_relative_attention_bias:bool=False): + self.layer = (T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias), + T5LayerFF(config)) + + def __call__(self, hidden_states:Tensor, position_bias:Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: + self_attention_outputs, position_bias = self.layer[0](hidden_states, position_bias=position_bias) + hidden_states = self_attention_outputs + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states, position_bias + + +class T5Stack: + def __init__(self, config:T5Config, embed_tokens:nn.Embedding): + self.config = config + self.embed_tokens = embed_tokens + self.block = [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def __call__(self, input_ids:Tensor) -> Tensor: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + + hidden_states, position_bias = self.embed_tokens(input_ids), None + + for layer_module in self.block: + hidden_states, position_bias = layer_module(hidden_states, position_bias=position_bias) + + return self.final_layer_norm(hidden_states) + + +class T5EncoderModel: + def __init__(self, config:T5Config): + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = T5Stack(config, self.shared) + + def __call__(self, input_ids:Tensor) -> Tensor: + return self.encoder(input_ids) + +class T5Embedder: + def __init__(self, max_length:int, spiece_path:Union[str, Path]): + self.tokenizer = T5Tokenizer(spiece_path) + self.max_length = max_length + config = T5Config() + self.encoder = T5EncoderModel(config) + + def __call__(self, texts:Union[str, List[str]]) -> Tensor: + if isinstance(texts, str): texts = [texts] + toks = Tensor.cat(*[Tensor(self.tokenizer.encode(text, self.max_length)) for text in texts], dim=0) + return self.encoder(toks) \ No newline at end of file