whisper: support audio >30s (#2378)

* whisper: support audio >30s

* make prompt indexing consistent with reference repo

* fix online
This commit is contained in:
mmmkkaaayy 2023-11-21 14:37:51 -08:00 committed by GitHub
parent 7220f5c9fc
commit 7f0cc4a4e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 61 deletions

View File

@ -5,7 +5,7 @@ import pathlib
import base64
import multiprocessing
import numpy as np
from typing import Optional, Union, Literal
from typing import Optional, Union, Literal, List
from extra.utils import download_file
from tinygrad.jit import TinyJit
from tinygrad.nn.state import torch_load, load_state_dict
@ -101,7 +101,7 @@ class AudioEncoder:
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 # double the size as an extra buffer for prefix/start tokens
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
@ -113,6 +113,7 @@ class TextDecoder:
self.start_output_tok = TinyJit(self.output_tok)
self.after_start_output_tok = TinyJit(self.output_tok)
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
@ -138,13 +139,22 @@ class Whisper:
RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
SEGMENT_SECONDS=30
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
def prep_audio(waveform, batch_size) -> np.ndarray:
assert waveform is not None
def pad_or_trim(arr, target_len=480000):
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
"""
:param waveforms: A list of possibly variable length 16000Hz audio samples
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
:return: mel spectrogram of the given waveforms
"""
def pad_or_trim(arr, target_len):
curr_len = len(arr)
if curr_len == target_len:
return arr
@ -153,18 +163,15 @@ def prep_audio(waveform, batch_size) -> np.ndarray:
else:
return arr[:target_len]
waveform = np.array(list(map(pad_or_trim, waveform)))
assert waveform.shape[0] <= batch_size
# pad the waveform to match the model's batch_size to avoid JIT shape mismatch errors.
# if operations like conv in the AudioEncoder could support symbolic shapes, then we wouldn't need to do this here
if waveform.shape[0] < batch_size:
waveform = np.pad(waveform, pad_width=((0, batch_size - waveform.shape[0]), (0, 0)))
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
assert waveforms.shape[0] <= batch_size
if waveforms.shape[0] < batch_size:
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
@ -216,23 +223,6 @@ def get_encoding(encoding_name):
mergeable_ranks=ranks,
special_tokens=special_tokens)
def img(x):
import matplotlib.pyplot as plt
plt.imshow(x.numpy())
plt.show()
def listener(q):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
q.put(waveform)
print("done listening")
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -264,15 +254,18 @@ def load_file_waveform(filename):
def transcribe_file(model, enc, filename):
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
def transcribe_waveform(model, enc, waveforms):
def transcribe_waveform(model, enc, waveforms, truncate=False):
"""
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
"""
N_audio = len(waveforms)
log_spec = prep_audio(waveforms, model.batch_size)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# print('encoded audio', np.sum(encoded_audio.numpy()))
log_spec = prep_audio(waveforms, model.batch_size, truncate)
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
raise Exception("Multi-segment transcription not supported with batch audio input")
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
@ -281,29 +274,54 @@ def transcribe_waveform(model, enc, waveforms):
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
transcription_start_index = len(start_tokens)
eot = enc._special_tokens["<|endoftext|>"]
tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
pos = 0
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
if curr_frame > 0:
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt = np.concatenate((
[enc._special_tokens["<|startofprev|>"]],
transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
start_tokens))
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
transcription_start_index = len(curr_segment_tokens[0])
for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder(Tensor(tokens if i == 0 else tokens[:, -1:]), pos, encoded_audio)
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[tokens[:, -1] == eot] = eot
tokens = np.concatenate((tokens, next_tokens.reshape(-1, 1)), axis=1)
pos = tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), tokens)))
if (tokens[:, -1] == eot).all():
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
pos = curr_segment_tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
if (curr_segment_tokens[:, -1] == eot).all():
break
transcriptions = []
for t in tokens:
for i, t in enumerate(curr_segment_tokens):
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcriptions.append(enc.decode(t[transcription_start_index:eot_index]).strip())
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
CHUNK = 1600
RECORD_SECONDS = 10
def listener(q):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
q.put(waveform)
print("done listening")
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
@ -312,7 +330,6 @@ if __name__ == "__main__":
print(transcribe_file(model, enc, sys.argv[1]))
else:
# online
q = multiprocessing.Queue()
p = multiprocessing.Process(target=listener, args=(q,))
p.daemon = True
@ -328,7 +345,7 @@ if __name__ == "__main__":
else: total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(total.reshape(1, -1), 1)
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()

View File

@ -36,7 +36,7 @@ def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
references = []
for audio, texts in tqdm.tqdm(loader):
transcriptions = transcribe_waveform(model, enc, audio.numpy())
transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
hypotheses.extend(transcriptions)
references.extend(texts)

View File

@ -1,7 +1,7 @@
import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI
from tinygrad.helpers import CI, fetch
from tinygrad.ops import Device
# Audio generated with the command on MacOS:
@ -11,6 +11,9 @@ TEST_FILE_1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
TRANSCRIPTION_1 = "Could you please let me out of the box?"
TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
# TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time."
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
class TestWhisper(unittest.TestCase):
@ -28,9 +31,11 @@ class TestWhisper(unittest.TestCase):
def test_transcribe_file1(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
@unittest.skipIf(CI, "too many tests for CI")
def test_transcribe_file2(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
@unittest.skipIf(CI, "too many tests for CI")
def test_transcribe_batch12(self):
waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)]
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
@ -44,3 +49,17 @@ class TestWhisper(unittest.TestCase):
self.assertEqual(2, len(transcriptions))
self.assertEqual(TRANSCRIPTION_2, transcriptions[0])
self.assertEqual(TRANSCRIPTION_1, transcriptions[1])
@unittest.skipIf(CI, "too long for CI")
def test_transcribe_long(self):
waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))]
transcription = transcribe_waveform(self.model, self.enc, waveform)
self.assertEqual(TRANSCRIPTION_3, transcription)
@unittest.skipIf(CI, "too long for CI")
def test_transcribe_long_no_batch(self):
waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)]
with self.assertRaises(Exception):
transcribe_waveform(self.model, self.enc, waveforms)