2023-08-11 00:09:47 +08:00
|
|
|
#!/usr/bin/env python3
|
2024-01-05 02:53:18 +08:00
|
|
|
from typing import Optional, Union
|
2023-09-22 20:14:47 +08:00
|
|
|
import argparse
|
2023-11-11 04:07:02 +08:00
|
|
|
import numpy as np
|
|
|
|
import tiktoken
|
2024-03-30 12:30:30 +08:00
|
|
|
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
2024-10-04 16:42:27 +08:00
|
|
|
from tinygrad.ops import UOp
|
2024-07-10 03:04:43 +08:00
|
|
|
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
|
2024-01-05 02:53:18 +08:00
|
|
|
from tinygrad.nn import Embedding, Linear, LayerNorm
|
|
|
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
2023-08-11 00:09:47 +08:00
|
|
|
|
2023-11-25 10:10:10 +08:00
|
|
|
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
2023-11-29 08:27:03 +08:00
|
|
|
HALF = getenv("HALF")
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
class Attention:
|
2023-11-11 04:07:02 +08:00
|
|
|
def __init__(self, dim, n_heads):
|
|
|
|
self.c_attn = Linear(dim, 3*dim, bias=True)
|
|
|
|
self.c_proj = Linear(dim, dim, bias=True)
|
2023-08-11 00:09:47 +08:00
|
|
|
self.n_heads = n_heads
|
|
|
|
self.dim = dim
|
|
|
|
self.head_dim = dim // n_heads
|
|
|
|
|
2023-11-11 04:07:02 +08:00
|
|
|
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
|
2024-01-10 05:14:55 +08:00
|
|
|
if mask is not None or start_pos.val == 0:
|
2023-11-11 04:07:02 +08:00
|
|
|
# no symbolic shape qkv when consuming prompts
|
|
|
|
start_pos = start_pos.val
|
|
|
|
|
2023-12-06 05:54:56 +08:00
|
|
|
if HALF: x = x.half()
|
2023-08-11 00:09:47 +08:00
|
|
|
xqkv = self.c_attn(x)
|
2024-01-05 02:53:18 +08:00
|
|
|
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
|
|
|
|
bsz, seqlen, _, _ = xq.shape
|
2023-11-11 04:07:02 +08:00
|
|
|
|
|
|
|
# create kv cache
|
2023-12-07 09:59:17 +08:00
|
|
|
if not hasattr(self, "cache_kv"):
|
2024-03-15 11:44:34 +08:00
|
|
|
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
2023-11-11 04:07:02 +08:00
|
|
|
|
2024-04-08 08:35:22 +08:00
|
|
|
# update the cache
|
2024-05-25 05:04:19 +08:00
|
|
|
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()
|
2024-04-08 08:35:22 +08:00
|
|
|
|
2024-01-10 05:14:55 +08:00
|
|
|
if start_pos > 0:
|
2024-04-08 08:35:22 +08:00
|
|
|
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
|
|
|
|
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
|
2024-01-10 05:14:55 +08:00
|
|
|
else:
|
|
|
|
keys = xk
|
|
|
|
values = xv
|
2023-11-11 04:07:02 +08:00
|
|
|
|
2023-08-11 00:09:47 +08:00
|
|
|
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
2024-01-13 03:46:36 +08:00
|
|
|
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
class FeedForward:
|
2023-11-11 04:07:02 +08:00
|
|
|
def __init__(self, dim, hidden_dim):
|
|
|
|
self.c_fc = Linear(dim, hidden_dim, bias=True)
|
|
|
|
self.c_proj = Linear(hidden_dim, dim, bias=True)
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
|
|
return self.c_proj(self.c_fc(x).gelu())
|
|
|
|
|
|
|
|
class TransformerBlock:
|
2023-11-11 04:07:02 +08:00
|
|
|
def __init__(self, dim, n_heads, norm_eps):
|
|
|
|
self.attn = Attention(dim, n_heads)
|
|
|
|
self.mlp = FeedForward(dim, 4*dim)
|
2023-08-11 00:09:47 +08:00
|
|
|
self.ln_1 = LayerNorm(dim, norm_eps)
|
|
|
|
self.ln_2 = LayerNorm(dim, norm_eps)
|
|
|
|
|
2023-11-11 04:07:02 +08:00
|
|
|
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
|
2024-01-05 02:53:18 +08:00
|
|
|
h = x + self.attn(self.ln_1(x), start_pos, mask).float()
|
2023-11-11 04:07:02 +08:00
|
|
|
return (h + self.mlp(self.ln_2(h)))
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
class Transformer:
|
2023-11-11 04:07:02 +08:00
|
|
|
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
2024-01-13 03:46:36 +08:00
|
|
|
self.vocab_size = vocab_size
|
2023-08-11 00:09:47 +08:00
|
|
|
self.wte = Embedding(vocab_size, dim)
|
|
|
|
self.wpe = Embedding(max_seq_len, dim)
|
2023-11-11 04:07:02 +08:00
|
|
|
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
|
2023-08-11 00:09:47 +08:00
|
|
|
self.ln_f = LayerNorm(dim, norm_eps)
|
2023-11-11 04:07:02 +08:00
|
|
|
self.lm_head = Linear(dim, vocab_size, bias=False)
|
|
|
|
self.forward_jit = TinyJit(self.forward)
|
|
|
|
|
2024-10-04 16:42:27 +08:00
|
|
|
def forward(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0):
|
2023-11-11 04:07:02 +08:00
|
|
|
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
2024-10-04 16:42:27 +08:00
|
|
|
if isinstance(tokens, UOp):
|
2023-12-06 12:01:17 +08:00
|
|
|
seqlen = 1
|
|
|
|
tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
|
|
|
|
else:
|
|
|
|
seqlen = tokens.shape[1]
|
|
|
|
tok_emb = self.wte(tokens)
|
2023-08-11 00:09:47 +08:00
|
|
|
|
2023-11-11 04:07:02 +08:00
|
|
|
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
|
2023-08-11 00:09:47 +08:00
|
|
|
h = tok_emb + pos_emb
|
2023-08-23 06:14:38 +08:00
|
|
|
|
2023-12-15 11:41:51 +08:00
|
|
|
if HALF: h = h.half()
|
2023-11-29 08:27:03 +08:00
|
|
|
|
2024-01-05 02:53:18 +08:00
|
|
|
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None
|
2023-11-29 08:27:03 +08:00
|
|
|
|
2023-12-15 11:41:51 +08:00
|
|
|
for hi in self.h: h = hi(h, start_pos, mask)
|
2023-08-11 00:09:47 +08:00
|
|
|
|
2024-01-13 03:46:36 +08:00
|
|
|
logits = self.lm_head(self.ln_f(h))
|
|
|
|
|
|
|
|
if logits.shape[1] == 0:
|
|
|
|
# special case for empty prompt
|
|
|
|
logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
|
|
|
|
else:
|
|
|
|
logits = logits[:, -1, :]
|
|
|
|
|
2023-12-29 12:26:00 +08:00
|
|
|
if temperature < 1e-6:
|
2024-01-05 06:01:50 +08:00
|
|
|
ret = logits.argmax(-1)
|
2023-12-29 12:26:00 +08:00
|
|
|
else:
|
2024-01-05 06:01:50 +08:00
|
|
|
ret = (logits / temperature).softmax().multinomial()
|
|
|
|
return ret.flatten().realize()
|
2023-09-22 20:14:47 +08:00
|
|
|
|
2024-10-04 16:42:27 +08:00
|
|
|
def __call__(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0) -> Tensor:
|
|
|
|
forward = (self.forward_jit if JIT and (isinstance(tokens, UOp) or tokens.shape[1] == 1) else self.forward)
|
2024-01-05 12:14:53 +08:00
|
|
|
return forward(tokens, start_pos, temperature)
|
2023-08-11 00:09:47 +08:00
|
|
|
|
2023-09-02 23:39:12 +08:00
|
|
|
VOCAB_SIZE = 50257
|
2023-08-11 00:09:47 +08:00
|
|
|
MODEL_PARAMS = {
|
2023-09-02 23:39:12 +08:00
|
|
|
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
|
|
|
|
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
|
|
|
|
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
|
|
|
|
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
|
2023-08-11 00:09:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
class GPT2:
|
|
|
|
@staticmethod
|
|
|
|
def build(model_size="gpt2"):
|
|
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
|
2023-11-11 04:07:02 +08:00
|
|
|
model = Transformer(**MODEL_PARAMS[model_size])
|
2023-11-24 06:16:17 +08:00
|
|
|
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
2023-08-11 00:09:47 +08:00
|
|
|
# special treatment for the Conv1D weights we need to transpose
|
2024-01-05 12:14:53 +08:00
|
|
|
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
|
|
|
|
for k in weights:
|
|
|
|
if k.endswith(transposed):
|
|
|
|
weights[k] = weights[k].T
|
2023-08-11 00:09:47 +08:00
|
|
|
# lm head and wte are tied
|
2024-01-05 12:14:53 +08:00
|
|
|
weights['lm_head.weight'] = weights['wte.weight']
|
2023-08-11 00:09:47 +08:00
|
|
|
|
2023-09-22 20:14:47 +08:00
|
|
|
load_state_dict(model, weights)
|
2024-01-05 12:14:53 +08:00
|
|
|
|
|
|
|
if HALF:
|
|
|
|
for l in get_state_dict(model).values():
|
2024-03-15 04:34:14 +08:00
|
|
|
l.replace(l.half().realize())
|
2024-01-05 12:14:53 +08:00
|
|
|
|
2023-08-11 00:09:47 +08:00
|
|
|
return GPT2(model, tokenizer)
|
|
|
|
|
|
|
|
def __init__(self, model, tokenizer):
|
|
|
|
self.model = model
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
2023-12-21 01:14:55 +08:00
|
|
|
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
2023-11-25 10:10:10 +08:00
|
|
|
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
|
|
|
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
2023-08-11 00:09:47 +08:00
|
|
|
start_pos = 0
|
|
|
|
for _ in trange(max_length, disable=(timing==True)):
|
2023-08-22 10:44:57 +08:00
|
|
|
GlobalCounters.reset()
|
2023-09-22 20:14:47 +08:00
|
|
|
if timing: print("")
|
2023-08-11 00:09:47 +08:00
|
|
|
st = GlobalCounters.time_sum_s
|
2024-01-05 06:01:50 +08:00
|
|
|
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
|
|
|
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
|
|
|
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
|
|
|
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
|
|
|
|
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
|
|
|
|
else:
|
|
|
|
tokens = Tensor([x[start_pos:] for x in toks])
|
|
|
|
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
|
2023-11-25 10:10:10 +08:00
|
|
|
start_pos = len(toks[0])
|
|
|
|
for i,t in enumerate(tok): toks[i].append(t)
|
2023-12-21 01:14:55 +08:00
|
|
|
return [self.tokenizer.decode(x) for x in toks]
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
# **** main code ****
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
Tensor.no_grad = True
|
|
|
|
print(f"using {Device.DEFAULT} backend")
|
2023-12-06 11:15:16 +08:00
|
|
|
default_prompt = "What is the answer to life, the universe, and everything?"
|
2023-08-11 00:09:47 +08:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
2023-12-06 11:15:16 +08:00
|
|
|
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to start with")
|
2023-08-11 00:09:47 +08:00
|
|
|
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
|
|
|
|
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
|
|
|
|
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
|
|
|
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
2023-09-16 05:34:14 +08:00
|
|
|
parser.add_argument('--seed', type=int, help="Set the random seed")
|
2023-11-28 05:23:03 +08:00
|
|
|
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
2023-11-25 10:10:10 +08:00
|
|
|
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
2023-11-28 05:23:03 +08:00
|
|
|
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
2023-08-11 00:09:47 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-09-16 05:34:14 +08:00
|
|
|
if args.seed is not None:
|
2024-01-05 12:14:53 +08:00
|
|
|
Tensor.manual_seed(args.seed)
|
2023-09-16 05:34:14 +08:00
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
2023-08-11 00:09:47 +08:00
|
|
|
print(f"using {args.model_size}")
|
|
|
|
gpt2 = GPT2.build(args.model_size)
|
2023-11-25 10:10:10 +08:00
|
|
|
|
|
|
|
if args.benchmark != -1:
|
|
|
|
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
|
|
|
else:
|
2023-12-21 01:14:55 +08:00
|
|
|
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
2023-11-28 05:23:03 +08:00
|
|
|
if not args.noshow:
|
|
|
|
print('Generating text...')
|
|
|
|
if len(texts) == 1: print(texts[0])
|
|
|
|
else:
|
2023-12-06 11:15:16 +08:00
|
|
|
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
|
|
|
|
|
|
|
|
# validate output!
|
|
|
|
if args.temperature == 0 and args.model_size == "gpt2-medium" and args.count == 10:
|
|
|
|
expected = {
|
|
|
|
default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
|
|
|
|
"Hello.": "Hello. I'm a little late to the party, but",
|
|
|
|
}
|
|
|
|
try:
|
|
|
|
assert texts[0] == expected[args.prompt]
|
|
|
|
print(colored("output validated", "green"))
|
|
|
|
except KeyError:
|
|
|
|
pass
|