From c7f795ca1e0e9ff8ecf112ad8b2a6db87c806eea Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 29 Nov 2021 12:55:56 -0500 Subject: [PATCH] added dot affine --- models/transformer.py | 18 +++++++++--------- tinygrad/tensor.py | 3 +++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/models/transformer.py b/models/transformer.py index 2f5de17f..9a114920 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -18,14 +18,14 @@ class TransformerBlock: assert self.head_size * self.num_heads == embed_dim # added bias - self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim)) - self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim)) - self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim)) + self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) + self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) + self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) - self.final = Tensor.uniform(embed_dim, embed_dim) + self.final = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) - self.ff1 = Tensor.uniform(embed_dim, ff_dim) - self.ff2 = Tensor.uniform(ff_dim, embed_dim) + self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim)) + self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim)) def __call__(self, x): # bs x T x embed_dim @@ -34,7 +34,7 @@ class TransformerBlock: inputs = x.reshape(shape=(-1, embed_dim)) # run multi head attention (bs, T, num_heads, head_size) - query, key, value = [inputs.dot(y[0]).add(y[1].reshape(shape=[1, -1])) \ + query, key, value = [inputs.affine(y) \ .reshape(shape=(bs, -1, self.num_heads, self.head_size)) \ for y in [self.query_dense, self.key_dense, self.value_dense]] @@ -46,9 +46,9 @@ class TransformerBlock: weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - x = inputs + attention.reshape(shape=(-1, embed_dim)).dot(self.final).dropout(0.1) + x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1) x = layernorm(x, embed_dim) - x = x + x.dot(self.ff1).relu().dot(self.ff2).dropout(0.1) + x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1) x = layernorm(x, embed_dim) return x.reshape(shape=(bs, -1, embed_dim)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2beeca15..7810621c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -275,6 +275,9 @@ class Tensor: def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5)) + def affine(self, params): + return self.dot(params[0]).add(params[1].reshape(shape=[1, -1])) + # An instantiation of the Function is the Context class Function: def __new__(cls, *args, **kwargs):