added dot affine

This commit is contained in:
George Hotz 2021-11-29 12:55:56 -05:00
parent 30eb3afbe1
commit c7f795ca1e
2 changed files with 12 additions and 9 deletions

View File

@ -18,14 +18,14 @@ class TransformerBlock:
assert self.head_size * self.num_heads == embed_dim
# added bias
self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.final = Tensor.uniform(embed_dim, embed_dim)
self.final = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.ff1 = Tensor.uniform(embed_dim, ff_dim)
self.ff2 = Tensor.uniform(ff_dim, embed_dim)
self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
def __call__(self, x):
# bs x T x embed_dim
@ -34,7 +34,7 @@ class TransformerBlock:
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
query, key, value = [inputs.dot(y[0]).add(y[1].reshape(shape=[1, -1])) \
query, key, value = [inputs.affine(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
@ -46,9 +46,9 @@ class TransformerBlock:
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
x = inputs + attention.reshape(shape=(-1, embed_dim)).dot(self.final).dropout(0.1)
x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1)
x = layernorm(x, embed_dim)
x = x + x.dot(self.ff1).relu().dot(self.ff2).dropout(0.1)
x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1)
x = layernorm(x, embed_dim)
return x.reshape(shape=(bs, -1, embed_dim))

View File

@ -275,6 +275,9 @@ class Tensor:
def max_pool2d(self, kernel_size=(2,2)):
return self._pool2d(*kernel_size).max(axis=(3,5))
def affine(self, params):
return self.dot(params[0]).add(params[1].reshape(shape=[1, -1]))
# An instantiation of the Function is the Context
class Function:
def __new__(cls, *args, **kwargs):