mirror of https://github.com/commaai/tinygrad.git
whisper: support audio >30s (#2378)
* whisper: support audio >30s * make prompt indexing consistent with reference repo * fix online
This commit is contained in:
parent
7220f5c9fc
commit
7f0cc4a4e8
|
@ -5,7 +5,7 @@ import pathlib
|
|||
import base64
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional, Union, Literal, List
|
||||
from extra.utils import download_file
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
|
@ -101,7 +101,7 @@ class AudioEncoder:
|
|||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 # double the size as an extra buffer for prefix/start tokens
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
|
@ -113,6 +113,7 @@ class TextDecoder:
|
|||
self.start_output_tok = TinyJit(self.output_tok)
|
||||
self.after_start_output_tok = TinyJit(self.output_tok)
|
||||
|
||||
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
|
||||
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
|
||||
|
@ -138,13 +139,22 @@ class Whisper:
|
|||
|
||||
|
||||
RATE = 16000
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
SEGMENT_SECONDS=30
|
||||
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
||||
|
||||
def prep_audio(waveform, batch_size) -> np.ndarray:
|
||||
assert waveform is not None
|
||||
|
||||
def pad_or_trim(arr, target_len=480000):
|
||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
|
||||
"""
|
||||
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
||||
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
||||
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
|
||||
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
||||
:return: mel spectrogram of the given waveforms
|
||||
"""
|
||||
def pad_or_trim(arr, target_len):
|
||||
curr_len = len(arr)
|
||||
if curr_len == target_len:
|
||||
return arr
|
||||
|
@ -153,18 +163,15 @@ def prep_audio(waveform, batch_size) -> np.ndarray:
|
|||
else:
|
||||
return arr[:target_len]
|
||||
|
||||
waveform = np.array(list(map(pad_or_trim, waveform)))
|
||||
assert waveform.shape[0] <= batch_size
|
||||
# pad the waveform to match the model's batch_size to avoid JIT shape mismatch errors.
|
||||
# if operations like conv in the AudioEncoder could support symbolic shapes, then we wouldn't need to do this here
|
||||
if waveform.shape[0] < batch_size:
|
||||
waveform = np.pad(waveform, pad_width=((0, batch_size - waveform.shape[0]), (0, 0)))
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
assert waveforms.shape[0] <= batch_size
|
||||
if waveforms.shape[0] < batch_size:
|
||||
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
||||
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
|
||||
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
|
||||
|
@ -216,23 +223,6 @@ def get_encoding(encoding_name):
|
|||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens)
|
||||
|
||||
def img(x):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(x.numpy())
|
||||
plt.show()
|
||||
|
||||
def listener(q):
|
||||
import pyaudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
|
@ -264,15 +254,18 @@ def load_file_waveform(filename):
|
|||
def transcribe_file(model, enc, filename):
|
||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||
|
||||
def transcribe_waveform(model, enc, waveforms):
|
||||
def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||
"""
|
||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
|
||||
"""
|
||||
N_audio = len(waveforms)
|
||||
log_spec = prep_audio(waveforms, model.batch_size)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# print('encoded audio', np.sum(encoded_audio.numpy()))
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
|
||||
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
||||
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
||||
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
||||
raise Exception("Multi-segment transcription not supported with batch audio input")
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
|
@ -281,29 +274,54 @@ def transcribe_waveform(model, enc, waveforms):
|
|||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
|
||||
transcription_start_index = len(start_tokens)
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
||||
|
||||
pos = 0
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(Tensor(tokens if i == 0 else tokens[:, -1:]), pos, encoded_audio)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[tokens[:, -1] == eot] = eot
|
||||
tokens = np.concatenate((tokens, next_tokens.reshape(-1, 1)), axis=1)
|
||||
pos = tokens.shape[-1] - 1
|
||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), tokens)))
|
||||
if (tokens[:, -1] == eot).all():
|
||||
break
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
||||
pos = 0
|
||||
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
if curr_frame > 0:
|
||||
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt = np.concatenate((
|
||||
[enc._special_tokens["<|startofprev|>"]],
|
||||
transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
|
||||
start_tokens))
|
||||
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
||||
transcription_start_index = len(curr_segment_tokens[0])
|
||||
|
||||
transcriptions = []
|
||||
for t in tokens:
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcriptions.append(enc.decode(t[transcription_start_index:eot_index]).strip())
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
||||
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
|
||||
pos = curr_segment_tokens.shape[-1] - 1
|
||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
|
||||
if (curr_segment_tokens[:, -1] == eot).all():
|
||||
break
|
||||
|
||||
for i, t in enumerate(curr_segment_tokens):
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
|
||||
|
||||
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
|
||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
def listener(q):
|
||||
import pyaudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
|
||||
|
@ -312,7 +330,6 @@ if __name__ == "__main__":
|
|||
print(transcribe_file(model, enc, sys.argv[1]))
|
||||
else:
|
||||
# online
|
||||
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=listener, args=(q,))
|
||||
p.daemon = True
|
||||
|
@ -328,7 +345,7 @@ if __name__ == "__main__":
|
|||
else: total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
log_spec = prep_audio(total.reshape(1, -1), 1)
|
||||
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
|
||||
|
|
|
@ -36,7 +36,7 @@ def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
|
|||
references = []
|
||||
|
||||
for audio, texts in tqdm.tqdm(loader):
|
||||
transcriptions = transcribe_waveform(model, enc, audio.numpy())
|
||||
transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
|
||||
hypotheses.extend(transcriptions)
|
||||
references.extend(texts)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, fetch
|
||||
from tinygrad.ops import Device
|
||||
|
||||
# Audio generated with the command on MacOS:
|
||||
|
@ -11,6 +11,9 @@ TEST_FILE_1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
|||
TRANSCRIPTION_1 = "Could you please let me out of the box?"
|
||||
TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
||||
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
|
||||
# TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe
|
||||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time."
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
|
@ -28,9 +31,11 @@ class TestWhisper(unittest.TestCase):
|
|||
def test_transcribe_file1(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
@unittest.skipIf(CI, "too many tests for CI")
|
||||
def test_transcribe_file2(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
|
||||
|
||||
@unittest.skipIf(CI, "too many tests for CI")
|
||||
def test_transcribe_batch12(self):
|
||||
waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)]
|
||||
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
|
@ -44,3 +49,17 @@ class TestWhisper(unittest.TestCase):
|
|||
self.assertEqual(2, len(transcriptions))
|
||||
self.assertEqual(TRANSCRIPTION_2, transcriptions[0])
|
||||
self.assertEqual(TRANSCRIPTION_1, transcriptions[1])
|
||||
|
||||
@unittest.skipIf(CI, "too long for CI")
|
||||
def test_transcribe_long(self):
|
||||
waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))]
|
||||
transcription = transcribe_waveform(self.model, self.enc, waveform)
|
||||
self.assertEqual(TRANSCRIPTION_3, transcription)
|
||||
|
||||
@unittest.skipIf(CI, "too long for CI")
|
||||
def test_transcribe_long_no_batch(self):
|
||||
waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)]
|
||||
with self.assertRaises(Exception):
|
||||
transcribe_waveform(self.model, self.enc, waveforms)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue