tinygrad/examples/gpt2.py

216 lines
9.2 KiB
Python

#!/usr/bin/env python3
from typing import Optional, Union
import argparse
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")
class Attention:
def __init__(self, dim, n_heads):
self.c_attn = Linear(dim, 3*dim, bias=True)
self.c_proj = Linear(dim, dim, bias=True)
self.n_heads = n_heads
self.dim = dim
self.head_dim = dim // n_heads
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
if mask is not None or start_pos.val == 0:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
if HALF: x = x.half()
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
bsz, seqlen, _, _ = xq.shape
# create kv cache
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
# update the cache
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()
if start_pos > 0:
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
else:
keys = xk
values = xv
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
class FeedForward:
def __init__(self, dim, hidden_dim):
self.c_fc = Linear(dim, hidden_dim, bias=True)
self.c_proj = Linear(hidden_dim, dim, bias=True)
def __call__(self, x:Tensor) -> Tensor:
return self.c_proj(self.c_fc(x).gelu())
class TransformerBlock:
def __init__(self, dim, n_heads, norm_eps):
self.attn = Attention(dim, n_heads)
self.mlp = FeedForward(dim, 4*dim)
self.ln_1 = LayerNorm(dim, norm_eps)
self.ln_2 = LayerNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
h = x + self.attn(self.ln_1(x), start_pos, mask).float()
return (h + self.mlp(self.ln_2(h)))
class Transformer:
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
self.vocab_size = vocab_size
self.wte = Embedding(vocab_size, dim)
self.wpe = Embedding(max_seq_len, dim)
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
self.ln_f = LayerNorm(dim, norm_eps)
self.lm_head = Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Union[Tensor,Variable], start_pos:Variable, temperature:float=0.0):
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
if isinstance(tokens, Variable):
seqlen = 1
tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
else:
seqlen = tokens.shape[1]
tok_emb = self.wte(tokens)
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
h = tok_emb + pos_emb
if HALF: h = h.half()
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None
for hi in self.h: h = hi(h, start_pos, mask)
logits = self.lm_head(self.ln_f(h))
if logits.shape[1] == 0:
# special case for empty prompt
logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
else:
logits = logits[:, -1, :]
if temperature < 1e-6:
ret = logits.argmax(-1)
else:
ret = (logits / temperature).softmax().multinomial()
return ret.flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
forward = (self.forward_jit if JIT and (isinstance(tokens, Variable) or tokens.shape[1] == 1) else self.forward)
return forward(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
MODEL_PARAMS = {
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
}
class GPT2:
@staticmethod
def build(model_size="gpt2"):
tokenizer = tiktoken.get_encoding("gpt2")
model = Transformer(**MODEL_PARAMS[model_size])
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
# special treatment for the Conv1D weights we need to transpose
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
for k in weights:
if k.endswith(transposed):
weights[k] = weights[k].T
# lm head and wte are tied
weights['lm_head.weight'] = weights['wte.weight']
load_state_dict(model, weights)
if HALF:
for l in get_state_dict(model).values():
l.replace(l.half().realize())
return GPT2(model, tokenizer)
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
toks = [prompt_tokens[:] for _ in range(batch_size)]
start_pos = 0
for _ in trange(max_length, disable=(timing==True)):
GlobalCounters.reset()
if timing: print("")
st = GlobalCounters.time_sum_s
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
else:
tokens = Tensor([x[start_pos:] for x in toks])
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
return [self.tokenizer.decode(x) for x in toks]
# **** main code ****
if __name__ == "__main__":
Tensor.no_grad = True
print(f"using {Device.DEFAULT} backend")
default_prompt = "What is the answer to life, the universe, and everything?"
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to start with")
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--seed', type=int, help="Set the random seed")
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
args = parser.parse_args()
if args.seed is not None:
Tensor.manual_seed(args.seed)
np.random.seed(args.seed)
print(f"using {args.model_size}")
gpt2 = GPT2.build(args.model_size)
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
else:
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
if not args.noshow:
print('Generating text...')
if len(texts) == 1: print(texts[0])
else:
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
# validate output!
if args.temperature == 0 and args.model_size == "gpt2-medium" and args.count == 10:
expected = {
default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
"Hello.": "Hello. I'm a little late to the party, but",
}
try:
assert texts[0] == expected[args.prompt]
print(colored("output validated", "green"))
except KeyError:
pass