mirror of https://github.com/commaai/tinygrad.git
whisper: support batch inference, add librispeech WER test (#2074)
* whisper: support batch inference, add librispeech WER test, add kv caching and JIT * remove JIT_SUPPORTED_DEVICE --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
3baaf298d6
commit
8235da11dd
|
@ -5,59 +5,80 @@ import pathlib
|
|||
import base64
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Literal
|
||||
from extra.utils import download_file
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
import tinygrad.nn as nn
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
import itertools
|
||||
import librosa
|
||||
|
||||
# TODO: you have written this fifteen times
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head):
|
||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
|
||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
|
||||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
|
||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
|
||||
if self.kv_caching == 'cross':
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
# see test_jitted_read_assign in test_jit.py
|
||||
self.cache_k.assign(k+1-1).realize()
|
||||
self.cache_v.assign(v+1-1).realize()
|
||||
else:
|
||||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == 'self':
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len-len-x.shape[1]
|
||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
|
||||
q = self.query(x)
|
||||
k = self.key(xa or x)
|
||||
v = self.value(xa or x)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
# NOTE: we aren't returning qk
|
||||
n_ctx = q.shape[1]
|
||||
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
|
||||
def qkv_attention(self, q, k, v, mask=None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.reshape(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.reshape(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.reshape(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
qk = q @ k
|
||||
if mask is not None: qk = qk + mask[:n_ctx, :n_ctx]
|
||||
w = qk.softmax(-1)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, n_state, n_head, cross_attention=False):
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
|
||||
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask)
|
||||
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x
|
||||
return x.realize()
|
||||
|
||||
class AudioEncoder:
|
||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
||||
|
@ -66,6 +87,7 @@ class AudioEncoder:
|
|||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x).gelu()
|
||||
|
@ -74,61 +96,82 @@ class AudioEncoder:
|
|||
x = x + self.positional_embedding[:x.shape[1]]
|
||||
x = x.sequential(self.blocks)
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
return x.realize()
|
||||
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 # double the size as an extra buffer for prefix/start tokens
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, cross_attention=True) for _ in range(n_text_layer)]
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
#mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.start_output_tok = TinyJit(self.output_tok)
|
||||
self.after_start_output_tok = TinyJit(self.output_tok)
|
||||
|
||||
def __call__(self, x, xa):
|
||||
offset = 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
|
||||
if pos == 0:
|
||||
for block in (self.blocks if streaming else self.blocks_start_tok):
|
||||
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
|
||||
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
||||
else:
|
||||
for block in self.blocks_after_start_tok:
|
||||
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
|
||||
x = block(x, mask=self.mask, len=len_v)
|
||||
return self.after_start_output_tok(x)
|
||||
|
||||
seqlen, start_pos = x.shape[1], 0
|
||||
|
||||
mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32)
|
||||
mask = np.triu(mask, k=start_pos + 1) # TODO: this is hard to do in tinygrad
|
||||
mask = Tensor(mask)
|
||||
|
||||
for block in self.blocks: x = block(x, xa, mask)
|
||||
x = self.ln(x)
|
||||
return x @ self.token_embedding.weight.T
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
class Whisper:
|
||||
def __init__(self, dims):
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
self.decoder = TextDecoder(**dims)
|
||||
self.is_multilingual = dims["n_vocab"] == 51865
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, mel:Tensor, tokens:Tensor):
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
RATE = 16000
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
def prep_audio(waveform) -> Tensor:
|
||||
def prep_audio(waveform, batch_size) -> np.ndarray:
|
||||
assert waveform is not None
|
||||
|
||||
def pad_or_trim(arr, target_len=480000):
|
||||
curr_len = len(arr)
|
||||
if curr_len == target_len:
|
||||
return arr
|
||||
elif curr_len < target_len:
|
||||
return np.pad(arr, (0, target_len - curr_len), 'constant')
|
||||
else:
|
||||
return arr[:target_len]
|
||||
|
||||
waveform = np.array(list(map(pad_or_trim, waveform)))
|
||||
assert waveform.shape[0] <= batch_size
|
||||
# pad the waveform to match the model's batch_size to avoid JIT shape mismatch errors.
|
||||
# if operations like conv in the AudioEncoder could support symbolic shapes, then we wouldn't need to do this here
|
||||
if waveform.shape[0] < batch_size:
|
||||
waveform = np.pad(waveform, pad_width=((0, batch_size - waveform.shape[0]), (0, 0)))
|
||||
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
assert waveform is not None
|
||||
waveform = waveform.reshape(1, -1)
|
||||
|
||||
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
|
||||
magnitudes = stft[..., :-1] ** 2
|
||||
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
# https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19
|
||||
n_frames = log_spec.shape[2]
|
||||
if n_frames < 3000:
|
||||
log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames)))
|
||||
|
||||
#print(waveform.shape, log_spec.shape)
|
||||
return log_spec
|
||||
|
||||
LANGUAGES = {
|
||||
|
@ -145,9 +188,11 @@ LANGUAGES = {
|
|||
}
|
||||
|
||||
BASE = pathlib.Path(__file__).parents[1] / "weights"
|
||||
def get_encoding(n_vocab_in):
|
||||
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", BASE / "gpt2.tiktoken")
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(BASE / "gpt2.tiktoken") if line)}
|
||||
def get_encoding(encoding_name):
|
||||
filename = encoding_name + ".tiktoken"
|
||||
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename)
|
||||
with open(BASE / filename) as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
|
@ -163,10 +208,9 @@ def get_encoding(n_vocab_in):
|
|||
]
|
||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
n_vocab += len(specials)
|
||||
assert n_vocab == n_vocab_in
|
||||
import tiktoken
|
||||
return tiktoken.Encoding(
|
||||
name="bob",
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
|
@ -202,39 +246,70 @@ MODEL_URLS = {
|
|||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
def init_whisper(model_name="tiny.en"):
|
||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = BASE / "whisper-{}.pt".format(model_name)
|
||||
download_file(MODEL_URLS[model_name], filename)
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'])
|
||||
load_state_dict(model, state['model_state_dict'])
|
||||
|
||||
enc = get_encoding(state['dims']['n_vocab'])
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
def transcribe_file(model, enc, filename):
|
||||
waveform, sample_rate = librosa.load(filename, sr=RATE)
|
||||
log_spec = prep_audio(waveform)
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
dat = model.encoder(Tensor(log_spec)).realize()
|
||||
def load_file_waveform(filename):
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
|
||||
for i in range(50):
|
||||
out = model.decoder(Tensor([lst]), dat).realize()
|
||||
idx = int(out[0,-1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
transcription = enc.decode(lst)
|
||||
print(transcription)
|
||||
if lst[-1] == enc._special_tokens["<|endoftext|>"]:
|
||||
return transcription
|
||||
def transcribe_file(model, enc, filename):
|
||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||
|
||||
def transcribe_waveform(model, enc, waveforms):
|
||||
"""
|
||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
|
||||
"""
|
||||
N_audio = len(waveforms)
|
||||
log_spec = prep_audio(waveforms, model.batch_size)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# print('encoded audio', np.sum(encoded_audio.numpy()))
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
# TODO detect language
|
||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
|
||||
transcription_start_index = len(start_tokens)
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
|
||||
pos = 0
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(Tensor(tokens if i == 0 else tokens[:, -1:]), pos, encoded_audio)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[tokens[:, -1] == eot] = eot
|
||||
tokens = np.concatenate((tokens, next_tokens.reshape(-1, 1)), axis=1)
|
||||
pos = tokens.shape[-1] - 1
|
||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), tokens)))
|
||||
if (tokens[:, -1] == eot).all():
|
||||
break
|
||||
|
||||
transcriptions = []
|
||||
for t in tokens:
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcriptions.append(enc.decode(t[transcription_start_index:eot_index]).strip())
|
||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en")
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
transcribe_file(model, enc, sys.argv[1])
|
||||
print(transcribe_file(model, enc, sys.argv[1]))
|
||||
else:
|
||||
# online
|
||||
|
||||
|
@ -253,13 +328,13 @@ if __name__ == "__main__":
|
|||
else: total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
log_spec = prep_audio(total)
|
||||
encoded_audio = model.encoder(Tensor(log_spec)).realize()
|
||||
out = model.decoder(Tensor([lst]), encoded_audio).realize()
|
||||
log_spec = prep_audio(total.reshape(1, -1), 1)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
|
||||
idx = int(out[0,-1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
dec = enc.decode(lst)
|
||||
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
|
||||
if dec.endswith("<|endoftext|>"):
|
||||
#total = total[:, 320*(len(lst)-1):]
|
||||
lst.pop()
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import unittest
|
||||
import torch
|
||||
import tqdm
|
||||
import torchaudio
|
||||
import pathlib
|
||||
import jiwer
|
||||
import os
|
||||
import numpy as np
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
from examples.whisper import init_whisper, transcribe_waveform
|
||||
|
||||
class TestWhisperLibriSpeech(unittest.TestCase):
|
||||
# reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
|
||||
# the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
|
||||
# tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
|
||||
def test_en_tiny(self):
|
||||
run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
|
||||
|
||||
def test_tiny(self):
|
||||
run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
|
||||
|
||||
def test_en_base(self):
|
||||
run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
|
||||
|
||||
def test_en_small(self):
|
||||
run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
|
||||
|
||||
def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
|
||||
dataset = LibriSpeech()
|
||||
batch_size=16
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
model, enc = init_whisper(model_name, batch_size=batch_size)
|
||||
|
||||
hypotheses = []
|
||||
references = []
|
||||
|
||||
for audio, texts in tqdm.tqdm(loader):
|
||||
transcriptions = transcribe_waveform(model, enc, audio.numpy())
|
||||
hypotheses.extend(transcriptions)
|
||||
references.extend(texts)
|
||||
|
||||
normalizer = EnglishTextNormalizer()
|
||||
normalized_hypotheses = [normalizer(text) for text in hypotheses]
|
||||
normalized_references = [normalizer(text) for text in references]
|
||||
wer = jiwer.wer(normalized_hypotheses, normalized_references)
|
||||
|
||||
np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
|
||||
print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
|
||||
del model, enc
|
||||
|
||||
class LibriSpeech(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
dir = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
|
||||
self.dataset = torchaudio.datasets.LIBRISPEECH(
|
||||
root=dir,
|
||||
url="test-clean",
|
||||
download=True,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
audio, sample_rate, text, _, _, _ = self.dataset[item]
|
||||
assert sample_rate == 16000
|
||||
return pad_or_trim_tensor(audio[0]), text
|
||||
|
||||
def pad_or_trim_tensor(tensor, target_len=480000):
|
||||
curr_len = len(tensor)
|
||||
if curr_len == target_len:
|
||||
return tensor
|
||||
elif curr_len < target_len:
|
||||
return torch.cat((tensor, torch.zeros(target_len - curr_len)))
|
||||
else:
|
||||
return tensor[:target_len]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,13 +1,13 @@
|
|||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from tinygrad.ops import Device
|
||||
from examples.whisper import init_whisper, transcribe_file
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model, enc = init_whisper("tiny.en")
|
||||
model, enc = init_whisper("tiny.en", batch_size=2)
|
||||
cls.model = model
|
||||
cls.enc = enc
|
||||
|
||||
|
@ -22,4 +22,15 @@ class TestWhisper(unittest.TestCase):
|
|||
# We use the WAVE type because it's easier to decode in CI test environments
|
||||
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
transcription = transcribe_file(self.model, self.enc, filename)
|
||||
self.assertEqual("<|startoftranscript|><|notimestamps|> Could you please let me out of the box?<|endoftext|>", transcription)
|
||||
self.assertEqual("Could you please let me out of the box?", transcription)
|
||||
|
||||
def test_transcribe_batch(self):
|
||||
file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
||||
|
||||
waveforms = [load_file_waveform(file1), load_file_waveform(file2)]
|
||||
|
||||
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
self.assertEqual(2, len(transcriptions))
|
||||
self.assertEqual("Could you please let me out of the box?", transcriptions[0])
|
||||
self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1])
|
||||
|
|
Binary file not shown.
|
@ -7,8 +7,7 @@ import pytest
|
|||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
# NOTE: METAL fails, might be platform and optimization options dependent.
|
||||
@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(Device.DEFAULT != "WEBGPU", f"no JIT on {Device.DEFAULT}")
|
||||
class TestJit(unittest.TestCase):
|
||||
def test_simple_jit(self):
|
||||
@TinyJit
|
||||
|
@ -190,5 +189,56 @@ class TestJit(unittest.TestCase):
|
|||
[0., 2., 3., 1., 0.]]
|
||||
np.testing.assert_allclose(want, Y)
|
||||
|
||||
def test_jitted_read_assign(self):
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self.good_cache = Tensor.zeros(1)
|
||||
self.bad_cache = Tensor.zeros(1)
|
||||
self.good_jitted = TinyJit(self.good)
|
||||
self.bad_jitted = TinyJit(self.bad)
|
||||
|
||||
def good(self, y, cache_v=None):
|
||||
if cache_v is not None:
|
||||
self.good_cache.assign(cache_v+1-1).realize()
|
||||
return (self.good_cache + y).realize() # need + y to provide inputs to JIT
|
||||
|
||||
def bad(self, y, cache_v=None):
|
||||
if cache_v is not None:
|
||||
self.bad_cache.assign(cache_v).realize()
|
||||
return (self.bad_cache + y).realize()
|
||||
|
||||
cache = Cache()
|
||||
np.testing.assert_equal([0], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([0], cache.bad_cache.numpy())
|
||||
|
||||
zero = Tensor([0])
|
||||
one = Tensor([1])
|
||||
two = Tensor([2])
|
||||
|
||||
# save [1] in the caches
|
||||
cache.good(zero, one)
|
||||
cache.bad(zero, one)
|
||||
np.testing.assert_equal([1], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
||||
|
||||
for i in range(5):
|
||||
cache.good_jitted(zero)
|
||||
cache.bad_jitted(zero)
|
||||
|
||||
# verify the jitted calls read 1 from the cache
|
||||
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
# save [2] in the caches
|
||||
cache.good(zero, two)
|
||||
cache.bad(zero, two)
|
||||
np.testing.assert_equal([2], cache.good_cache)
|
||||
np.testing.assert_equal([2], cache.bad_cache)
|
||||
|
||||
# verify the jitted calls read 2 from the cache
|
||||
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
|
||||
# but the bad_jitted doesn't!
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue