mirror of https://github.com/commaai/tinygrad.git
fix llama n_kv_heads in kvcache (#2267)
* fix llama n_kv_heads in kvcache * trigger ci
This commit is contained in:
parent
78623ba204
commit
880e693207
|
@ -84,7 +84,7 @@ class Attention:
|
|||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
|
Loading…
Reference in New Issue