Ensure freqs as type float32 in freqs_cis (#1798)

This commit is contained in:
badcc 2023-09-06 10:24:15 -07:00 committed by GitHub
parent 35072877ef
commit fd25792c8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -23,7 +23,7 @@ JIT = getenv("JIT", 0 if CI else 1)
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (np.arange(0, dim, 2, dtype=np.float32)[:(dim // 2)] / dim))
freqs = np.outer(np.arange(end, dtype=np.float32), freqs)
freqs = np.outer(np.arange(end, dtype=np.float32), freqs.astype(np.float32))
return np.stack([np.cos(freqs), np.sin(freqs)], axis=-1).reshape(1, end, 1, dim//2, 2)
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)