mirror of https://github.com/commaai/tinygrad.git
Ensure freqs as type float32 in freqs_cis (#1798)
This commit is contained in:
parent
35072877ef
commit
fd25792c8b
|
@ -23,7 +23,7 @@ JIT = getenv("JIT", 0 if CI else 1)
|
|||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
freqs = 1.0 / (theta ** (np.arange(0, dim, 2, dtype=np.float32)[:(dim // 2)] / dim))
|
||||
freqs = np.outer(np.arange(end, dtype=np.float32), freqs)
|
||||
freqs = np.outer(np.arange(end, dtype=np.float32), freqs.astype(np.float32))
|
||||
return np.stack([np.cos(freqs), np.sin(freqs)], axis=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
|
|
Loading…
Reference in New Issue