tinygrad/examples/sdxl.py

429 lines
18 KiB
Python

# This file incorporates code from the following:
# Github Name | License | Link
# Stability-AI/generative-models | MIT | https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/LICENSE-CODE
# mlfoundations/open_clip | MIT | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE
from tinygrad import Tensor, TinyJit, dtypes
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, trange, colored, Timing, GlobalCounters
from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
from examples.stable_diffusion import ResnetBlock, Mid
import numpy as np
from typing import Dict, List, Callable, Optional, Any, Set, Tuple
import argparse, tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from PIL import Image
# configs:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_base.yaml
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml
configs: Dict = {
"SDXL_Base": {
"model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True},
"conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]},
"first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
"denoiser": {"num_idx": 1000},
},
"SDXL_Refiner": {
"model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True},
"conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]},
"first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
"denoiser": {"num_idx": 1000},
}
}
def tensor_identity(x:Tensor) -> Tensor:
return x
class DiffusionModel:
def __init__(self, *args, **kwargs):
self.diffusion_model = UNetModel(*args, **kwargs)
class Embedder(ABC):
input_key: str
@abstractmethod
def __call__(self, x:Tensor) -> Tensor:
pass
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
class ConcatTimestepEmbedderND(Embedder):
def __init__(self, outdim:int, input_key:str):
self.outdim = outdim
self.input_key = input_key
def __call__(self, x:Tensor):
assert len(x.shape) == 2
emb = timestep_embedding(x.flatten(), self.outdim)
emb = emb.reshape((x.shape[0],-1))
return emb
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L71
class Conditioner:
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
embedders: List[Embedder]
def __init__(self, concat_embedders:List[str]):
self.embedders = [
FrozenClosedClipEmbedder(ret_layer_idx=11),
FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True),
]
for input_key in concat_embedders:
self.embedders.append(ConcatTimestepEmbedderND(256, input_key))
def get_keys(self) -> Set[str]:
return set(e.input_key for e in self.embedders)
def __call__(self, batch:Dict, force_zero_embeddings:List=[]) -> Dict[str,Tensor]:
output: Dict[str,Tensor] = {}
for embedder in self.embedders:
emb_out = embedder(batch[embedder.input_key])
if isinstance(emb_out, Tensor):
emb_out = [emb_out]
else:
assert isinstance(emb_out, (list, tuple))
for emb in emb_out:
if embedder.input_key in force_zero_embeddings:
emb = Tensor.zeros_like(emb)
out_key = self.OUTPUT_DIM2KEYS[len(emb.shape)]
if out_key in output:
output[out_key] = Tensor.cat(output[out_key], emb, dim=self.KEY2CATDIM[out_key])
else:
output[out_key] = emb
return output
class FirstStage:
"""
Namespace for First Stage Model components
"""
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487
class Encoder:
def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
self.conv_in = Conv2d(in_ch, ch, kernel_size=3, stride=1, padding=1)
in_ch_mult = (1,) + tuple(ch_mult)
class BlockEntry:
def __init__(self, block:List[ResnetBlock], downsample):
self.block = block
self.downsample = downsample
self.down: List[BlockEntry] = []
for i_level in range(len(ch_mult)):
block = []
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult [i_level]
for _ in range(num_res_blocks):
block.append(ResnetBlock(block_in, block_out))
block_in = block_out
downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in)
self.down.append(BlockEntry(block, downsample))
self.mid = Mid(block_in)
self.norm_out = GroupNorm(32, block_in)
self.conv_out = Conv2d(block_in, 2*z_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, x:Tensor) -> Tensor:
h = self.conv_in(x)
for down in self.down:
for block in down.block:
h = block(h)
h = down.downsample(h)
h = h.sequential([self.mid.block_1, self.mid.attn_1, self.mid.block_2])
h = h.sequential([self.norm_out, Tensor.swish, self.conv_out ])
return h
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L604
class Decoder:
def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
block_in = ch * ch_mult[-1]
curr_res = resolution // 2 ** (len(ch_mult) - 1)
self.z_shape = (1, z_ch, curr_res, curr_res)
self.conv_in = Conv2d(z_ch, block_in, kernel_size=3, stride=1, padding=1)
self.mid = Mid(block_in)
class BlockEntry:
def __init__(self, block:List[ResnetBlock], upsample:Callable[[Any],Any]):
self.block = block
self.upsample = upsample
self.up: List[BlockEntry] = []
for i_level in reversed(range(len(ch_mult))):
block = []
block_out = ch * ch_mult[i_level]
for _ in range(num_res_blocks + 1):
block.append(ResnetBlock(block_in, block_out))
block_in = block_out
upsample = tensor_identity if i_level == 0 else Upsample(block_in)
self.up.insert(0, BlockEntry(block, upsample)) # type: ignore
self.norm_out = GroupNorm(32, block_in)
self.conv_out = Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z:Tensor) -> Tensor:
h = z.sequential([self.conv_in, self.mid.block_1, self.mid.attn_1, self.mid.block_2])
for up in self.up[::-1]:
for block in up.block:
h = block(h)
h = up.upsample(h)
h = h.sequential([self.norm_out, Tensor.swish, self.conv_out])
return h
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L102
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L437
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L508
class FirstStageModel:
def __init__(self, embed_dim:int=4, **kwargs):
self.encoder = FirstStage.Encoder(**kwargs)
self.decoder = FirstStage.Decoder(**kwargs)
self.quant_conv = Conv2d(2*kwargs["z_ch"], 2*embed_dim, 1)
self.post_quant_conv = Conv2d(embed_dim, kwargs["z_ch"], 1)
def decode(self, z:Tensor) -> Tensor:
return z.sequential([self.post_quant_conv, self.decoder])
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/discretizer.py#L42
class LegacyDDPMDiscretization:
def __init__(self, linear_start:float=0.00085, linear_end:float=0.0120, num_timesteps:int=1000):
self.num_timesteps = num_timesteps
betas = np.linspace(linear_start**0.5, linear_end**0.5, num_timesteps, dtype=np.float32) ** 2
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
def __call__(self, n:int, flip:bool=False) -> Tensor:
if n < self.num_timesteps:
timesteps = np.linspace(self.num_timesteps - 1, 0, n, endpoint=False).astype(int)[::-1]
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
sigmas = Tensor((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
sigmas = Tensor.cat(Tensor.zeros((1,)), sigmas)
return sigmas if flip else sigmas.flip(axis=0) # sigmas is "pre-flipped", need to do oposite of flag
def append_dims(x:Tensor, t:Tensor) -> Tensor:
dims_to_append = len(t.shape) - len(x.shape)
assert dims_to_append >= 0
return x.reshape(x.shape + (1,)*dims_to_append)
@TinyJit
def run(model, x, tms, ctx, y, c_out, add):
return (model(x, tms, ctx, y)*c_out + add).realize()
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/diffusion.py#L19
class SDXL:
def __init__(self, config:Dict):
self.conditioner = Conditioner(**config["conditioner"])
self.first_stage_model = FirstStageModel(**config["first_stage_model"])
self.model = DiffusionModel(**config["model"])
self.discretization = LegacyDDPMDiscretization()
self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
batch_c : Dict = {
"txt": pos_prompt,
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
}
batch_uc: Dict = {
"txt": "",
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
}
return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
def sigma_to_idx(s:Tensor) -> Tensor:
dists = s - self.sigmas.unsqueeze(1)
return dists.abs().argmin(axis=0).view(*s.shape)
sigma = self.sigmas[sigma_to_idx(sigma)]
sigma_shape = sigma.shape
sigma = append_dims(sigma, x)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
def prep(*tensors:Tensor):
return tuple(t.cast(dtypes.float16).realize() for t in tensors)
return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x))
def decode(self, x:Tensor) -> Tensor:
return self.first_stage_model.decode(1.0 / 0.13025 * x)
class VanillaCFG:
def __init__(self, scale:float):
self.scale = scale
def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
c_out = {}
for k in c:
assert k in ["vector", "crossattn", "concat"]
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
def __call__(self, x:Tensor, sigma:float) -> Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
class DPMPP2MSampler:
def __init__(self, cfg_scale:float):
self.discretization = LegacyDDPMDiscretization()
self.guider = VanillaCFG(cfg_scale)
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
denoised = self.guider(denoised, sigma)
t, t_next = sigma.log().neg(), next_sigma.log().neg()
h = t_next - t
r = None if prev_sigma is None else (t - prev_sigma.log().neg()) / h
mults = [t_next.neg().exp()/t.neg().exp(), (-h).exp().sub(1)]
if r is not None:
mults.extend([1 + 1/(2*r), 1/(2*r)])
mults = [append_dims(m, x) for m in mults]
x_standard = mults[0]*x - mults[1]*denoised
if (old_denoised is None) or (next_sigma.sum().numpy().item() < 1e-14):
return x_standard, denoised
denoised_d = mults[2]*denoised - mults[3]*old_denoised
x_advanced = mults[0]*x - mults[1]*denoised_d
x = Tensor.where(append_dims(next_sigma, x) > 0.0, x_advanced, x_standard)
return x, denoised
def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
sigmas = self.discretization(num_steps)
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
old_denoised = None
for i in trange(num_sigmas - 1):
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
x, old_denoised = self.sampler_step(
old_denoised=old_denoised,
prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
sigma=sigmas[i].reshape(x.shape[0]),
next_sigma=sigmas[i+1].reshape(x.shape[0]),
denoiser=denoiser,
x=x,
c=c,
uc=uc,
)
x.realize()
old_denoised.realize()
return x
if __name__ == "__main__":
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
parser.add_argument('--width', type=int, default=1024, help="The output image width")
parser.add_argument('--height', type=int, default=1024, help="The output image height")
parser.add_argument('--weights', type=str, help="Custom path to weights")
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
args = parser.parse_args()
Tensor.no_grad = True
if args.seed is not None:
Tensor.manual_seed(args.seed)
model = SDXL(configs["SDXL_Base"])
default_weight_url = 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'
weights = args.weights if args.weights else fetch(default_weight_url, 'sd_xl_base_1.0.safetensors')
load_state_dict(model, safe_load(weights), strict=False)
N = 1
C = 4
F = 8
assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
c, uc = model.create_conditioning(args.prompt, args.width, args.height)
del model.conditioner
for v in c .values(): v.realize()
for v in uc.values(): v.realize()
print("created batch")
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L101
shape = (N, C, args.height // F, args.width // F)
randn = Tensor.randn(shape)
sampler = DPMPP2MSampler(args.guidance)
z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
print("created samples")
x = model.decode(z).realize()
print("decoded samples")
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,args.height,args.width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
print(x.shape)
im = Image.fromarray(x.numpy())
print(f"saving {args.out}")
im.save(args.out)
if not args.noshow:
im.show()
# validation!
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
and not args.weights:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))