mirror of https://github.com/commaai/tinygrad.git
layernorm fixes transformer instability
This commit is contained in:
parent
628d21f899
commit
2e89e75dcb
|
@ -23,6 +23,12 @@ def make_dataset():
|
|||
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
def layernorm(x, eps=1e-5):
|
||||
layer_mean = x.mean(axis=(0,1))
|
||||
y = (x - layer_mean.reshape(shape=[1, 1, -1]))
|
||||
layer_var = (y*y).mean(axis=(0,1))
|
||||
return y.div(layer_var.add(eps).reshape(shape=[1, 1, -1]))
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
# Multi-Head Attention
|
||||
|
@ -55,12 +61,16 @@ class TransformerBlock:
|
|||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
|
||||
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3))
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
|
||||
# layernorm
|
||||
x = x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
|
||||
x = layernorm(x)
|
||||
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
|
||||
x = x + x.dot(self.ff1).relu().dot(self.ff2)
|
||||
# layernorm
|
||||
x = x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
|
||||
x = layernorm(x)
|
||||
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
|
||||
return x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
|
||||
|
||||
class Transformer:
|
||||
|
@ -93,7 +103,7 @@ if __name__ == "__main__":
|
|||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 500)
|
||||
train(model, X_train, Y_train, optim, 500, BS=16)
|
||||
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
|
|
@ -28,3 +28,4 @@ class BatchNorm2D:
|
|||
def normalize(self, x, mean, var):
|
||||
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
|
||||
|
||||
|
|
Loading…
Reference in New Issue