hotfix: revert llama change

This commit is contained in:
George Hotz 2024-04-10 20:13:15 -07:00
parent 2e6c39b0b2
commit e79a11b99c
1 changed files with 1 additions and 1 deletions

View File

@ -101,7 +101,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return h + self.feed_forward(self.ffn_norm(h).half()) return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
class Transformer: class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):