mirror of https://github.com/commaai/tinygrad.git
84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
import unittest
|
|
import torch
|
|
import tqdm
|
|
import torchaudio
|
|
import pathlib
|
|
import jiwer
|
|
import os
|
|
import numpy as np
|
|
from whisper.normalizers import EnglishTextNormalizer
|
|
from examples.whisper import init_whisper, transcribe_waveform
|
|
|
|
class TestWhisperLibriSpeech(unittest.TestCase):
|
|
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
|
|
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
|
|
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
|
|
def test_en_tiny(self):
|
|
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
|
|
|
|
def test_tiny(self):
|
|
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
|
|
|
|
def test_en_base(self):
|
|
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
|
|
|
|
def test_en_small(self):
|
|
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
|
|
|
|
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
|
|
dataset = LibriSpeech()
|
|
batch_size=16
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
|
|
|
model, enc = init_whisper(model_name, batch_size=batch_size)
|
|
|
|
hypotheses = []
|
|
references = []
|
|
|
|
for audio, texts in tqdm.tqdm(loader):
|
|
transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
|
|
hypotheses.extend(transcriptions)
|
|
references.extend(texts)
|
|
|
|
normalizer = EnglishTextNormalizer()
|
|
normalized_hypotheses = [normalizer(text) for text in hypotheses]
|
|
normalized_references = [normalizer(text) for text in references]
|
|
wer = jiwer.wer(normalized_hypotheses, normalized_references)
|
|
|
|
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
|
|
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
|
|
del model, enc
|
|
|
|
class LibriSpeech(torch.utils.data.Dataset):
|
|
def __init__(self):
|
|
folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
|
|
if not os.path.exists(folder):
|
|
os.makedirs(folder)
|
|
|
|
self.dataset = torchaudio.datasets.LIBRISPEECH(
|
|
root=folder,
|
|
url="test-clean",
|
|
download=True,
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, item):
|
|
audio, sample_rate, text, _, _, _ = self.dataset[item]
|
|
assert sample_rate == 16000
|
|
return pad_or_trim_tensor(audio[0]), text
|
|
|
|
def pad_or_trim_tensor(tensor, target_len=480000):
|
|
curr_len = len(tensor)
|
|
if curr_len == target_len:
|
|
return tensor
|
|
elif curr_len < target_len:
|
|
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
|
|
else:
|
|
return tensor[:target_len]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|