mirror of https://github.com/commaai/tinygrad.git
codellama (#1702)
* add codellama with pre-downloaded weights * add rope_theta, fix param * fix test * add 7B-Python * add 7B-Instruct * replace single quotes with doulbe --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
a2745819f6
commit
66a6bbd029
|
@ -133,13 +133,14 @@ class TransformerBlock:
|
|||
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.kv_caches = [(None, None) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta))
|
||||
self.norm_output = lambda x: self.output(self.norm(x))
|
||||
|
||||
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
|
||||
self.postprocess_jitted = TinyJit(self.postprocess)
|
||||
|
@ -176,41 +177,77 @@ class Transformer:
|
|||
return self.postprocess(h, temperature)
|
||||
|
||||
# **** files and arguments ****
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
MODEL_PARAMS = {
|
||||
1: {
|
||||
"1": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000},
|
||||
"files": 2,
|
||||
},
|
||||
"30B": {
|
||||
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000},
|
||||
"files": 4,
|
||||
},
|
||||
"65B": {
|
||||
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
2: {
|
||||
"2": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000},
|
||||
"files": 2,
|
||||
},
|
||||
"70B": {
|
||||
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"code": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Python": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Instruct": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Python": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Instruct": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_headvocab_sizes": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 2,
|
||||
},
|
||||
"34B": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Python": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Instruct": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
|
||||
"files": 4,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# **** helper functions ****
|
||||
|
@ -219,7 +256,7 @@ def concat_weights(models):
|
|||
disk_tensors = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
return disk_tensors[0].to(device=Device.DEFAULT)
|
||||
axis = 1 if name.startswith('tok_embeddings.') or name.endswith('.attention.wo.weight') or name.endswith('.feed_forward.w2.weight') else 0
|
||||
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
||||
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
|
@ -229,22 +266,22 @@ def load(fn:str):
|
|||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
parts = {n: load(Path(fn).parent / Path(n).name) for n in set(weight_map.values())}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith('.safetensors'):
|
||||
elif fn.endswith(".safetensors"):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
def convert_from_huggingface(weights, model):
|
||||
keymap = {
|
||||
'model.embed_tokens.weight': 'tok_embeddings.weight',
|
||||
**{f'model.layers.{l}.input_layernorm.weight': f'layers.{l}.attention_norm.weight' for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.self_attn.{x}_proj.weight': f'layers.{l}.attention.w{x}.weight' for x in ['q', 'k', 'v', 'o'] for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.post_attention_layernorm.weight': f'layers.{l}.ffn_norm.weight' for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.mlp.{x}_proj.weight': f'layers.{l}.feed_forward.w{y}.weight' for x, y in {'gate': '1', 'down': '2', 'up': '3'}.items() for l in range(len(model.layers))},
|
||||
'model.norm.weight': 'norm.weight',
|
||||
'lm_head.weight': 'output.weight',
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
return {keymap[k]: v for k,v in weights.items() if '.rotary_emb.' not in k}
|
||||
return {keymap[k]: v for k,v in weights.items() if ".rotary_emb." not in k}
|
||||
|
||||
class AbsmaxQuantizedLinear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
|
@ -259,7 +296,7 @@ class AbsmaxQuantizedLinear:
|
|||
def quantize(tensors):
|
||||
new_tensors = {}
|
||||
for name,v in tensors.items():
|
||||
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
|
||||
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
|
@ -270,10 +307,10 @@ class AbsmaxQuantizedLinear:
|
|||
|
||||
class LLaMa:
|
||||
@staticmethod
|
||||
def build(model_path, tokenizer_path, model_gen=1, model_size="7B", quantize=False):
|
||||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"]
|
||||
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
|
@ -282,7 +319,7 @@ class LLaMa:
|
|||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
else:
|
||||
weights = load(str(model_path))
|
||||
if 'model.embed_tokens.weight' in weights:
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(weights, model)
|
||||
|
||||
if quantize:
|
||||
|
@ -313,27 +350,81 @@ class LLaMa:
|
|||
return output
|
||||
|
||||
# **** main code ****
|
||||
"""
|
||||
test:
|
||||
python3 examples/llama.py --temperature=0 --count=50 --prompt="Hello."
|
||||
output:
|
||||
Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
|
||||
|
||||
test:
|
||||
python3 examples/llama.py --gen='2' --temperature=0 --count=50 --prompt="Hello."
|
||||
output:
|
||||
Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean.
|
||||
|
||||
test:
|
||||
python3 examples/llama.py --gen="code" --temperature=0.2 --count=50 --prompt="\
|
||||
import argparse
|
||||
|
||||
def main(string: str):
|
||||
print(string)
|
||||
print(string[::-1])
|
||||
|
||||
if __name__ == "__main__":"
|
||||
output:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('string', type=str, help='string to be reversed')
|
||||
args = parser.parse_args()
|
||||
main(args.string)
|
||||
|
||||
test:
|
||||
python3 examples/llama.py --gen="code" --size="7B-Python" --temperature=0.2 --count=70 --prompt="def add_elements(arr,k):"
|
||||
output:
|
||||
for i in range(len(arr)):
|
||||
arr[i] += k
|
||||
return arr
|
||||
|
||||
|
||||
arr = [1, 2, 3, 4, 5]
|
||||
k = 2
|
||||
print(add_elements(arr, k))
|
||||
|
||||
test:
|
||||
python3 examples/llama.py --gen="code" --size="7B-Instruct" --temperature=0.2 --count=120 --prompt="write a function in c++ that adds three float numbers"
|
||||
output:
|
||||
\begin{code}
|
||||
#include<iostream>
|
||||
using namespace std;
|
||||
|
||||
float add(float a, float b, float c)
|
||||
{
|
||||
return a+b+c;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
float a, b, c;
|
||||
cout<<"Enter three numbers: ";
|
||||
cin>>a>>b>>c;
|
||||
cout<<"The sum is: "<<add(a,b,c);
|
||||
return 0;
|
||||
}
|
||||
\end{code}
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run LLaMA in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
# test: python3 examples/llama.py --prompt="Hello." --temperature=0
|
||||
# Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
|
||||
# test: python3 examples/llama.py --gen 2 --prompt="Hello." --temperature=0
|
||||
# Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean.
|
||||
parser.add_argument('--prompt', type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
||||
parser.add_argument('--count', type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument('--personality', type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
|
||||
|
||||
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2")
|
||||
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
|
||||
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument('--model', type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
||||
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
|
||||
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
|
||||
parser.add_argument("--size", type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2, [7B, 13B, 34B] for Code LLaMA")
|
||||
parser.add_argument("--gen", default="1", help="Generation of the model to use ['1', '2', 'code']")
|
||||
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
|
||||
args = parser.parse_args()
|
||||
chatbot = args.prompt == None
|
||||
|
@ -427,8 +518,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
|
||||
# *** prompt engineers stop here ****
|
||||
|
||||
|
||||
LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
|
||||
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code"}[args.gen]
|
||||
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
|
@ -502,4 +592,4 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
if args.profile:
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.dump_stats('out.prof')
|
||||
stats.dump_stats("out.prof")
|
||||
|
|
|
@ -28,7 +28,7 @@ class TestLLaMASpeed(unittest.TestCase):
|
|||
Device[Device.DEFAULT].buffer = RawFakeBuffer
|
||||
|
||||
print("testing llama python run time")
|
||||
model = Transformer(**MODEL_PARAMS[1]["7B"]["args"])
|
||||
model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"])
|
||||
print("built model")
|
||||
# assign fake tensors to the values
|
||||
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
|
||||
|
|
Loading…
Reference in New Issue