mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based)
This commit is contained in:
155
whisperx/asr.py
155
whisperx/asr.py
@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_type = "simple"
|
||||
import time
|
||||
|
||||
import jiwer
|
||||
from tqdm import tqdm
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
|
||||
from benchmark.tedlium import parse_tedlium_annos
|
||||
|
||||
if main_type == "complex":
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import TranscriptionOptions
|
||||
from faster_whisper.vad import (SpeechTimestampsMap,
|
||||
get_speech_timestamps)
|
||||
|
||||
from whisperx.vad import load_vad_model, merge_chunks
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
faster_t_options = TranscriptionOptions(
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
length_penalty=1,
|
||||
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
no_speech_threshold=0.6,
|
||||
condition_on_previous_text=False,
|
||||
initial_prompt=None,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens=[-1],
|
||||
without_timestamps=True,
|
||||
max_initial_timestamp=0.0,
|
||||
word_timestamps=False,
|
||||
prepend_punctuations="\"'“¿([{-",
|
||||
append_punctuations="\"'.。,,!!??::”)]}、"
|
||||
)
|
||||
whisper_arch = "large-v2"
|
||||
device = "cuda"
|
||||
batch_size = 16
|
||||
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
|
||||
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
|
||||
fn = "DanielKahneman_2010.wav"
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
vad_model = load_vad_model("cuda", 0.6, 0.3)
|
||||
audio = load_audio(os.path.join(wav_dir, fn))
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
vad_method="pyannote"
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
if vad_method == "pyannote":
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
elif vad_method == "silero":
|
||||
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
|
||||
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
|
||||
new_segs = []
|
||||
curr_start = vad_segments[0]['start']
|
||||
curr_end = vad_segments[0]['end']
|
||||
for seg in vad_segments[1:]:
|
||||
if seg['end'] - curr_start > 30:
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
curr_start = seg['start']
|
||||
curr_end = seg['end']
|
||||
else:
|
||||
curr_end = seg['end']
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
vad_segments = new_segs
|
||||
text = []
|
||||
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
|
||||
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
|
||||
text.append(out['text'])
|
||||
t2 = time.time()
|
||||
if batch_size == 1:
|
||||
text = [x[0] for x in text]
|
||||
text = " ".join(text)
|
||||
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
elif main_type == "simple":
|
||||
model = load_model(
|
||||
"large-v2",
|
||||
device="cuda",
|
||||
language="en",
|
||||
)
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
# fn = "DanielKahneman_2010.wav"
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
out = model.transcribe(audio_fp, batch_size=8)["segments"]
|
||||
t2 = time.time()
|
||||
|
||||
text = " ".join([x['text'] for x in out])
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
|
Reference in New Issue
Block a user