tinygrad/test/external/external_test_whisper_libri...

84 lines
2.7 KiB
Python

import unittest
import torch
import tqdm
import torchaudio
import pathlib
import jiwer
import os
import numpy as np
from whisper.normalizers import EnglishTextNormalizer
from examples.whisper import init_whisper, transcribe_waveform
class TestWhisperLibriSpeech(unittest.TestCase):
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
def test_en_tiny(self):
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
def test_tiny(self):
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
def test_en_base(self):
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
def test_en_small(self):
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
dataset = LibriSpeech()
batch_size=16
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
model, enc = init_whisper(model_name, batch_size=batch_size)
hypotheses = []
references = []
for audio, texts in tqdm.tqdm(loader):
transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
hypotheses.extend(transcriptions)
references.extend(texts)
normalizer = EnglishTextNormalizer()
normalized_hypotheses = [normalizer(text) for text in hypotheses]
normalized_references = [normalizer(text) for text in references]
wer = jiwer.wer(normalized_hypotheses, normalized_references)
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
del model, enc
class LibriSpeech(torch.utils.data.Dataset):
def __init__(self):
folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
if not os.path.exists(folder):
os.makedirs(folder)
self.dataset = torchaudio.datasets.LIBRISPEECH(
root=folder,
url="test-clean",
download=True,
)
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
audio, sample_rate, text, _, _, _ = self.dataset[item]
assert sample_rate == 16000
return pad_or_trim_tensor(audio[0]), text
def pad_or_trim_tensor(tensor, target_len=480000):
curr_len = len(tensor)
if curr_len == target_len:
return tensor
elif curr_len < target_len:
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
else:
return tensor[:target_len]
if __name__ == '__main__':
unittest.main()