mirror of https://github.com/commaai/tinygrad.git
JIT=0 llama.py should not jit (#2609)
This commit is contained in:
parent
41d696145d
commit
6ba6349c97
|
@ -1,5 +1,6 @@
|
|||
from typing import Tuple, Union, Optional, Dict
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
|
@ -118,7 +119,7 @@ class Transformer:
|
|||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit:
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1):
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
|
Loading…
Reference in New Issue