extreme llama speed, 57.34 tok/s (#5827)

* extreme llama speed

* mergable
This commit is contained in:
George Hotz 2024-07-30 18:32:09 -07:00 committed by GitHub
parent e6879035a0
commit 21c5e8e1b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 5 deletions

View File

@ -4,7 +4,7 @@
#typeguard.importhook.install_import_hook('tinygrad')
from pathlib import Path
from typing import List
from typing import List, Optional
import argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
@ -464,17 +464,19 @@ After you are done speaking, output [EOS]. You are not Chad.
toks = new_toks
assert outputted == llama.tokenizer.decode(toks)
tok_tensor: Optional[Tensor] = None
for i in range(args.count):
GlobalCounters.reset()
if args.timing or args.profile: print("")
st = GlobalCounters.time_sum_s
next_tok = Tensor([toks[start_pos:]], device=device) if tok_tensor is None or (len(toks)-start_pos) > 1 else tok_tensor.reshape(1, 1)
with Profiling(enabled=args.profile):
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
tok_tensor = llama.model(Tensor([toks[start_pos:]], device=device), start_pos, args.temperature)
tok_tensor = llama.model(next_tok, start_pos, args.temperature)
tok = tok_tensor.item()
# use the kv cache

View File

@ -46,7 +46,13 @@ class Attention:
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
xqkv = x @ self.wqkv.T
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@ -92,7 +98,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
return h + self.feed_forward(self.ffn_norm(h))
# standard openai sampling
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):

View File

@ -37,7 +37,7 @@ def fold_expanded(ex, buf):
used = set()
for rootsrc, offsets in offsets_rootsrc.items():
for o in offsets:
for fold_length in [4] if is_image else [4, 2]:
for fold_length in [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else [4,2]):
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
load_1 = new_srcs[offsets[o]]
new_src = list(load_1.src)