Revert "SD: Refactor AttnBlock, CrossAttention, CLIPAttention to share code (#1513)" (#1515)

This reverts commit 85e02311a2.
This commit is contained in:
George Hotz 2023-08-10 09:08:51 -07:00 committed by GitHub
parent 85e02311a2
commit c82bd59b85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 7 deletions

View File

@ -14,6 +14,8 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
class AttnBlock:
def __init__(self, in_channels):
self.norm = GroupNorm(32, in_channels)
@ -29,8 +31,19 @@ class AttnBlock:
# compute attention
b,c,h,w = q.shape
q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)]
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = q @ k
w_ = w_ * (c**(-0.5))
w_ = w_.softmax()
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1)
h_ = v @ w_
h_ = h_.reshape(b,c,h,w)
return x + self.proj_out(h_)
class ResnetBlock:
@ -165,6 +178,7 @@ class CrossAttention:
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
self.scale = d_head ** -0.5
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [Linear(n_heads*d_head, query_dim)]
@ -172,8 +186,14 @@ class CrossAttention:
def __call__(self, x, context=None):
context = x if context is None else context
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
score = q.dot(k) * self.scale
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
return h_.sequential(self.to_out)
@ -342,6 +362,7 @@ class CLIPAttention:
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.k_proj = Linear(self.embed_dim, self.embed_dim)
self.v_proj = Linear(self.embed_dim, self.embed_dim)
self.q_proj = Linear(self.embed_dim, self.embed_dim)
@ -353,7 +374,7 @@ class CLIPAttention:
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
@ -363,8 +384,15 @@ class CLIPAttention:
src_len = key_states.shape[1]
value_states = value_states.reshape(*proj_shape)
causal_attention_mask = causal_attention_mask.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask)
attn_weights = query_states @ key_states.permute(0,2,1)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.softmax()
attn_output = attn_weights @ value_states
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.permute(0,2,1,3)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)