whisper: support batch inference, add librispeech WER test (#2074)

* whisper: support batch inference, add librispeech WER test, add kv caching and JIT

* remove JIT_SUPPORTED_DEVICE

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
mmmkkaaayy 2023-11-16 13:50:08 -08:00 committed by GitHub
parent 3baaf298d6
commit 8235da11dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 309 additions and 90 deletions

View File

@ -5,59 +5,80 @@ import pathlib
import base64
import multiprocessing
import numpy as np
from typing import Optional
from typing import Optional, Union, Literal
from extra.utils import download_file
from tinygrad.jit import TinyJit
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, DEBUG, CI
import tinygrad.nn as nn
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
import itertools
import librosa
# TODO: you have written this fifteen times
class MultiHeadAttention:
def __init__(self, n_state, n_head):
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
self.kv_caching = kv_caching
self.max_self_attn_cache_len = max_self_attn_cache_len
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
if self.kv_caching == 'cross':
if xa is not None:
k, v = self.key(xa), self.value(xa)
if not hasattr(self, 'cache_k'):
self.cache_k, self.cache_v = k, v
else:
# see test_jitted_read_assign in test_jit.py
self.cache_k.assign(k+1-1).realize()
self.cache_v.assign(v+1-1).realize()
else:
k, v = self.cache_k, self.cache_v
else:
k, v = self.key(x), self.value(x)
if self.kv_caching == 'self':
if not hasattr(self, 'cache_k'):
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
padding = self.max_self_attn_cache_len-len-x.shape[1]
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
q = self.query(x)
k = self.key(xa or x)
v = self.value(xa or x)
wv, qk = self.qkv_attention(q, k, v, mask)
# NOTE: we aren't returning qk
n_ctx = q.shape[1]
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
head_dim = q.shape[-1] // self.n_head
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
return self.out(wv)
def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.reshape(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.reshape(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.reshape(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None: qk = qk + mask[:n_ctx, :n_ctx]
w = qk.softmax(-1)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock:
def __init__(self, n_state, n_head, cross_attention=False):
self.attn = MultiHeadAttention(n_state, n_head)
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None):
x = x + self.attn(self.attn_ln(x), mask=mask)
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None):
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x
return x.realize()
class AudioEncoder:
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
@ -66,6 +87,7 @@ class AudioEncoder:
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
self.encode = TinyJit(self.__call__)
def __call__(self, x):
x = self.conv1(x).gelu()
@ -74,61 +96,82 @@ class AudioEncoder:
x = x + self.positional_embedding[:x.shape[1]]
x = x.sequential(self.blocks)
x = self.ln_post(x)
return x
return x.realize()
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 # double the size as an extra buffer for prefix/start tokens
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, cross_attention=True) for _ in range(n_text_layer)]
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
self.ln = nn.LayerNorm(n_text_state)
#mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.start_output_tok = TinyJit(self.output_tok)
self.after_start_output_tok = TinyJit(self.output_tok)
def __call__(self, x, xa):
offset = 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
if pos == 0:
for block in (self.blocks if streaming else self.blocks_start_tok):
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
return self.output_tok(x) if streaming else self.start_output_tok(x)
else:
for block in self.blocks_after_start_tok:
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
x = block(x, mask=self.mask, len=len_v)
return self.after_start_output_tok(x)
seqlen, start_pos = x.shape[1], 0
mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32)
mask = np.triu(mask, k=start_pos + 1) # TODO: this is hard to do in tinygrad
mask = Tensor(mask)
for block in self.blocks: x = block(x, xa, mask)
x = self.ln(x)
return x @ self.token_embedding.weight.T
def output_tok(self, x):
return (self.ln(x) @ self.token_embedding.weight.T).realize()
class Whisper:
def __init__(self, dims):
def __init__(self, dims, batch_size=1):
self.encoder = AudioEncoder(**dims)
self.decoder = TextDecoder(**dims)
self.is_multilingual = dims["n_vocab"] == 51865
self.batch_size = batch_size
def __call__(self, mel:Tensor, tokens:Tensor):
return self.decoder(tokens, self.encoder(mel))
RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
def prep_audio(waveform) -> Tensor:
def prep_audio(waveform, batch_size) -> np.ndarray:
assert waveform is not None
def pad_or_trim(arr, target_len=480000):
curr_len = len(arr)
if curr_len == target_len:
return arr
elif curr_len < target_len:
return np.pad(arr, (0, target_len - curr_len), 'constant')
else:
return arr[:target_len]
waveform = np.array(list(map(pad_or_trim, waveform)))
assert waveform.shape[0] <= batch_size
# pad the waveform to match the model's batch_size to avoid JIT shape mismatch errors.
# if operations like conv in the AudioEncoder could support symbolic shapes, then we wouldn't need to do this here
if waveform.shape[0] < batch_size:
waveform = np.pad(waveform, pad_width=((0, batch_size - waveform.shape[0]), (0, 0)))
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
assert waveform is not None
waveform = waveform.reshape(1, -1)
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
magnitudes = stft[..., :-1] ** 2
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
# https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19
n_frames = log_spec.shape[2]
if n_frames < 3000:
log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames)))
#print(waveform.shape, log_spec.shape)
return log_spec
LANGUAGES = {
@ -145,9 +188,11 @@ LANGUAGES = {
}
BASE = pathlib.Path(__file__).parents[1] / "weights"
def get_encoding(n_vocab_in):
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", BASE / "gpt2.tiktoken")
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(BASE / "gpt2.tiktoken") if line)}
def get_encoding(encoding_name):
filename = encoding_name + ".tiktoken"
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename)
with open(BASE / filename) as f:
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
@ -163,10 +208,9 @@ def get_encoding(n_vocab_in):
]
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
n_vocab += len(specials)
assert n_vocab == n_vocab_in
import tiktoken
return tiktoken.Encoding(
name="bob",
name=encoding_name,
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
@ -202,39 +246,70 @@ MODEL_URLS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def init_whisper(model_name="tiny.en"):
def init_whisper(model_name="tiny.en", batch_size=1):
assert MODEL_URLS[model_name] is not None
filename = BASE / "whisper-{}.pt".format(model_name)
download_file(MODEL_URLS[model_name], filename)
state = torch_load(filename)
model = Whisper(state['dims'])
load_state_dict(model, state['model_state_dict'])
enc = get_encoding(state['dims']['n_vocab'])
model = Whisper(state['dims'], batch_size)
load_state_dict(model, state['model_state_dict'], strict=False)
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
return model, enc
def transcribe_file(model, enc, filename):
waveform, sample_rate = librosa.load(filename, sr=RATE)
log_spec = prep_audio(waveform)
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
dat = model.encoder(Tensor(log_spec)).realize()
def load_file_waveform(filename):
waveform, _ = librosa.load(filename, sr=RATE)
return waveform
for i in range(50):
out = model.decoder(Tensor([lst]), dat).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
transcription = enc.decode(lst)
print(transcription)
if lst[-1] == enc._special_tokens["<|endoftext|>"]:
return transcription
def transcribe_file(model, enc, filename):
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
def transcribe_waveform(model, enc, waveforms):
"""
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
"""
N_audio = len(waveforms)
log_spec = prep_audio(waveforms, model.batch_size)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# print('encoded audio', np.sum(encoded_audio.numpy()))
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
# TODO detect language
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
transcription_start_index = len(start_tokens)
eot = enc._special_tokens["<|endoftext|>"]
tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
pos = 0
for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder(Tensor(tokens if i == 0 else tokens[:, -1:]), pos, encoded_audio)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[tokens[:, -1] == eot] = eot
tokens = np.concatenate((tokens, next_tokens.reshape(-1, 1)), axis=1)
pos = tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), tokens)))
if (tokens[:, -1] == eot).all():
break
transcriptions = []
for t in tokens:
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcriptions.append(enc.decode(t[transcription_start_index:eot_index]).strip())
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en")
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
if len(sys.argv) > 1:
transcribe_file(model, enc, sys.argv[1])
print(transcribe_file(model, enc, sys.argv[1]))
else:
# online
@ -253,13 +328,13 @@ if __name__ == "__main__":
else: total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(total)
encoded_audio = model.encoder(Tensor(log_spec)).realize()
out = model.decoder(Tensor([lst]), encoded_audio).realize()
log_spec = prep_audio(total.reshape(1, -1), 1)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec.endswith("<|endoftext|>"):
#total = total[:, 320*(len(lst)-1):]
lst.pop()

View File

@ -0,0 +1,83 @@
import unittest
import torch
import tqdm
import torchaudio
import pathlib
import jiwer
import os
import numpy as np
from whisper.normalizers import EnglishTextNormalizer
from examples.whisper import init_whisper, transcribe_waveform
class TestWhisperLibriSpeech(unittest.TestCase):
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
def test_en_tiny(self):
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
def test_tiny(self):
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
def test_en_base(self):
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
def test_en_small(self):
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
dataset = LibriSpeech()
batch_size=16
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
model, enc = init_whisper(model_name, batch_size=batch_size)
hypotheses = []
references = []
for audio, texts in tqdm.tqdm(loader):
transcriptions = transcribe_waveform(model, enc, audio.numpy())
hypotheses.extend(transcriptions)
references.extend(texts)
normalizer = EnglishTextNormalizer()
normalized_hypotheses = [normalizer(text) for text in hypotheses]
normalized_references = [normalizer(text) for text in references]
wer = jiwer.wer(normalized_hypotheses, normalized_references)
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
del model, enc
class LibriSpeech(torch.utils.data.Dataset):
def __init__(self):
dir = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
if not os.path.exists(dir):
os.makedirs(dir)
self.dataset = torchaudio.datasets.LIBRISPEECH(
root=dir,
url="test-clean",
download=True,
)
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
audio, sample_rate, text, _, _, _ = self.dataset[item]
assert sample_rate == 16000
return pad_or_trim_tensor(audio[0]), text
def pad_or_trim_tensor(tensor, target_len=480000):
curr_len = len(tensor)
if curr_len == target_len:
return tensor
elif curr_len < target_len:
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
else:
return tensor[:target_len]
if __name__ == '__main__':
unittest.main()

View File

@ -1,13 +1,13 @@
import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.ops import Device
from examples.whisper import init_whisper, transcribe_file
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
model, enc = init_whisper("tiny.en")
model, enc = init_whisper("tiny.en", batch_size=2)
cls.model = model
cls.enc = enc
@ -22,4 +22,15 @@ class TestWhisper(unittest.TestCase):
# We use the WAVE type because it's easier to decode in CI test environments
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
transcription = transcribe_file(self.model, self.enc, filename)
self.assertEqual("<|startoftranscript|><|notimestamps|> Could you please let me out of the box?<|endoftext|>", transcription)
self.assertEqual("Could you please let me out of the box?", transcription)
def test_transcribe_batch(self):
file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
waveforms = [load_file_waveform(file1), load_file_waveform(file2)]
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
self.assertEqual(2, len(transcriptions))
self.assertEqual("Could you please let me out of the box?", transcriptions[0])
self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1])

Binary file not shown.

View File

@ -7,8 +7,7 @@ import pytest
pytestmark = pytest.mark.webgpu
# NOTE: METAL fails, might be platform and optimization options dependent.
@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit
@ -190,5 +189,56 @@ class TestJit(unittest.TestCase):
[0., 2., 3., 1., 0.]]
np.testing.assert_allclose(want, Y)
def test_jitted_read_assign(self):
class Cache:
def __init__(self):
self.good_cache = Tensor.zeros(1)
self.bad_cache = Tensor.zeros(1)
self.good_jitted = TinyJit(self.good)
self.bad_jitted = TinyJit(self.bad)
def good(self, y, cache_v=None):
if cache_v is not None:
self.good_cache.assign(cache_v+1-1).realize()
return (self.good_cache + y).realize() # need + y to provide inputs to JIT
def bad(self, y, cache_v=None):
if cache_v is not None:
self.bad_cache.assign(cache_v).realize()
return (self.bad_cache + y).realize()
cache = Cache()
np.testing.assert_equal([0], cache.good_cache.numpy())
np.testing.assert_equal([0], cache.bad_cache.numpy())
zero = Tensor([0])
one = Tensor([1])
two = Tensor([2])
# save [1] in the caches
cache.good(zero, one)
cache.bad(zero, one)
np.testing.assert_equal([1], cache.good_cache.numpy())
np.testing.assert_equal([1], cache.bad_cache.numpy())
for i in range(5):
cache.good_jitted(zero)
cache.bad_jitted(zero)
# verify the jitted calls read 1 from the cache
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
# save [2] in the caches
cache.good(zero, two)
cache.bad(zero, two)
np.testing.assert_equal([2], cache.good_cache)
np.testing.assert_equal([2], cache.bad_cache)
# verify the jitted calls read 2 from the cache
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
# but the bad_jitted doesn't!
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
if __name__ == '__main__':
unittest.main()