Add LLaMA-2 support (#1284)

Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
Pavol Rusnak 2023-07-24 23:12:02 +02:00 committed by GitHub
parent d89fb729e5
commit cd60b8561c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 49 deletions

View File

@ -92,9 +92,12 @@ class Attention:
return self.wo(output)
class FeedForward:
def __init__(self, dim, hidden_dim, multiple_of):
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
# TODO: what is this?
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = Linear(dim, hidden_dim, bias=False)
self.w2 = Linear(hidden_dim, dim, bias=False)
@ -104,9 +107,9 @@ class FeedForward:
return self.w2(self.w1(x).silu() * self.w3(x))
class TransformerBlock:
def __init__(self, dim, multiple_of, n_heads, norm_eps):
def __init__(self, dim, multiple_of, n_heads, norm_eps, ffn_dim_multiplier=None):
self.attention = Attention(dim, n_heads)
self.feed_forward = FeedForward(dim, 4*dim, multiple_of)
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
if getenv("JIT"):
@ -130,8 +133,8 @@ class TransformerBlock:
return self._post(x, output)
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)]
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps, ffn_dim_multiplier) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = Linear(dim, vocab_size, bias=False)
@ -150,23 +153,42 @@ class Transformer:
# **** files and arguments ****
WEIGHTS_DIR = Path(__file__).parent.parent / "weights/LLaMA/"
TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model"
VOCAB_SIZE = 32000
args_small = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}
args_7B = {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
WEIGHTS_7B_FILENAME = WEIGHTS_DIR / "7B/consolidated.00.pth"
args_13B = {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
WEIGHTS_13B_FILENAMES = [WEIGHTS_DIR / "13B/consolidated.00.pth", WEIGHTS_DIR / "13B/consolidated.01.pth"]
args_30B = {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
WEIGHTS_30B_FILENAMES = [WEIGHTS_DIR / "30B/consolidated.00.pth", WEIGHTS_DIR / "30B/consolidated.01.pth", WEIGHTS_DIR / "30B/consolidated.02.pth", WEIGHTS_DIR / "30B/consolidated.03.pth"]
args_65B = {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}
WEIGHTS_65B_FILENAMES = [WEIGHTS_DIR / "65B/consolidated.00.pth", WEIGHTS_DIR / "65B/consolidated.01.pth", WEIGHTS_DIR / "65B/consolidated.02.pth", WEIGHTS_DIR / "65B/consolidated.03.pth", WEIGHTS_DIR / "65B/consolidated.04.pth", WEIGHTS_DIR / "65B/consolidated.05.pth", WEIGHTS_DIR / "65B/consolidated.06.pth", WEIGHTS_DIR / "65B/consolidated.07.pth"]
MODEL_PARAMS = {
1: {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"files": 2,
},
"30B": {
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
"files": 4,
},
"65B": {
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"files": 8,
},
},
2: {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"files": 2,
},
# # 70B is disabled because we do not yet implement n_kv_heads argument
# "70B": {
# "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
# "files": 8,
# },
},
}
# **** helper functions ****
def sample(logits, temperature):
@ -199,11 +221,8 @@ def concat_weights(models):
if __name__ == "__main__":
Tensor.no_grad = True
print(f"using {Device.DEFAULT} backend")
from sentencepiece import SentencePieceProcessor
sp_model = SentencePieceProcessor(model_file=str(TOKENIZER_FILENAME))
assert sp_model.vocab_size() == VOCAB_SIZE
parser = argparse.ArgumentParser(description='Run LLaMA 7B in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(description='Run LLaMA in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# test: python3 examples/llama.py --prompt="Hello." --temperature=0
# Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
parser.add_argument('--prompt', type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
@ -213,31 +232,29 @@ if __name__ == "__main__":
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B]")
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
args = parser.parse_args()
chatbot = args.prompt == None
LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
WEIGHTS_DIR = Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/"
TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model"
from sentencepiece import SentencePieceProcessor
sp_model = SentencePieceProcessor(model_file=str(TOKENIZER_FILENAME))
assert sp_model.vocab_size() == VOCAB_SIZE
from tinygrad.state import torch_load, load_state_dict
if args.size == "65B":
print("using 65B model")
model = Transformer(**args_65B)
weights = [torch_load(filename) for filename in WEIGHTS_65B_FILENAMES]
load_state_dict(model, concat_weights(weights), strict=False)
elif args.size == "30B":
print("using 30B model")
model = Transformer(**args_30B)
weights = [torch_load(filename) for filename in WEIGHTS_30B_FILENAMES]
load_state_dict(model, concat_weights(weights), strict=False)
elif args.size == "13B":
print("using 13B model")
model = Transformer(**args_13B)
weights = [torch_load(filename) for filename in WEIGHTS_13B_FILENAMES]
load_state_dict(model, concat_weights(weights), strict=False)
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
params = MODEL_PARAMS[args.gen][args.size]
model = Transformer(**params["args"])
weights = [torch_load(WEIGHTS_DIR / f"{args.size}/consolidated.{i:02d}.pth") for i in range(params["files"])]
if len(weights) == 1:
load_state_dict(model, weights[0], strict=False)
else:
print("using 7B model")
model = Transformer(**args_7B)
load_state_dict(model, torch_load(WEIGHTS_7B_FILENAME), strict=False)
load_state_dict(model, concat_weights(weights), strict=False)
# *** prompt engineers work here ****

View File

@ -1,6 +1,6 @@
# NOTE: this only tests the speed of the LLaMA codegen, it doesn't actually run the net
import unittest, time
from examples.llama import Transformer, args_7B
from examples.llama import Transformer, MODEL_PARAMS
from test.test_net_speed import start_profile, stop_profile
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
@ -19,7 +19,7 @@ class TestLLaMASpeed(unittest.TestCase):
print("using", Device['fake'].codegen)
print("testing llama python run time")
model = Transformer(**args_7B)
model = Transformer(**MODEL_PARAMS[1]["7B"]["args"])
print("built model")
# assign fake tensors to the values
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))

View File

@ -8,7 +8,7 @@ from tinygrad.lazy import Device
from tinygrad.helpers import CI, dtypes
from examples.hlb_cifar10 import SpeedyResNet
from examples.llama import Transformer, args_7B
from examples.llama import Transformer, MODEL_PARAMS
from examples.stable_diffusion import UNetModel
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed):
@ -55,7 +55,7 @@ class TestRealWorld(unittest.TestCase):
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**(args_tiny if CI else args_7B))
model = Transformer(**(args_tiny if CI else MODEL_PARAMS[1]["7B"]["args"]))
derandomize_model(model)
@TinyJit
def test(t): return model(t, 0).realize()

View File

@ -4,6 +4,7 @@ from typing import Dict, Union, List
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters
from tinygrad.shape.shapetracker import strides_for_shape
from tinygrad.lazy import Device
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
@ -67,6 +68,11 @@ def torch_load(fn:str):
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
# convert bfloat16 -> float16 using LLVM
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
if storage[1] == dtypes.bfloat16:
ret = ret.to("LLVM").half().to(Device.DEFAULT).realize()
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
@ -80,7 +86,7 @@ def torch_load(fn:str):
return ret.reshape(size)
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):