diff --git a/examples/whisper.py b/examples/whisper.py index 2dae8d17..8406b0a4 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -5,59 +5,80 @@ import pathlib import base64 import multiprocessing import numpy as np -from typing import Optional +from typing import Optional, Union, Literal from extra.utils import download_file +from tinygrad.jit import TinyJit from tinygrad.nn.state import torch_load, load_state_dict -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, DEBUG, CI import tinygrad.nn as nn +from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor import itertools import librosa -# TODO: you have written this fifteen times class MultiHeadAttention: - def __init__(self, n_state, n_head): + def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None): self.n_head = n_head self.query = nn.Linear(n_state, n_state) self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state) self.out = nn.Linear(n_state, n_state) - def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None): + self.kv_caching = kv_caching + self.max_self_attn_cache_len = max_self_attn_cache_len + + def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None): + if self.kv_caching == 'cross': + if xa is not None: + k, v = self.key(xa), self.value(xa) + if not hasattr(self, 'cache_k'): + self.cache_k, self.cache_v = k, v + else: + # see test_jitted_read_assign in test_jit.py + self.cache_k.assign(k+1-1).realize() + self.cache_v.assign(v+1-1).realize() + else: + k, v = self.cache_k, self.cache_v + else: + k, v = self.key(x), self.value(x) + if self.kv_caching == 'self': + if not hasattr(self, 'cache_k'): + self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) + self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) + k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1) + v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1) + padding = self.max_self_attn_cache_len-len-x.shape[1] + self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize() + self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize() + q = self.query(x) - k = self.key(xa or x) - v = self.value(xa or x) - wv, qk = self.qkv_attention(q, k, v, mask) - # NOTE: we aren't returning qk + n_ctx = q.shape[1] + assert(q.shape[-1] == k.shape[-1] == v.shape[-1]) + head_dim = q.shape[-1] // self.n_head + q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None) + wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2) return self.out(wv) - def qkv_attention(self, q, k, v, mask=None): - n_batch, n_ctx, n_state = q.shape - scale = (n_state // self.n_head) ** -0.25 - q = q.reshape(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.reshape(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale - v = v.reshape(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) - qk = q @ k - if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - w = qk.softmax(-1) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() class ResidualAttentionBlock: - def __init__(self, n_state, n_head, cross_attention=False): - self.attn = MultiHeadAttention(n_state, n_head) + def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None): + self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len) self.attn_ln = nn.LayerNorm(n_state) - self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None - self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None + self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None + self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)] self.mlp_ln = nn.LayerNorm(n_state) - def __call__(self, x, xa=None, mask=None): - x = x + self.attn(self.attn_ln(x), mask=mask) + def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None): + x = x + self.attn(self.attn_ln(x), mask=mask, len=len) if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa) x = x + self.mlp_ln(x).sequential(self.mlp) - return x + return x.realize() class AudioEncoder: def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_): @@ -66,6 +87,7 @@ class AudioEncoder: self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)] self.ln_post = nn.LayerNorm(n_audio_state) self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state) + self.encode = TinyJit(self.__call__) def __call__(self, x): x = self.conv1(x).gelu() @@ -74,61 +96,82 @@ class AudioEncoder: x = x + self.positional_embedding[:x.shape[1]] x = x.sequential(self.blocks) x = self.ln_post(x) - return x + return x.realize() class TextDecoder: def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_): + self.max_tokens_to_sample = n_text_ctx // 2 + self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 # double the size as an extra buffer for prefix/start tokens + self.token_embedding = nn.Embedding(n_vocab, n_text_state) self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state) - self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, cross_attention=True) for _ in range(n_text_layer)] + self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)] self.ln = nn.LayerNorm(n_text_state) - #mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize() + self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks] + self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks] + self.start_output_tok = TinyJit(self.output_tok) + self.after_start_output_tok = TinyJit(self.output_tok) - def __call__(self, x, xa): - offset = 0 - x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] + def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False): + seqlen = x.shape[-1] + x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen] + if pos == 0: + for block in (self.blocks if streaming else self.blocks_start_tok): + x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching + return self.output_tok(x) if streaming else self.start_output_tok(x) + else: + for block in self.blocks_after_start_tok: + len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos) + x = block(x, mask=self.mask, len=len_v) + return self.after_start_output_tok(x) - seqlen, start_pos = x.shape[1], 0 - - mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32) - mask = np.triu(mask, k=start_pos + 1) # TODO: this is hard to do in tinygrad - mask = Tensor(mask) - - for block in self.blocks: x = block(x, xa, mask) - x = self.ln(x) - return x @ self.token_embedding.weight.T + def output_tok(self, x): + return (self.ln(x) @ self.token_embedding.weight.T).realize() class Whisper: - def __init__(self, dims): + def __init__(self, dims, batch_size=1): self.encoder = AudioEncoder(**dims) self.decoder = TextDecoder(**dims) + self.is_multilingual = dims["n_vocab"] == 51865 + self.batch_size = batch_size - def __call__(self, mel:Tensor, tokens:Tensor): - return self.decoder(tokens, self.encoder(mel)) RATE = 16000 CHUNK = 1600 RECORD_SECONDS = 10 -def prep_audio(waveform) -> Tensor: +def prep_audio(waveform, batch_size) -> np.ndarray: + assert waveform is not None + + def pad_or_trim(arr, target_len=480000): + curr_len = len(arr) + if curr_len == target_len: + return arr + elif curr_len < target_len: + return np.pad(arr, (0, target_len - curr_len), 'constant') + else: + return arr[:target_len] + + waveform = np.array(list(map(pad_or_trim, waveform))) + assert waveform.shape[0] <= batch_size + # pad the waveform to match the model's batch_size to avoid JIT shape mismatch errors. + # if operations like conv in the AudioEncoder could support symbolic shapes, then we wouldn't need to do this here + if waveform.shape[0] < batch_size: + waveform = np.pad(waveform, pad_width=((0, batch_size - waveform.shape[0]), (0, 0))) + N_FFT = 400 HOP_LENGTH = 160 N_MELS = 80 - assert waveform is not None - waveform = waveform.reshape(1, -1) - stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32) - magnitudes = stft[..., :-1] ** 2 + stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle) + magnitudes = np.absolute(stft[..., :-1]) ** 2 mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes - log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8)) + + log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - # https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19 - n_frames = log_spec.shape[2] - if n_frames < 3000: - log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames))) - - #print(waveform.shape, log_spec.shape) return log_spec LANGUAGES = { @@ -145,9 +188,11 @@ LANGUAGES = { } BASE = pathlib.Path(__file__).parents[1] / "weights" -def get_encoding(n_vocab_in): - download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", BASE / "gpt2.tiktoken") - ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(BASE / "gpt2.tiktoken") if line)} +def get_encoding(encoding_name): + filename = encoding_name + ".tiktoken" + download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename) + with open(BASE / filename) as f: + ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)} n_vocab = len(ranks) specials = [ "<|endoftext|>", @@ -163,10 +208,9 @@ def get_encoding(n_vocab_in): ] special_tokens = dict(zip(specials, itertools.count(n_vocab))) n_vocab += len(specials) - assert n_vocab == n_vocab_in import tiktoken return tiktoken.Encoding( - name="bob", + name=encoding_name, explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, @@ -202,39 +246,70 @@ MODEL_URLS = { "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } -def init_whisper(model_name="tiny.en"): +def init_whisper(model_name="tiny.en", batch_size=1): assert MODEL_URLS[model_name] is not None filename = BASE / "whisper-{}.pt".format(model_name) download_file(MODEL_URLS[model_name], filename) state = torch_load(filename) - model = Whisper(state['dims']) - load_state_dict(model, state['model_state_dict']) - - enc = get_encoding(state['dims']['n_vocab']) + model = Whisper(state['dims'], batch_size) + load_state_dict(model, state['model_state_dict'], strict=False) + enc = get_encoding("multilingual" if model.is_multilingual else "gpt2") return model, enc -def transcribe_file(model, enc, filename): - waveform, sample_rate = librosa.load(filename, sr=RATE) - log_spec = prep_audio(waveform) - lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]] - dat = model.encoder(Tensor(log_spec)).realize() +def load_file_waveform(filename): + waveform, _ = librosa.load(filename, sr=RATE) + return waveform - for i in range(50): - out = model.decoder(Tensor([lst]), dat).realize() - idx = int(out[0,-1].argmax().numpy().item()) - lst.append(idx) - transcription = enc.decode(lst) - print(transcription) - if lst[-1] == enc._special_tokens["<|endoftext|>"]: - return transcription +def transcribe_file(model, enc, filename): + return transcribe_waveform(model, enc, [load_file_waveform(filename)]) + +def transcribe_waveform(model, enc, waveforms): + """ + Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples + Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided + """ + N_audio = len(waveforms) + log_spec = prep_audio(waveforms, model.batch_size) + encoded_audio = model.encoder.encode(Tensor(log_spec)) + # print('encoded audio', np.sum(encoded_audio.numpy())) + + start_tokens = [enc._special_tokens["<|startoftranscript|>"]] + if model.is_multilingual: + # TODO detect language + language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en") + start_tokens.append(language_token) + start_tokens.append(enc._special_tokens["<|transcribe|>"]) + start_tokens.append(enc._special_tokens["<|notimestamps|>"]) + + transcription_start_index = len(start_tokens) + eot = enc._special_tokens["<|endoftext|>"] + tokens = np.tile(start_tokens, (log_spec.shape[0], 1)) + + pos = 0 + for i in range(model.decoder.max_tokens_to_sample): + out = model.decoder(Tensor(tokens if i == 0 else tokens[:, -1:]), pos, encoded_audio) + next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32) + next_tokens[tokens[:, -1] == eot] = eot + tokens = np.concatenate((tokens, next_tokens.reshape(-1, 1)), axis=1) + pos = tokens.shape[-1] - 1 + if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), tokens))) + if (tokens[:, -1] == eot).all(): + break + + transcriptions = [] + for t in tokens: + eot_index = np.where(t == eot)[0] + eot_index = None if len(eot_index) == 0 else eot_index[0] + transcriptions.append(enc.decode(t[transcription_start_index:eot_index]).strip()) + return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0] if __name__ == "__main__": - model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en") + model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1) if len(sys.argv) > 1: - transcribe_file(model, enc, sys.argv[1]) + print(transcribe_file(model, enc, sys.argv[1])) else: # online @@ -253,13 +328,13 @@ if __name__ == "__main__": else: total = np.concatenate([total, waveform]) did_read = True if did_read: - log_spec = prep_audio(total) - encoded_audio = model.encoder(Tensor(log_spec)).realize() - out = model.decoder(Tensor([lst]), encoded_audio).realize() + log_spec = prep_audio(total.reshape(1, -1), 1) + encoded_audio = model.encoder.encode(Tensor(log_spec)) + # pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize() idx = int(out[0,-1].argmax().numpy().item()) lst.append(idx) dec = enc.decode(lst) print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT if dec.endswith("<|endoftext|>"): - #total = total[:, 320*(len(lst)-1):] lst.pop() diff --git a/test/external/external_test_whisper_librispeech.py b/test/external/external_test_whisper_librispeech.py new file mode 100644 index 00000000..49ce6a3c --- /dev/null +++ b/test/external/external_test_whisper_librispeech.py @@ -0,0 +1,83 @@ +import unittest +import torch +import tqdm +import torchaudio +import pathlib +import jiwer +import os +import numpy as np +from whisper.normalizers import EnglishTextNormalizer +from examples.whisper import init_whisper, transcribe_waveform + +class TestWhisperLibriSpeech(unittest.TestCase): + # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb + # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22 + # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch + def test_en_tiny(self): + run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749) + + def test_tiny(self): + run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187) + + def test_en_base(self): + run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505) + + def test_en_small(self): + run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228) + +def run_evaluation(model_name, tinygrad_expected_wer, reference_wer): + dataset = LibriSpeech() + batch_size=16 + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) + + model, enc = init_whisper(model_name, batch_size=batch_size) + + hypotheses = [] + references = [] + + for audio, texts in tqdm.tqdm(loader): + transcriptions = transcribe_waveform(model, enc, audio.numpy()) + hypotheses.extend(transcriptions) + references.extend(texts) + + normalizer = EnglishTextNormalizer() + normalized_hypotheses = [normalizer(text) for text in hypotheses] + normalized_references = [normalizer(text) for text in references] + wer = jiwer.wer(normalized_hypotheses, normalized_references) + + np.testing.assert_almost_equal(wer, tinygrad_expected_wer) + print(f'tinygrad WER {wer} vs reference WER {reference_wer}') + del model, enc + +class LibriSpeech(torch.utils.data.Dataset): + def __init__(self): + dir = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech" + if not os.path.exists(dir): + os.makedirs(dir) + + self.dataset = torchaudio.datasets.LIBRISPEECH( + root=dir, + url="test-clean", + download=True, + ) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + audio, sample_rate, text, _, _, _ = self.dataset[item] + assert sample_rate == 16000 + return pad_or_trim_tensor(audio[0]), text + +def pad_or_trim_tensor(tensor, target_len=480000): + curr_len = len(tensor) + if curr_len == target_len: + return tensor + elif curr_len < target_len: + return torch.cat((tensor, torch.zeros(target_len - curr_len))) + else: + return tensor[:target_len] + + +if __name__ == '__main__': + unittest.main() diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index b68dea50..835c238f 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -1,13 +1,13 @@ import unittest import pathlib +from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform from tinygrad.ops import Device -from examples.whisper import init_whisper, transcribe_file @unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array") class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): - model, enc = init_whisper("tiny.en") + model, enc = init_whisper("tiny.en", batch_size=2) cls.model = model cls.enc = enc @@ -22,4 +22,15 @@ class TestWhisper(unittest.TestCase): # We use the WAVE type because it's easier to decode in CI test environments filename = str(pathlib.Path(__file__).parent / "whisper/test.wav") transcription = transcribe_file(self.model, self.enc, filename) - self.assertEqual("<|startoftranscript|><|notimestamps|> Could you please let me out of the box?<|endoftext|>", transcription) + self.assertEqual("Could you please let me out of the box?", transcription) + + def test_transcribe_batch(self): + file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav") + file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav") + + waveforms = [load_file_waveform(file1), load_file_waveform(file2)] + + transcriptions = transcribe_waveform(self.model, self.enc, waveforms) + self.assertEqual(2, len(transcriptions)) + self.assertEqual("Could you please let me out of the box?", transcriptions[0]) + self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1]) diff --git a/test/models/whisper/test2.wav b/test/models/whisper/test2.wav new file mode 100644 index 00000000..e76b04fc Binary files /dev/null and b/test/models/whisper/test2.wav differ diff --git a/test/test_jit.py b/test/test_jit.py index 99dd3dee..6439124c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7,8 +7,7 @@ import pytest pytestmark = pytest.mark.webgpu -# NOTE: METAL fails, might be platform and optimization options dependent. -@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}") +@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}") class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit @@ -190,5 +189,56 @@ class TestJit(unittest.TestCase): [0., 2., 3., 1., 0.]] np.testing.assert_allclose(want, Y) + def test_jitted_read_assign(self): + class Cache: + def __init__(self): + self.good_cache = Tensor.zeros(1) + self.bad_cache = Tensor.zeros(1) + self.good_jitted = TinyJit(self.good) + self.bad_jitted = TinyJit(self.bad) + + def good(self, y, cache_v=None): + if cache_v is not None: + self.good_cache.assign(cache_v+1-1).realize() + return (self.good_cache + y).realize() # need + y to provide inputs to JIT + + def bad(self, y, cache_v=None): + if cache_v is not None: + self.bad_cache.assign(cache_v).realize() + return (self.bad_cache + y).realize() + + cache = Cache() + np.testing.assert_equal([0], cache.good_cache.numpy()) + np.testing.assert_equal([0], cache.bad_cache.numpy()) + + zero = Tensor([0]) + one = Tensor([1]) + two = Tensor([2]) + + # save [1] in the caches + cache.good(zero, one) + cache.bad(zero, one) + np.testing.assert_equal([1], cache.good_cache.numpy()) + np.testing.assert_equal([1], cache.bad_cache.numpy()) + + for i in range(5): + cache.good_jitted(zero) + cache.bad_jitted(zero) + + # verify the jitted calls read 1 from the cache + np.testing.assert_equal([1], cache.good_jitted(zero).numpy()) + np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) + + # save [2] in the caches + cache.good(zero, two) + cache.bad(zero, two) + np.testing.assert_equal([2], cache.good_cache) + np.testing.assert_equal([2], cache.bad_cache) + + # verify the jitted calls read 2 from the cache + np.testing.assert_equal([2], cache.good_jitted(zero).numpy()) + # but the bad_jitted doesn't! + np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) + if __name__ == '__main__': unittest.main() \ No newline at end of file