fix llama n_kv_heads in kvcache (#2267)

* fix llama n_kv_heads in kvcache

* trigger ci
This commit is contained in:
chenyu 2023-11-10 21:44:39 -05:00 committed by GitHub
parent 78623ba204
commit 880e693207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -84,7 +84,7 @@ class Attention:
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)