mirror of https://github.com/commaai/tinygrad.git
496 lines
21 KiB
Python
496 lines
21 KiB
Python
# pip3 install sentencepiece
|
|
|
|
# This file incorporates code from the following:
|
|
# Github Name | License | Link
|
|
# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses
|
|
|
|
from tinygrad import Tensor, nn, dtypes, TinyJit
|
|
from tinygrad.nn.state import safe_load, load_state_dict
|
|
from tinygrad.helpers import fetch, tqdm, colored
|
|
from sdxl import FirstStage
|
|
from extra.models.clip import FrozenClosedClipEmbedder
|
|
from extra.models.t5 import T5Embedder
|
|
import numpy as np
|
|
|
|
import math, time, argparse, tempfile
|
|
from typing import List, Dict, Optional, Union, Tuple, Callable
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
|
|
urls:dict = {
|
|
"flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
|
|
"flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
|
|
"ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
|
|
"T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
|
|
"T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
|
|
"T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
|
|
"clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
|
|
}
|
|
|
|
def tensor_identity(x:Tensor) -> Tensor: return x
|
|
|
|
class AutoEncoder:
|
|
def __init__(self, scale_factor:float, shift_factor:float):
|
|
self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
|
|
self.scale_factor = scale_factor
|
|
self.shift_factor = shift_factor
|
|
|
|
def decode(self, z:Tensor) -> Tensor:
|
|
z = z / self.scale_factor + self.shift_factor
|
|
return self.decoder(z)
|
|
|
|
# Conditioner
|
|
class ClipEmbedder(FrozenClosedClipEmbedder):
|
|
def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
|
|
if isinstance(texts, str): texts = [texts]
|
|
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
|
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
|
return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
|
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
|
|
q, k = apply_rope(q, k, pe)
|
|
x = Tensor.scaled_dot_product_attention(q, k, v)
|
|
return x.rearrange("B H L D -> B L (H D)")
|
|
|
|
def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
|
|
assert dim % 2 == 0
|
|
scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
|
|
omega = 1.0 / (theta**scale)
|
|
out = Tensor.einsum("...n,d->...nd", pos, omega)
|
|
out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
|
|
out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
|
return out.float()
|
|
|
|
def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
|
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)
|
|
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
|
class EmbedND:
|
|
def __init__(self, dim:int, theta:int, axes_dim:List[int]):
|
|
self.dim = dim
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
|
|
def __call__(self, ids:Tensor) -> Tensor:
|
|
n_axes = ids.shape[-1]
|
|
emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
|
return emb.unsqueeze(1)
|
|
|
|
class MLPEmbedder:
|
|
def __init__(self, in_dim:int, hidden_dim:int):
|
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return self.out_layer(self.in_layer(x).silu())
|
|
|
|
class QKNorm:
|
|
def __init__(self, dim:int):
|
|
self.query_norm = nn.RMSNorm(dim)
|
|
self.key_norm = nn.RMSNorm(dim)
|
|
|
|
def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
|
|
return self.query_norm(q), self.key_norm(k)
|
|
|
|
class SelfAttention:
|
|
def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.norm = QKNorm(head_dim)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
|
|
qkv = self.qkv(x)
|
|
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
q, k = self.norm(q, k)
|
|
x = attention(q, k, v, pe=pe)
|
|
return self.proj(x)
|
|
|
|
@dataclass
|
|
class ModulationOut:
|
|
shift:Tensor
|
|
scale:Tensor
|
|
gate:Tensor
|
|
|
|
class Modulation:
|
|
def __init__(self, dim:int, double:bool):
|
|
self.is_double = double
|
|
self.multiplier = 6 if double else 3
|
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
|
|
|
def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
|
out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
|
|
return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None
|
|
|
|
class DoubleStreamBlock:
|
|
def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
self.num_heads = num_heads
|
|
self.hidden_size = hidden_size
|
|
self.img_mod = Modulation(hidden_size, double=True)
|
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
|
|
|
self.txt_mod = Modulation(hidden_size, double=True)
|
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
|
|
|
def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
|
|
img_mod1, img_mod2 = self.img_mod(vec)
|
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
assert img_mod2 is not None and txt_mod2 is not None
|
|
# prepare image for attention
|
|
img_modulated = self.img_norm1(img)
|
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
img_qkv = self.img_attn.qkv(img_modulated)
|
|
img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = self.txt_norm1(txt)
|
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
|
|
|
# run actual attention
|
|
q = Tensor.cat(txt_q, img_q, dim=2)
|
|
k = Tensor.cat(txt_k, img_k, dim=2)
|
|
v = Tensor.cat(txt_v, img_v, dim=2)
|
|
|
|
attn = attention(q, k, v, pe=pe)
|
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
|
|
# calculate the img bloks
|
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)
|
|
|
|
# calculate the txt bloks
|
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
|
txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
|
|
return img, txt
|
|
|
|
|
|
class SingleStreamBlock:
|
|
"""
|
|
A DiT block with parallel linear layers as described in
|
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
|
"""
|
|
|
|
def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
|
|
self.hidden_dim = hidden_size
|
|
self.num_heads = num_heads
|
|
head_dim = hidden_size // num_heads
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
# qkv and mlp_in
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
# proj and mlp_out
|
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
|
|
self.norm = QKNorm(head_dim)
|
|
|
|
self.hidden_size = hidden_size
|
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
|
|
self.mlp_act = Tensor.gelu
|
|
self.modulation = Modulation(hidden_size, double=False)
|
|
|
|
def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
|
|
mod, _ = self.modulation(vec)
|
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
|
qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
q, k = self.norm(q, k)
|
|
|
|
# compute attention
|
|
attn = attention(q, k, v, pe=pe)
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
|
|
return x + mod.gate * output
|
|
|
|
|
|
class LastLayer:
|
|
def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
|
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]
|
|
|
|
def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
|
|
shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
|
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
|
return self.linear(x)
|
|
|
|
def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, D) Tensor of positional embeddings.
|
|
"""
|
|
t = time_factor * t
|
|
half = dim // 2
|
|
freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)
|
|
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
|
|
if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
|
|
if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype)
|
|
return embedding
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
|
|
class Flux:
|
|
"""
|
|
Transformer model for flow matching on sequences.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
guidance_embed:bool,
|
|
in_channels:int = 64,
|
|
vec_in_dim:int = 768,
|
|
context_in_dim:int = 4096,
|
|
hidden_size:int = 3072,
|
|
mlp_ratio:float = 4.0,
|
|
num_heads:int = 24,
|
|
depth:int = 19,
|
|
depth_single_blocks:int = 38,
|
|
axes_dim:Optional[List[int]] = None,
|
|
theta:int = 10_000,
|
|
qkv_bias:bool = True,
|
|
):
|
|
|
|
axes_dim = axes_dim or [16, 56, 56]
|
|
self.guidance_embed = guidance_embed
|
|
self.in_channels = in_channels
|
|
self.out_channels = self.in_channels
|
|
if hidden_size % num_heads != 0:
|
|
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
|
pe_dim = hidden_size // num_heads
|
|
if sum(axes_dim) != pe_dim:
|
|
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
|
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
|
|
self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
|
|
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
|
|
|
self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
|
|
self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
|
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
|
|
|
def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
# running on sequences img
|
|
img = self.img_in(img)
|
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
if self.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
vec = vec + self.vector_in(y)
|
|
txt = self.txt_in(txt)
|
|
ids = Tensor.cat(txt_ids, img_ids, dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
for double_block in self.double_blocks:
|
|
img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
img = Tensor.cat(txt, img, dim=1)
|
|
for single_block in self.single_blocks:
|
|
img = single_block(img, vec=vec, pe=pe)
|
|
|
|
img = img[:, txt.shape[1] :, ...]
|
|
|
|
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
|
|
def load_flow_model(name:str):
|
|
# Loading Flux
|
|
print("Init model")
|
|
model = Flux(guidance_embed=(name != "flux-schnell"))
|
|
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()}
|
|
load_state_dict(model, state_dict)
|
|
return model
|
|
|
|
def load_T5(max_length:int=512):
|
|
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
|
print("Init T5")
|
|
T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
|
|
pt_1 = fetch(urls["T5_1_of_2"])
|
|
pt_2 = fetch(urls["T5_2_of_2"])
|
|
load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
|
|
return T5
|
|
|
|
def load_clip():
|
|
print("Init Clip")
|
|
clip = ClipEmbedder()
|
|
load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
|
|
return clip
|
|
|
|
def load_ae() -> AutoEncoder:
|
|
# Loading the autoencoder
|
|
print("Init AE")
|
|
ae = AutoEncoder(0.3611, 0.1159)
|
|
load_state_dict(ae, safe_load(fetch(urls["ae"])))
|
|
return ae
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
|
|
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
|
|
bs, _, h, w = img.shape
|
|
if bs == 1 and not isinstance(prompt, str):
|
|
bs = len(prompt)
|
|
|
|
img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
if img.shape[0] == 1 and bs > 1:
|
|
img = img.expand((bs, *img.shape[1:]))
|
|
|
|
img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
|
|
img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
|
|
img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
|
|
img_ids = img_ids.expand((bs, *img_ids.shape[1:]))
|
|
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
txt = T5(prompt).realize()
|
|
if txt.shape[0] == 1 and bs > 1:
|
|
txt = txt.expand((bs, *txt.shape[1:]))
|
|
txt_ids = Tensor.zeros(bs, txt.shape[1], 3)
|
|
|
|
vec = clip(prompt).realize()
|
|
if vec.shape[0] == 1 and bs > 1:
|
|
vec = vec.expand((bs, *vec.shape[1:]))
|
|
|
|
return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}
|
|
|
|
|
|
def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
|
|
# extra step for zero
|
|
step_size = -1.0 / num_steps
|
|
timesteps = Tensor.arange(1, 0 + step_size, step_size)
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
if shift:
|
|
# estimate mu based on linear estimation between two points
|
|
mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
|
|
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
|
|
return timesteps.tolist()
|
|
|
|
@TinyJit
|
|
def run(model, *args): return model(*args).realize()
|
|
|
|
def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
|
|
# this is ignored for schnell
|
|
guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
|
|
t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
|
pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
return img
|
|
|
|
def unpack(x:Tensor, height:int, width:int) -> Tensor:
|
|
return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)
|
|
|
|
# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
|
|
if __name__ == "__main__":
|
|
default_prompt = "bananas and a can of coke"
|
|
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
|
|
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
|
|
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
|
|
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
|
|
parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling")
|
|
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
|
parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
|
|
parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation")
|
|
parser.add_argument("--output_dir", type=str, default="output", help="output directory")
|
|
args = parser.parse_args()
|
|
|
|
if args.name not in ["flux-schnell", "flux-dev"]:
|
|
raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")
|
|
|
|
if args.num_steps is None:
|
|
args.num_steps = 4 if args.name == "flux-schnell" else 50
|
|
|
|
# allow for packing and conversion to latent space
|
|
height = 16 * (args.height // 16)
|
|
width = 16 * (args.width // 16)
|
|
|
|
if args.seed is None: args.seed = Tensor._seed
|
|
else: Tensor.manual_seed(args.seed)
|
|
|
|
print(f"Generating with seed {args.seed}:\n{args.prompt}")
|
|
t0 = time.perf_counter()
|
|
|
|
# prepare input noise
|
|
x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")
|
|
|
|
# load text embedders
|
|
T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
|
|
clip = load_clip()
|
|
|
|
# embed text to get inputs for model
|
|
inp = prepare(T5, clip, x, prompt=args.prompt)
|
|
timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))
|
|
|
|
# done with text embedders
|
|
del T5, clip
|
|
|
|
# load model
|
|
model = load_flow_model(args.name)
|
|
|
|
# denoise initial noise
|
|
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
|
|
|
|
# done with model
|
|
del model, run
|
|
|
|
# load autoencoder
|
|
ae = load_ae()
|
|
|
|
# decode latents to pixel space
|
|
x = unpack(x.float(), height, width)
|
|
x = ae.decode(x).realize()
|
|
|
|
t1 = time.perf_counter()
|
|
print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")
|
|
|
|
# bring into PIL format and save
|
|
x = x.clamp(-1, 1)
|
|
x = x[0].rearrange("c h w -> h w c")
|
|
x = (127.5 * (x + 1.0)).cast("uint8")
|
|
|
|
img = Image.fromarray(x.numpy())
|
|
|
|
img.save(args.out)
|
|
|
|
# validation!
|
|
if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
|
|
ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
|
|
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
|
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
|
|
print(colored(f"output validated with {distance=}", "green")) |