fix onehot embed

This commit is contained in:
George Hotz 2020-12-27 18:50:38 -05:00
parent d864e1c71a
commit 65b07d2f4f
1 changed files with 5 additions and 6 deletions

View File

@ -59,10 +59,8 @@ class TransformerBlock:
weights = score.logsoftmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3))
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
print(x.shape)
# layernorm
x = x + x.dot(self.ff1).relu().dot(self.ff2)
print(x.shape)
# layernorm
return x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
@ -78,11 +76,12 @@ class Transformer:
def forward(self, x):
bs = x.shape[0]
xnp = x.cpu().data
onehot = np.zeros((bs*x.shape[1], self.maxlen+self.syms), dtype=np.float32)
print(onehot.shape)
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
for i in range(x.shape[1]):
onehot[range(bs*i, bs*(i+1)), i] = 1
onehot[range(bs*i, bs*(i+1)), self.maxlen + xnp[:, i]] = 1
onehot[range(bs), i, i] = 1
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms)
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
for t in self.tbs:
x = t(x)