diff --git a/models/transformer.py b/models/transformer.py index c25ff94f..289f49e9 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -45,9 +45,9 @@ class TransformerBlock: x = x + x.layernorm().linear(*self.ln2).linear(*self.ff1).gelu().linear(*self.ff2).dropout(0.1) else: x = x + self.attn(x).dropout(0.1) - x = x.layernorm().linear(self.ln1) - x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1) - x = x.layernorm().linear(self.ln2) + x = x.layernorm().linear(*self.ln1) + x = x + x.linear(*self.ff1).relu().linear(*self.ff2).dropout(0.1) + x = x.layernorm().linear(*self.ln2) return x class Transformer: