From 0c3a35e5c2a05e17d0b429aa0d92852e9db36e97 Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Wed, 3 Jul 2024 22:47:10 -0400 Subject: [PATCH] Stable Diffusion v2 Inference (#5283) * model implementation * clip fix, more qol options --- examples/sdv2.py | 147 +++++++++++++++++++++++++++++++++++++++++++ examples/sdxl.py | 2 +- extra/models/clip.py | 6 +- 3 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 examples/sdv2.py diff --git a/examples/sdv2.py b/examples/sdv2.py new file mode 100644 index 00000000..a3f4d30b --- /dev/null +++ b/examples/sdv2.py @@ -0,0 +1,147 @@ +from tinygrad import Tensor, dtypes, TinyJit +from tinygrad.helpers import fetch +from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict +from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod +from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization +from extra.models.unet import UNetModel +from extra.models.clip import FrozenOpenClipEmbedder + +from typing import Dict +import argparse, tempfile, os +from pathlib import Path +from PIL import Image + +class DiffusionModel: + def __init__(self, model:UNetModel): + self.diffusion_model = model + +@TinyJit +def run(model, x, tms, ctx, c_out, add): + return (model(x, tms, ctx)*c_out + add).realize() + +# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521 +class StableDiffusionV2: + def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"): + self.model = DiffusionModel(UNetModel(**unet_config)) + self.first_stage_model = AutoencoderKL() + self.cond_stage_model = FrozenOpenClipEmbedder(**cond_stage_config) + self.alphas_cumprod = get_alphas_cumprod() + self.parameterization = parameterization + + self.discretization = LegacyDDPMDiscretization() + self.sigmas = self.discretization(1000, flip=True) + + def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor: + + def sigma_to_idx(s:Tensor) -> Tensor: + dists = s - self.sigmas.unsqueeze(1) + return dists.abs().argmin(axis=0).view(*s.shape) + + sigma = self.sigmas[sigma_to_idx(sigma)] + sigma_shape = sigma.shape + sigma = append_dims(sigma, x) + + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma_to_idx(sigma.reshape(sigma_shape)) + + def prep(*tensors:Tensor): + return tuple(t.cast(dtypes.float16).realize() for t in tensors) + + return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip)) + + def decode(self, x:Tensor, height:int, width:int) -> Tensor: + x = self.first_stage_model.post_quant_conv(1/0.18215 * x) + x = self.first_stage_model.decoder(x) + + # make image correct size and scale + x = (x + 1.0) / 2.0 + x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8) + return x + +params: Dict = { + "unet_config": { + "adm_in_ch": None, + "in_ch": 4, + "out_ch": 4, + "model_ch": 320, + "attention_resolutions": [4, 2, 1], + "num_res_blocks": 2, + "channel_mult": [1, 2, 4, 4], + "d_head": 64, + "transformer_depth": [1, 1, 1, 1], + "ctx_dim": 1024, + "use_linear": True, + }, + "cond_stage_config": { + "dims": 1024, + "n_heads": 16, + "layers": 24, + "return_pooled": False, + "ln_penultimate": True, + } +} + +if __name__ == "__main__": + default_prompt = "a horse sized cat eating a bagel" + parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps") + parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate") + parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") + parser.add_argument('--seed', type=int, help="Set the random latent seed") + parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") + parser.add_argument('--width', type=int, default=768, help="The output image width") + parser.add_argument('--height', type=int, default=768, help="The output image height") + parser.add_argument('--weights-fn', type=str, help="Filename of weights to use") + parser.add_argument('--weights-url', type=str, help="Custom URL to download weights from") + parser.add_argument('--timing', action='store_true', help="Print timing per step") + parser.add_argument('--noshow', action='store_true', help="Don't show the image") + parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + args = parser.parse_args() + + N = 1 + C = 4 + F = 8 + assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}" + assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}" + + Tensor.no_grad = True + if args.seed is not None: + Tensor.manual_seed(args.seed) + + model = StableDiffusionV2(**params) + + default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' + weights_fn = args.weights_fn + if not weights_fn: + weights_url = args.weights_url if args.weights_url else default_weights_url + weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) + load_state_dict(model, safe_load(weights_fn), strict=False) + + if args.fp16: + for k,v in get_state_dict(model).items(): + if k.startswith("model"): + v.replace(v.cast(dtypes.float16).realize()) + + c = { "crossattn": model.cond_stage_model(args.prompt) } + uc = { "crossattn": model.cond_stage_model("") } + del model.cond_stage_model + print("created conditioning") + + shape = (N, C, args.height // F, args.width // F) + randn = Tensor.randn(shape) + + sampler = DPMPP2MSampler(args.guidance) + z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing) + print("created samples") + x = model.decode(z, args.height, args.width).realize() + print("decoded samples") + print(x.shape) + + im = Image.fromarray(x.numpy()) + print(f"saving {args.out}") + im.save(args.out) + + if not args.noshow: + im.show() diff --git a/examples/sdxl.py b/examples/sdxl.py index a9d753fa..b6ea7439 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -370,7 +370,7 @@ if __name__ == "__main__": parser.add_argument('--width', type=int, default=1024, help="The output image width") parser.add_argument('--height', type=int, default=1024, help="The output image height") parser.add_argument('--weights', type=str, help="Custom path to weights") - parser.add_argument('--timing', action='store_true', help="Print timing per step") + parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--noshow', action='store_true', help="Don't show the image") args = parser.parse_args() diff --git a/extra/models/clip.py b/extra/models/clip.py index 0ab70e0b..bdcf0eee 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -323,11 +323,12 @@ class Open: # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498 class FrozenOpenClipEmbedder(Embedder): - def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool): + def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False): self.tokenizer = Tokenizer.ClipTokenizer() self.model = Open.ClipTextTransformer(dims, n_heads, layers) self.return_pooled = return_pooled self.input_key = "txt" + self.ln_penultimate = ln_penultimate def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): for r in self.model.transformer.resblocks: @@ -340,6 +341,9 @@ class FrozenOpenClipEmbedder(Embedder): x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2) x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + if self.ln_penultimate: + penultimate = self.model.ln_final(penultimate) + if self.return_pooled: x = self.model.ln_final(x) pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection