mirror of https://github.com/commaai/tinygrad.git
fix tests hopefully, more stable diffusion
This commit is contained in:
parent
c01a8c5c2d
commit
4dadd95e3c
|
@ -3,6 +3,7 @@
|
|||
# this is sd-v1-4.ckpt
|
||||
FILENAME = "/Users/kafka/fun/mps/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
import numpy as np
|
||||
from extra.utils import fake_torch_load_zipped, get_child
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.tensor import Tensor
|
||||
|
@ -14,9 +15,9 @@ class Normalize:
|
|||
self.weight = Tensor.uniform(in_channels)
|
||||
self.bias = Tensor.uniform(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
def __call__(self, x):
|
||||
# TODO: write groupnorm
|
||||
pass
|
||||
return x
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
|
@ -26,24 +27,33 @@ class AttnBlock:
|
|||
self.v = Conv2d(in_channels, in_channels, 1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
# TODO: write attention
|
||||
print("attention:", x.shape)
|
||||
return x
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3)
|
||||
if in_channels != out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.conv1(self.norm1(x).swish())
|
||||
h = self.conv2(self.norm2(h).swish())
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
class Decoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(4,512,3)
|
||||
self.conv_in = Conv2d(4,512,3, padding=1)
|
||||
|
||||
arr = []
|
||||
for i,s in enumerate(sz):
|
||||
x = {}
|
||||
x['upsample'] = {"conv": Conv2d(s[0], s[0], 3, stride=2, padding=(0,1,0,1))}
|
||||
if i != 0: x['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
x['block'] = [ResnetBlock(s[1], s[0]),
|
||||
ResnetBlock(s[0], s[0]),
|
||||
ResnetBlock(s[0], s[0])]
|
||||
|
@ -58,12 +68,33 @@ class Decoder:
|
|||
}
|
||||
|
||||
self.norm_out = Normalize(128)
|
||||
self.conv_out = Conv2d(128, 3, 3)
|
||||
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
x = self.mid['block_1'](x)
|
||||
x = self.mid['attn_1'](x)
|
||||
x = self.mid['block_2'](x)
|
||||
|
||||
for l in self.up[::-1]:
|
||||
print("decode", x.shape)
|
||||
for b in l['block']: x = b(x)
|
||||
if 'upsample' in l:
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
||||
bs,c,py,px = x.shape
|
||||
x = x.reshape((bs, c, py, 1, px, 1))
|
||||
x = x.expand((bs, c, py, 2, px, 2))
|
||||
x = x.reshape((bs, c, py*2, px*2))
|
||||
x = l['upsample']['conv'](x)
|
||||
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, decode=False):
|
||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(3,128,3)
|
||||
self.conv_in = Conv2d(3,128,3, padding=1)
|
||||
|
||||
arr = []
|
||||
for i,s in enumerate(sz):
|
||||
|
@ -81,7 +112,21 @@ class Encoder:
|
|||
}
|
||||
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = Conv2d(block_in, 8, 3)
|
||||
self.conv_out = Conv2d(block_in, 8, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
for l in self.down:
|
||||
print("encode", x.shape)
|
||||
for b in l['block']: x = b(x)
|
||||
if 'downsample' in l: x = l['downsample']['conv'](x)
|
||||
|
||||
x = self.mid['block_1'](x)
|
||||
x = self.mid['attn_1'](x)
|
||||
x = self.mid['block_2'](x)
|
||||
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class AutoencoderKL:
|
||||
def __init__(self):
|
||||
|
@ -90,10 +135,21 @@ class AutoencoderKL:
|
|||
self.quant_conv = Conv2d(8, 8, 1)
|
||||
self.post_quant_conv = Conv2d(4, 4, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
latent = self.encoder(x)
|
||||
latent = self.quant_conv(latent)
|
||||
latent = latent[:, 0:4] # only the means
|
||||
print("latent", latent.shape)
|
||||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
|
||||
def __call__(self, x):
|
||||
return self.first_stage_model(x)
|
||||
|
||||
model = StableDiffusion()
|
||||
|
||||
for k,v in dat['state_dict'].items():
|
||||
|
@ -106,6 +162,18 @@ for k,v in dat['state_dict'].items():
|
|||
assert w.shape == v.shape
|
||||
|
||||
|
||||
IMG = "/Users/kafka/fun/mps/stable-diffusion/outputs/txt2img-samples/grid-0006.png"
|
||||
from PIL import Image
|
||||
img = Tensor(np.array(Image.open(IMG))).permute((2,0,1)).reshape((1,3,512,512))
|
||||
print(img.shape)
|
||||
x = model(img)
|
||||
print(x.shape)
|
||||
x = x[0]
|
||||
print(x.shape)
|
||||
|
||||
dat = x.numpy()
|
||||
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL
|
||||
# 3x512x512 <--> 4x64x64 (16384)
|
||||
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
||||
|
|
|
@ -124,8 +124,7 @@ def fake_torch_load(b0):
|
|||
assert ll == obj_size
|
||||
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
|
||||
mydat = fb0.read(ll * bytes_size)
|
||||
np_array[:] = np.frombuffer(mydat, storage_type)
|
||||
np_array.shape = np_shape
|
||||
np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape))
|
||||
|
||||
# numpy stores its strides in bytes
|
||||
real_strides = tuple([x*bytes_size for x in np_strides])
|
||||
|
|
|
@ -261,9 +261,10 @@ class Tensor:
|
|||
# ***** functional nn ops *****
|
||||
|
||||
# TODO: fix the kwargs problem, then remove these
|
||||
# NOTE: perhaps don't, since they create NOOPs if the shape already matches
|
||||
# NOTE: perhaps don't, since they create NOOPs if the shape already matches. (now lazy should do this)
|
||||
def reshape(self, shape): return self._reshape(shape=shape) if tuple(self.shape) != tuple(shape) else self
|
||||
def expand(self, shape): return self._expand(shape=shape) if tuple(self.shape) != tuple(shape) else self
|
||||
def permute(self, order): return self._permute(order=order)
|
||||
|
||||
def linear(self, weight:Tensor, bias:Tensor):
|
||||
shp = [1] * (len(self.shape)-1) + [-1]
|
||||
|
|
Loading…
Reference in New Issue