Files
whisperX/whisperx/transcribe.py
Barabazs ffedc5cdf0 fix: speaker embedding bug (#1178)
* fix: improve handling of speaker embeddings in transcribe_task

* chore: bump version to 3.4.1
2025-06-25 13:55:20 +02:00

235 lines
8.5 KiB
Python

import argparse
import gc
import os
import warnings
import numpy as np
import torch
from whisperx.alignment import align, load_align_model
from whisperx.asr import load_model
from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
"""Transcription task to be called from CLI.
Args:
args: Dictionary of command-line arguments.
parser: argparse.ArgumentParser object.
"""
# fmt: off
model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir")
model_cache_only: bool = args.pop("model_cache_only")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
verbose: bool = args.pop("verbose")
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
chunk_size: int = args.pop("chunk_size")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
if args["language"] not in LANGUAGES:
if args["language"] in TO_LANGUAGE_CODE:
args["language"] = TO_LANGUAGE_CODE[args["language"]]
else:
raise ValueError(f"Unsupported language: {args['language']}")
if model_name.endswith(".en") and args["language"] != "en":
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
)
args["language"] = "en"
align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
faster_whisper_threads = 4
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
faster_whisper_threads = threads
asr_options = {
"beam_size": args.pop("beam_size"),
"patience": args.pop("patience"),
"length_penalty": args.pop("length_penalty"),
"temperatures": temperature,
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
"log_prob_threshold": args.pop("logprob_threshold"),
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
}
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if no_align:
for option in word_options:
if args[option]:
parser.error(f"--{option} not possible with --no_align")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
del model
gc.collect()
torch.cuda.empty_cache()
# Part 2: Align Loop
if not no_align:
tmp_results = results
results = []
align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results:
# >> Align
if len(tmp_results) > 1:
input_audio = audio_path
else:
# lazily load audio from part 1
input_audio = audio
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...")
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))
# Unload align model
del align_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize:
if hf_token is None:
print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results
print(">>Performing diarization...")
print(">>Using model:", diarize_model_name)
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_result = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)
if return_speaker_embeddings:
diarize_segments, speaker_embeddings = diarize_result
else:
diarize_segments = diarize_result
speaker_embeddings = None
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
result["language"] = align_language
writer(result, audio_path, writer_args)