mirror of https://github.com/commaai/tinygrad.git
Stable Diffusion v2 Inference (#5283)
* model implementation * clip fix, more qol options
This commit is contained in:
parent
e5ba385f03
commit
0c3a35e5c2
|
@ -0,0 +1,147 @@
|
|||
from tinygrad import Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict
|
||||
from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod
|
||||
from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization
|
||||
from extra.models.unet import UNetModel
|
||||
from extra.models.clip import FrozenOpenClipEmbedder
|
||||
|
||||
from typing import Dict
|
||||
import argparse, tempfile, os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
class DiffusionModel:
|
||||
def __init__(self, model:UNetModel):
|
||||
self.diffusion_model = model
|
||||
|
||||
@TinyJit
|
||||
def run(model, x, tms, ctx, c_out, add):
|
||||
return (model(x, tms, ctx)*c_out + add).realize()
|
||||
|
||||
# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521
|
||||
class StableDiffusionV2:
|
||||
def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"):
|
||||
self.model = DiffusionModel(UNetModel(**unet_config))
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = FrozenOpenClipEmbedder(**cond_stage_config)
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.parameterization = parameterization
|
||||
|
||||
self.discretization = LegacyDDPMDiscretization()
|
||||
self.sigmas = self.discretization(1000, flip=True)
|
||||
|
||||
def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
|
||||
|
||||
def sigma_to_idx(s:Tensor) -> Tensor:
|
||||
dists = s - self.sigmas.unsqueeze(1)
|
||||
return dists.abs().argmin(axis=0).view(*s.shape)
|
||||
|
||||
sigma = self.sigmas[sigma_to_idx(sigma)]
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, x)
|
||||
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
|
||||
|
||||
def prep(*tensors:Tensor):
|
||||
return tuple(t.cast(dtypes.float16).realize() for t in tensors)
|
||||
|
||||
return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip))
|
||||
|
||||
def decode(self, x:Tensor, height:int, width:int) -> Tensor:
|
||||
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
|
||||
x = self.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
|
||||
return x
|
||||
|
||||
params: Dict = {
|
||||
"unet_config": {
|
||||
"adm_in_ch": None,
|
||||
"in_ch": 4,
|
||||
"out_ch": 4,
|
||||
"model_ch": 320,
|
||||
"attention_resolutions": [4, 2, 1],
|
||||
"num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4],
|
||||
"d_head": 64,
|
||||
"transformer_depth": [1, 1, 1, 1],
|
||||
"ctx_dim": 1024,
|
||||
"use_linear": True,
|
||||
},
|
||||
"cond_stage_config": {
|
||||
"dims": 1024,
|
||||
"n_heads": 16,
|
||||
"layers": 24,
|
||||
"return_pooled": False,
|
||||
"ln_penultimate": True,
|
||||
}
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
default_prompt = "a horse sized cat eating a bagel"
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
|
||||
parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||
parser.add_argument('--width', type=int, default=768, help="The output image width")
|
||||
parser.add_argument('--height', type=int, default=768, help="The output image height")
|
||||
parser.add_argument('--weights-fn', type=str, help="Filename of weights to use")
|
||||
parser.add_argument('--weights-url', type=str, help="Custom URL to download weights from")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
args = parser.parse_args()
|
||||
|
||||
N = 1
|
||||
C = 4
|
||||
F = 8
|
||||
assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
|
||||
assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
|
||||
|
||||
Tensor.no_grad = True
|
||||
if args.seed is not None:
|
||||
Tensor.manual_seed(args.seed)
|
||||
|
||||
model = StableDiffusionV2(**params)
|
||||
|
||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||
weights_fn = args.weights_fn
|
||||
if not weights_fn:
|
||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
if k.startswith("model"):
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
|
||||
c = { "crossattn": model.cond_stage_model(args.prompt) }
|
||||
uc = { "crossattn": model.cond_stage_model("") }
|
||||
del model.cond_stage_model
|
||||
print("created conditioning")
|
||||
|
||||
shape = (N, C, args.height // F, args.width // F)
|
||||
randn = Tensor.randn(shape)
|
||||
|
||||
sampler = DPMPP2MSampler(args.guidance)
|
||||
z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
|
||||
print("created samples")
|
||||
x = model.decode(z, args.height, args.width).realize()
|
||||
print("decoded samples")
|
||||
print(x.shape)
|
||||
|
||||
im = Image.fromarray(x.numpy())
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
|
||||
if not args.noshow:
|
||||
im.show()
|
|
@ -370,7 +370,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--width', type=int, default=1024, help="The output image width")
|
||||
parser.add_argument('--height', type=int, default=1024, help="The output image height")
|
||||
parser.add_argument('--weights', type=str, help="Custom path to weights")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -323,11 +323,12 @@ class Open:
|
|||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
||||
self.return_pooled = return_pooled
|
||||
self.input_key = "txt"
|
||||
self.ln_penultimate = ln_penultimate
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
|
@ -340,6 +341,9 @@ class FrozenOpenClipEmbedder(Embedder):
|
|||
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
||||
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
|
||||
if self.ln_penultimate:
|
||||
penultimate = self.model.ln_final(penultimate)
|
||||
|
||||
if self.return_pooled:
|
||||
x = self.model.ln_final(x)
|
||||
pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection
|
||||
|
|
Loading…
Reference in New Issue