mirror of https://github.com/commaai/tinygrad.git
extreme llama speed, 57.34 tok/s (#5827)
* extreme llama speed * mergable
This commit is contained in:
parent
e6879035a0
commit
21c5e8e1b7
|
@ -4,7 +4,7 @@
|
|||
#typeguard.importhook.install_import_hook('tinygrad')
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import argparse, json
|
||||
import numpy as np
|
||||
np.set_printoptions(linewidth=200)
|
||||
|
@ -464,17 +464,19 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
toks = new_toks
|
||||
assert outputted == llama.tokenizer.decode(toks)
|
||||
|
||||
tok_tensor: Optional[Tensor] = None
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
|
||||
if args.timing or args.profile: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
next_tok = Tensor([toks[start_pos:]], device=device) if tok_tensor is None or (len(toks)-start_pos) > 1 else tok_tensor.reshape(1, 1)
|
||||
with Profiling(enabled=args.profile):
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
|
||||
with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
tok_tensor = llama.model(Tensor([toks[start_pos:]], device=device), start_pos, args.temperature)
|
||||
tok_tensor = llama.model(next_tok, start_pos, args.temperature)
|
||||
tok = tok_tensor.item()
|
||||
|
||||
# use the kv cache
|
||||
|
|
|
@ -46,7 +46,13 @@ class Attention:
|
|||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
xqkv = x @ self.wqkv.T
|
||||
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
|
@ -92,7 +98,7 @@ class TransformerBlock:
|
|||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
|
||||
return h + self.feed_forward(self.ffn_norm(h))
|
||||
|
||||
# standard openai sampling
|
||||
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
|
|
|
@ -37,7 +37,7 @@ def fold_expanded(ex, buf):
|
|||
used = set()
|
||||
for rootsrc, offsets in offsets_rootsrc.items():
|
||||
for o in offsets:
|
||||
for fold_length in [4] if is_image else [4, 2]:
|
||||
for fold_length in [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else [4,2]):
|
||||
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
||||
load_1 = new_srcs[offsets[o]]
|
||||
new_src = list(load_1.src)
|
||||
|
|
Loading…
Reference in New Issue