mirror of https://github.com/commaai/tinygrad.git
This reverts commit 85e02311a2
.
This commit is contained in:
parent
85e02311a2
commit
c82bd59b85
|
@ -14,6 +14,8 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
|||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
|
||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
self.norm = GroupNorm(32, in_channels)
|
||||
|
@ -29,8 +31,19 @@ class AttnBlock:
|
|||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)]
|
||||
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = q @ k
|
||||
w_ = w_ * (c**(-0.5))
|
||||
w_ = w_.softmax()
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1)
|
||||
h_ = v @ w_
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
class ResnetBlock:
|
||||
|
@ -165,6 +178,7 @@ class CrossAttention:
|
|||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.scale = d_head ** -0.5
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
@ -172,8 +186,14 @@ class CrossAttention:
|
|||
def __call__(self, x, context=None):
|
||||
context = x if context is None else context
|
||||
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
|
||||
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
|
||||
score = q.dot(k) * self.scale
|
||||
weights = score.softmax() # (bs, num_heads, time, time)
|
||||
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
|
||||
|
||||
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
|
@ -342,6 +362,7 @@ class CLIPAttention:
|
|||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
@ -353,7 +374,7 @@ class CLIPAttention:
|
|||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
|
@ -363,8 +384,15 @@ class CLIPAttention:
|
|||
src_len = key_states.shape[1]
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
causal_attention_mask = causal_attention_mask.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
attn_output = Tensor.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_attention_mask)
|
||||
attn_weights = query_states @ key_states.permute(0,2,1)
|
||||
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = attn_weights.softmax()
|
||||
|
||||
attn_output = attn_weights @ value_states
|
||||
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.permute(0,2,1,3)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
|
Loading…
Reference in New Issue