mirror of https://github.com/commaai/tinygrad.git
cleanup llama apply_rotary_emb and other helpers (#2950)
* cleanup llama apply_rotary_emb and other helpers used ellipsis and other higher level tensor function. disabled the half @ half -> half tensor core as it fails uop dtype checks * keep hip 8x8->8 wmma
This commit is contained in:
parent
61e255d197
commit
ad4472e6e8
|
@ -10,17 +10,17 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
|||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
|
||||
a,b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
||||
assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
@ -28,7 +28,8 @@ def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
|||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
|
@ -54,7 +55,7 @@ class Attention:
|
|||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
x = x.half()
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq, xk, xv = self.wq(x).half(), self.wk(x).half(), self.wv(x).half()
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
|
@ -66,11 +67,11 @@ class Attention:
|
|||
self.cache_k = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
|
||||
# TODO: fix coder, old hack did not work after the uop dtype check
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
assert keys.dtype == self.cache_k.dtype and values.dtype == self.cache_v.dtype, f"{keys.dtype=}, {values.dtype=}, {self.cache_k.dtype=}, {self.cache_v.dtype=}"
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
|
|
|
@ -38,7 +38,8 @@ class TensorCore:
|
|||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
# TODO: enable half @ half -> half tensor core with correct dtypes in uop
|
||||
# TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
|
|
Loading…
Reference in New Issue