JIT=0 llama.py should not jit (#2609)

This commit is contained in:
chenyu 2023-12-04 20:21:07 -05:00 committed by GitHub
parent 41d696145d
commit 6ba6349c97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -1,5 +1,6 @@
from typing import Tuple, Union, Optional, Dict
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
@ -118,7 +119,7 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit:
if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1):
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
return self.forward(tokens, start_pos, temperature)