mirror of https://github.com/commaai/tinygrad.git
Add LLaMA-2 support (#1284)
Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
parent
d89fb729e5
commit
cd60b8561c
|
@ -92,9 +92,12 @@ class Attention:
|
|||
return self.wo(output)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, multiple_of):
|
||||
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
|
||||
# TODO: what is this?
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
self.w1 = Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = Linear(hidden_dim, dim, bias=False)
|
||||
|
@ -104,9 +107,9 @@ class FeedForward:
|
|||
return self.w2(self.w1(x).silu() * self.w3(x))
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, multiple_of, n_heads, norm_eps):
|
||||
def __init__(self, dim, multiple_of, n_heads, norm_eps, ffn_dim_multiplier=None):
|
||||
self.attention = Attention(dim, n_heads)
|
||||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of)
|
||||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, ffn_dim_multiplier)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
if getenv("JIT"):
|
||||
|
@ -130,8 +133,8 @@ class TransformerBlock:
|
|||
return self._post(x, output)
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = Linear(dim, vocab_size, bias=False)
|
||||
|
@ -150,23 +153,42 @@ class Transformer:
|
|||
|
||||
# **** files and arguments ****
|
||||
|
||||
WEIGHTS_DIR = Path(__file__).parent.parent / "weights/LLaMA/"
|
||||
TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model"
|
||||
VOCAB_SIZE = 32000
|
||||
|
||||
args_small = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}
|
||||
|
||||
args_7B = {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
|
||||
WEIGHTS_7B_FILENAME = WEIGHTS_DIR / "7B/consolidated.00.pth"
|
||||
|
||||
args_13B = {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
|
||||
WEIGHTS_13B_FILENAMES = [WEIGHTS_DIR / "13B/consolidated.00.pth", WEIGHTS_DIR / "13B/consolidated.01.pth"]
|
||||
|
||||
args_30B = {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}
|
||||
WEIGHTS_30B_FILENAMES = [WEIGHTS_DIR / "30B/consolidated.00.pth", WEIGHTS_DIR / "30B/consolidated.01.pth", WEIGHTS_DIR / "30B/consolidated.02.pth", WEIGHTS_DIR / "30B/consolidated.03.pth"]
|
||||
|
||||
args_65B = {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}
|
||||
WEIGHTS_65B_FILENAMES = [WEIGHTS_DIR / "65B/consolidated.00.pth", WEIGHTS_DIR / "65B/consolidated.01.pth", WEIGHTS_DIR / "65B/consolidated.02.pth", WEIGHTS_DIR / "65B/consolidated.03.pth", WEIGHTS_DIR / "65B/consolidated.04.pth", WEIGHTS_DIR / "65B/consolidated.05.pth", WEIGHTS_DIR / "65B/consolidated.06.pth", WEIGHTS_DIR / "65B/consolidated.07.pth"]
|
||||
MODEL_PARAMS = {
|
||||
1: {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"files": 2,
|
||||
},
|
||||
"30B": {
|
||||
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE},
|
||||
"files": 4,
|
||||
},
|
||||
"65B": {
|
||||
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
2: {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"files": 2,
|
||||
},
|
||||
# # 70B is disabled because we do not yet implement n_kv_heads argument
|
||||
# "70B": {
|
||||
# "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
# "files": 8,
|
||||
# },
|
||||
},
|
||||
}
|
||||
|
||||
# **** helper functions ****
|
||||
def sample(logits, temperature):
|
||||
|
@ -199,11 +221,8 @@ def concat_weights(models):
|
|||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
sp_model = SentencePieceProcessor(model_file=str(TOKENIZER_FILENAME))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run LLaMA 7B in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(description='Run LLaMA in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
# test: python3 examples/llama.py --prompt="Hello." --temperature=0
|
||||
# Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
|
||||
parser.add_argument('--prompt', type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
||||
|
@ -213,31 +232,29 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B]")
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
|
||||
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
|
||||
|
||||
args = parser.parse_args()
|
||||
chatbot = args.prompt == None
|
||||
|
||||
LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
|
||||
WEIGHTS_DIR = Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/"
|
||||
TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model"
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
sp_model = SentencePieceProcessor(model_file=str(TOKENIZER_FILENAME))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
if args.size == "65B":
|
||||
print("using 65B model")
|
||||
model = Transformer(**args_65B)
|
||||
weights = [torch_load(filename) for filename in WEIGHTS_65B_FILENAMES]
|
||||
load_state_dict(model, concat_weights(weights), strict=False)
|
||||
elif args.size == "30B":
|
||||
print("using 30B model")
|
||||
model = Transformer(**args_30B)
|
||||
weights = [torch_load(filename) for filename in WEIGHTS_30B_FILENAMES]
|
||||
load_state_dict(model, concat_weights(weights), strict=False)
|
||||
elif args.size == "13B":
|
||||
print("using 13B model")
|
||||
model = Transformer(**args_13B)
|
||||
weights = [torch_load(filename) for filename in WEIGHTS_13B_FILENAMES]
|
||||
load_state_dict(model, concat_weights(weights), strict=False)
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
params = MODEL_PARAMS[args.gen][args.size]
|
||||
model = Transformer(**params["args"])
|
||||
weights = [torch_load(WEIGHTS_DIR / f"{args.size}/consolidated.{i:02d}.pth") for i in range(params["files"])]
|
||||
if len(weights) == 1:
|
||||
load_state_dict(model, weights[0], strict=False)
|
||||
else:
|
||||
print("using 7B model")
|
||||
model = Transformer(**args_7B)
|
||||
load_state_dict(model, torch_load(WEIGHTS_7B_FILENAME), strict=False)
|
||||
load_state_dict(model, concat_weights(weights), strict=False)
|
||||
|
||||
# *** prompt engineers work here ****
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# NOTE: this only tests the speed of the LLaMA codegen, it doesn't actually run the net
|
||||
import unittest, time
|
||||
from examples.llama import Transformer, args_7B
|
||||
from examples.llama import Transformer, MODEL_PARAMS
|
||||
from test.test_net_speed import start_profile, stop_profile
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import Device
|
||||
|
@ -19,7 +19,7 @@ class TestLLaMASpeed(unittest.TestCase):
|
|||
print("using", Device['fake'].codegen)
|
||||
|
||||
print("testing llama python run time")
|
||||
model = Transformer(**args_7B)
|
||||
model = Transformer(**MODEL_PARAMS[1]["7B"]["args"])
|
||||
print("built model")
|
||||
# assign fake tensors to the values
|
||||
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
|
||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.lazy import Device
|
|||
from tinygrad.helpers import CI, dtypes
|
||||
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
from examples.llama import Transformer, args_7B
|
||||
from examples.llama import Transformer, MODEL_PARAMS
|
||||
from examples.stable_diffusion import UNetModel
|
||||
|
||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed):
|
||||
|
@ -55,7 +55,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**(args_tiny if CI else args_7B))
|
||||
model = Transformer(**(args_tiny if CI else MODEL_PARAMS[1]["7B"]["args"]))
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, Union, List
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters
|
||||
from tinygrad.shape.shapetracker import strides_for_shape
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
@ -67,6 +68,11 @@ def torch_load(fn:str):
|
|||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
# convert bfloat16 -> float16 using LLVM
|
||||
# upstream LLaMA also does this conversion:
|
||||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||
if storage[1] == dtypes.bfloat16:
|
||||
ret = ret.to("LLVM").half().to(Device.DEFAULT).realize()
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
|
@ -80,7 +86,7 @@ def torch_load(fn:str):
|
|||
|
||||
return ret.reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
|
|
Loading…
Reference in New Issue