mirror of https://github.com/commaai/tinygrad.git
hotfix: revert llama change
This commit is contained in:
parent
2e6c39b0b2
commit
e79a11b99c
|
@ -101,7 +101,7 @@ class TransformerBlock:
|
||||||
|
|
||||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
return h + self.feed_forward(self.ffn_norm(h).half())
|
return (h + self.feed_forward(self.ffn_norm(h).half())).realize()
|
||||||
|
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
||||||
|
|
Loading…
Reference in New Issue