mirror of https://github.com/commaai/tinygrad.git
295 lines
11 KiB
Python
295 lines
11 KiB
Python
# https://arxiv.org/pdf/2112.10752.pdf
|
|
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
|
import tempfile
|
|
from pathlib import Path
|
|
import argparse
|
|
from collections import namedtuple
|
|
from typing import Dict, Any
|
|
|
|
from PIL import Image
|
|
import numpy as np
|
|
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
|
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
|
from tinygrad.nn import Conv2d, GroupNorm
|
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
|
from extra.models.clip import Closed, Tokenizer
|
|
from extra.models.unet import UNetModel
|
|
|
|
class AttnBlock:
|
|
def __init__(self, in_channels):
|
|
self.norm = GroupNorm(32, in_channels)
|
|
self.q = Conv2d(in_channels, in_channels, 1)
|
|
self.k = Conv2d(in_channels, in_channels, 1)
|
|
self.v = Conv2d(in_channels, in_channels, 1)
|
|
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
|
|
|
# copied from AttnBlock in ldm repo
|
|
def __call__(self, x):
|
|
h_ = self.norm(x)
|
|
q,k,v = self.q(h_), self.k(h_), self.v(h_)
|
|
|
|
# compute attention
|
|
b,c,h,w = q.shape
|
|
q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)]
|
|
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
|
|
return x + self.proj_out(h_)
|
|
|
|
class ResnetBlock:
|
|
def __init__(self, in_channels, out_channels=None):
|
|
self.norm1 = GroupNorm(32, in_channels)
|
|
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
|
self.norm2 = GroupNorm(32, out_channels)
|
|
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
|
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
|
|
|
def __call__(self, x):
|
|
h = self.conv1(self.norm1(x).swish())
|
|
h = self.conv2(self.norm2(h).swish())
|
|
return self.nin_shortcut(x) + h
|
|
|
|
class Mid:
|
|
def __init__(self, block_in):
|
|
self.block_1 = ResnetBlock(block_in, block_in)
|
|
self.attn_1 = AttnBlock(block_in)
|
|
self.block_2 = ResnetBlock(block_in, block_in)
|
|
|
|
def __call__(self, x):
|
|
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
|
|
|
class Decoder:
|
|
def __init__(self):
|
|
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
|
self.conv_in = Conv2d(4,512,3, padding=1)
|
|
self.mid = Mid(512)
|
|
|
|
arr = []
|
|
for i,s in enumerate(sz):
|
|
arr.append({"block":
|
|
[ResnetBlock(s[1], s[0]),
|
|
ResnetBlock(s[0], s[0]),
|
|
ResnetBlock(s[0], s[0])]})
|
|
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
|
self.up = arr
|
|
|
|
self.norm_out = GroupNorm(32, 128)
|
|
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
|
|
|
def __call__(self, x):
|
|
x = self.conv_in(x)
|
|
x = self.mid(x)
|
|
|
|
for l in self.up[::-1]:
|
|
print("decode", x.shape)
|
|
for b in l['block']: x = b(x)
|
|
if 'upsample' in l:
|
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
|
bs,c,py,px = x.shape
|
|
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
|
x = l['upsample']['conv'](x)
|
|
x.realize()
|
|
|
|
return self.conv_out(self.norm_out(x).swish())
|
|
|
|
class Encoder:
|
|
def __init__(self):
|
|
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
|
self.conv_in = Conv2d(3,128,3, padding=1)
|
|
|
|
arr = []
|
|
for i,s in enumerate(sz):
|
|
arr.append({"block":
|
|
[ResnetBlock(s[0], s[1]),
|
|
ResnetBlock(s[1], s[1])]})
|
|
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
|
|
self.down = arr
|
|
|
|
self.mid = Mid(512)
|
|
self.norm_out = GroupNorm(32, 512)
|
|
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
|
|
|
def __call__(self, x):
|
|
x = self.conv_in(x)
|
|
|
|
for l in self.down:
|
|
print("encode", x.shape)
|
|
for b in l['block']: x = b(x)
|
|
if 'downsample' in l: x = l['downsample']['conv'](x)
|
|
|
|
x = self.mid(x)
|
|
return self.conv_out(self.norm_out(x).swish())
|
|
|
|
class AutoencoderKL:
|
|
def __init__(self):
|
|
self.encoder = Encoder()
|
|
self.decoder = Decoder()
|
|
self.quant_conv = Conv2d(8, 8, 1)
|
|
self.post_quant_conv = Conv2d(4, 4, 1)
|
|
|
|
def __call__(self, x):
|
|
latent = self.encoder(x)
|
|
latent = self.quant_conv(latent)
|
|
latent = latent[:, 0:4] # only the means
|
|
print("latent", latent.shape)
|
|
latent = self.post_quant_conv(latent)
|
|
return self.decoder(latent)
|
|
|
|
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
|
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
return Tensor(alphas_cumprod)
|
|
|
|
unet_params: Dict[str,Any] = {
|
|
"adm_in_ch": None,
|
|
"in_ch": 4,
|
|
"out_ch": 4,
|
|
"model_ch": 320,
|
|
"attention_resolutions": [4, 2, 1],
|
|
"num_res_blocks": 2,
|
|
"channel_mult": [1, 2, 4, 4],
|
|
"n_heads": 8,
|
|
"transformer_depth": [1, 1, 1, 1],
|
|
"ctx_dim": 768,
|
|
"use_linear": False,
|
|
}
|
|
|
|
class StableDiffusion:
|
|
def __init__(self):
|
|
self.alphas_cumprod = get_alphas_cumprod()
|
|
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
|
|
self.first_stage_model = AutoencoderKL()
|
|
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
|
|
|
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
|
temperature = 1
|
|
sigma_t = 0
|
|
sqrt_one_minus_at = (1-a_t).sqrt()
|
|
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
|
|
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
|
|
# direction pointing to x_t
|
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
|
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
|
return x_prev, pred_x0
|
|
|
|
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
|
|
# put into diffuser
|
|
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
|
unconditional_latent, latent = latents[0:1], latents[1:2]
|
|
|
|
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
|
return e_t
|
|
|
|
def decode(self, x):
|
|
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
|
|
x = self.first_stage_model.decoder(x)
|
|
|
|
# make image correct size and scale
|
|
x = (x + 1.0) / 2.0
|
|
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
|
|
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
|
|
|
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
|
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
|
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
|
#e_t_next = get_model_output(x_prev)
|
|
#e_t_prime = (e_t + e_t_next) / 2
|
|
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
|
return x_prev.realize()
|
|
|
|
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
|
# 3x512x512 <--> 4x64x64 (16384)
|
|
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
|
# section 4.3 of paper
|
|
# first_stage_model.encoder, first_stage_model.decoder
|
|
|
|
# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
# this is what runs each time to sample. is this the LDM?
|
|
# input: 4x64x64
|
|
# output: 4x64x64
|
|
# model.diffusion_model
|
|
# it has attention?
|
|
|
|
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
# cond_stage_model.transformer.text_model
|
|
|
|
if __name__ == "__main__":
|
|
default_prompt = "a horse sized cat eating a bagel"
|
|
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
|
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render")
|
|
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
|
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
|
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
|
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
|
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
|
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
|
args = parser.parse_args()
|
|
|
|
Tensor.no_grad = True
|
|
model = StableDiffusion()
|
|
|
|
# load in weights
|
|
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
|
|
|
if args.fp16:
|
|
for k,v in get_state_dict(model).items():
|
|
if k.startswith("model"):
|
|
v.replace(v.cast(dtypes.float16).realize())
|
|
|
|
# run through CLIP to get context
|
|
tokenizer = Tokenizer.ClipTokenizer()
|
|
prompt = Tensor([tokenizer.encode(args.prompt)])
|
|
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
|
print("got CLIP context", context.shape)
|
|
|
|
prompt = Tensor([tokenizer.encode("")])
|
|
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
|
print("got unconditional CLIP context", unconditional_context.shape)
|
|
|
|
# done with clip model
|
|
del model.cond_stage_model
|
|
|
|
timesteps = list(range(1, 1000, 1000//args.steps))
|
|
print(f"running for {timesteps} timesteps")
|
|
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
|
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
|
|
|
# start with random noise
|
|
if args.seed is not None: Tensor.manual_seed(args.seed)
|
|
latent = Tensor.randn(1,4,64,64)
|
|
|
|
@TinyJit
|
|
def run(model, *x): return model(*x).realize()
|
|
|
|
# this is diffusion
|
|
with Context(BEAM=getenv("LATEBEAM")):
|
|
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
|
GlobalCounters.reset()
|
|
t.set_description("%3d %3d" % (index, timestep))
|
|
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
|
tid = Tensor([index])
|
|
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
|
if args.timing: Device[Device.DEFAULT].synchronize()
|
|
del run
|
|
|
|
# upsample latent space to image with autoencoder
|
|
x = model.decode(latent)
|
|
print(x.shape)
|
|
|
|
# save image
|
|
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
|
|
print(f"saving {args.out}")
|
|
im.save(args.out)
|
|
# Open image.
|
|
if not args.noshow: im.show()
|
|
|
|
# validation!
|
|
if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5:
|
|
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
|
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
|
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
|
|
print(colored(f"output validated with {distance=}", "green"))
|