mirror of https://github.com/commaai/tinygrad.git
fix codellama params and repeat_kv (#2181)
This commit is contained in:
parent
c7f4dd6cb0
commit
8548b20b23
|
@ -46,7 +46,7 @@ def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
|||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x[:, :, :, None, :].expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
|
@ -224,11 +224,11 @@ MODEL_PARAMS = {
|
|||
"files": 2,
|
||||
},
|
||||
"13B-Instruct": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_headvocab_sizes": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"files": 2,
|
||||
},
|
||||
"34B": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Python": {
|
||||
|
@ -302,7 +302,7 @@ class LLaMa:
|
|||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"]
|
||||
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {MODEL_PARAMS[model_gen][model_size]['args']['vocab_size']}"
|
||||
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
|
|
Loading…
Reference in New Issue