diff --git a/examples/whisper.py b/examples/whisper.py index 5caf1f2f..1e25da65 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -1,18 +1,13 @@ # thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code -import sys -import pathlib -import base64 -import multiprocessing -import numpy as np +import sys, base64, multiprocessing, itertools from typing import Optional, Union, Literal, List -from tinygrad.engine.jit import TinyJit + +from tinygrad import Tensor, TinyJit, Variable, nn from tinygrad.nn.state import torch_load, load_state_dict -from tinygrad.helpers import getenv, DEBUG, CI, fetch -import tinygrad.nn as nn -from tinygrad.shape.symbolic import Variable -from tinygrad.tensor import Tensor -import itertools +from tinygrad.helpers import getenv, DEBUG, fetch + +import numpy as np import librosa class MultiHeadAttention: @@ -33,9 +28,8 @@ class MultiHeadAttention: if not hasattr(self, 'cache_k'): self.cache_k, self.cache_v = k, v else: - # see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994 - self.cache_k.assign(k+1-1).realize() - self.cache_v.assign(v+1-1).realize() + self.cache_k.assign(k).realize() + self.cache_v.assign(v).realize() else: k, v = self.cache_k, self.cache_v else: