* add codellama with pre-downloaded weights

* add rope_theta, fix param

* fix test

* add 7B-Python

* add 7B-Instruct

* replace single quotes with doulbe

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Yixiang Gao 2023-09-02 10:45:12 -05:00 committed by GitHub
parent a2745819f6
commit 66a6bbd029
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 137 additions and 47 deletions

View File

@ -133,13 +133,14 @@ class TransformerBlock:
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
self.kv_caches = [(None, None) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta))
self.norm_output = lambda x: self.output(self.norm(x))
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
self.postprocess_jitted = TinyJit(self.postprocess)
@ -176,41 +177,77 @@ class Transformer:
return self.postprocess(h, temperature)
# **** files and arguments ****
VOCAB_SIZE = 32000
MODEL_PARAMS = {
1: {
"1": {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000},
"files": 2,
},
"30B": {
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000},
"files": 4,
},
"65B": {
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
"files": 8,
},
},
2: {
"2": {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000},
"files": 2,
},
"70B": {
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
"files": 8,
},
},
"code": {
"7B": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"files": 1,
},
"7B-Python": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"files": 1,
},
"7B-Instruct": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"files": 2,
},
"13B-Python": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"files": 2,
},
"13B-Instruct": {
"args": {"dim": 5120, "n_layers": 40, "n_headvocab_sizes": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"files": 2,
},
"34B": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"files": 4,
},
"34B-Python": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"files": 4,
},
"34B-Instruct": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"files": 4,
},
}
}
# **** helper functions ****
@ -219,7 +256,7 @@ def concat_weights(models):
disk_tensors = [model[name] for model in models]
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
return disk_tensors[0].to(device=Device.DEFAULT)
axis = 1 if name.startswith('tok_embeddings.') or name.endswith('.attention.wo.weight') or name.endswith('.feed_forward.w2.weight') else 0
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
return {name: convert(name) for name in {name: None for model in models for name in model}}
@ -229,22 +266,22 @@ def load(fn:str):
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
parts = {n: load(Path(fn).parent / Path(n).name) for n in set(weight_map.values())}
return {k: parts[n][k] for k, n in weight_map.items()}
elif fn.endswith('.safetensors'):
elif fn.endswith(".safetensors"):
return safe_load(fn)
else:
return torch_load(fn)
def convert_from_huggingface(weights, model):
keymap = {
'model.embed_tokens.weight': 'tok_embeddings.weight',
**{f'model.layers.{l}.input_layernorm.weight': f'layers.{l}.attention_norm.weight' for l in range(len(model.layers))},
**{f'model.layers.{l}.self_attn.{x}_proj.weight': f'layers.{l}.attention.w{x}.weight' for x in ['q', 'k', 'v', 'o'] for l in range(len(model.layers))},
**{f'model.layers.{l}.post_attention_layernorm.weight': f'layers.{l}.ffn_norm.weight' for l in range(len(model.layers))},
**{f'model.layers.{l}.mlp.{x}_proj.weight': f'layers.{l}.feed_forward.w{y}.weight' for x, y in {'gate': '1', 'down': '2', 'up': '3'}.items() for l in range(len(model.layers))},
'model.norm.weight': 'norm.weight',
'lm_head.weight': 'output.weight',
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
return {keymap[k]: v for k,v in weights.items() if '.rotary_emb.' not in k}
return {keymap[k]: v for k,v in weights.items() if ".rotary_emb." not in k}
class AbsmaxQuantizedLinear:
def __init__(self, in_features, out_features, bias=False):
@ -259,7 +296,7 @@ class AbsmaxQuantizedLinear:
def quantize(tensors):
new_tensors = {}
for name,v in tensors.items():
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
scale = v.abs().max(axis=1) / 127.0
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
new_tensors[name] = int8_weight
@ -270,10 +307,10 @@ class AbsmaxQuantizedLinear:
class LLaMa:
@staticmethod
def build(model_path, tokenizer_path, model_gen=1, model_size="7B", quantize=False):
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
from sentencepiece import SentencePieceProcessor
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == VOCAB_SIZE
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"]
params = MODEL_PARAMS[model_gen][model_size]
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
@ -282,7 +319,7 @@ class LLaMa:
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
else:
weights = load(str(model_path))
if 'model.embed_tokens.weight' in weights:
if "model.embed_tokens.weight" in weights:
weights = convert_from_huggingface(weights, model)
if quantize:
@ -313,27 +350,81 @@ class LLaMa:
return output
# **** main code ****
"""
test:
python3 examples/llama.py --temperature=0 --count=50 --prompt="Hello."
output:
Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
test:
python3 examples/llama.py --gen='2' --temperature=0 --count=50 --prompt="Hello."
output:
Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean.
test:
python3 examples/llama.py --gen="code" --temperature=0.2 --count=50 --prompt="\
import argparse
def main(string: str):
print(string)
print(string[::-1])
if __name__ == "__main__":"
output:
parser = argparse.ArgumentParser()
parser.add_argument('string', type=str, help='string to be reversed')
args = parser.parse_args()
main(args.string)
test:
python3 examples/llama.py --gen="code" --size="7B-Python" --temperature=0.2 --count=70 --prompt="def add_elements(arr,k):"
output:
for i in range(len(arr)):
arr[i] += k
return arr
arr = [1, 2, 3, 4, 5]
k = 2
print(add_elements(arr, k))
test:
python3 examples/llama.py --gen="code" --size="7B-Instruct" --temperature=0.2 --count=120 --prompt="write a function in c++ that adds three float numbers"
output:
\begin{code}
#include<iostream>
using namespace std;
float add(float a, float b, float c)
{
return a+b+c;
}
int main()
{
float a, b, c;
cout<<"Enter three numbers: ";
cin>>a>>b>>c;
cout<<"The sum is: "<<add(a,b,c);
return 0;
}
\end{code}
"""
if __name__ == "__main__":
Tensor.no_grad = True
print(f"using {Device.DEFAULT} backend")
parser = argparse.ArgumentParser(description='Run LLaMA in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# test: python3 examples/llama.py --prompt="Hello." --temperature=0
# Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
# test: python3 examples/llama.py --gen 2 --prompt="Hello." --temperature=0
# Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean.
parser.add_argument('--prompt', type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
parser.add_argument('--count', type=int, default=1000, help="Max number of tokens to generate")
parser.add_argument('--personality', type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2")
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
parser.add_argument('--model', type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument("--timing", action="store_true", help="Print timing per token")
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
parser.add_argument("--size", type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2, [7B, 13B, 34B] for Code LLaMA")
parser.add_argument("--gen", default="1", help="Generation of the model to use ['1', '2', 'code']")
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
args = parser.parse_args()
chatbot = args.prompt == None
@ -427,8 +518,7 @@ After you are done speaking, output [EOS]. You are not Chad.
# *** prompt engineers stop here ****
LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code"}[args.gen]
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
@ -502,4 +592,4 @@ After you are done speaking, output [EOS]. You are not Chad.
if args.profile:
profiler.disable()
stats = pstats.Stats(profiler)
stats.dump_stats('out.prof')
stats.dump_stats("out.prof")

View File

@ -28,7 +28,7 @@ class TestLLaMASpeed(unittest.TestCase):
Device[Device.DEFAULT].buffer = RawFakeBuffer
print("testing llama python run time")
model = Transformer(**MODEL_PARAMS[1]["7B"]["args"])
model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"])
print("built model")
# assign fake tensors to the values
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))