mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
221 lines
8.0 KiB
Python
221 lines
8.0 KiB
Python
import argparse
|
|
import gc
|
|
import os
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from whisperx.alignment import align, load_align_model
|
|
from whisperx.asr import load_model
|
|
from whisperx.audio import load_audio
|
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
|
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
|
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
|
|
|
|
|
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|
"""Transcription task to be called from CLI.
|
|
|
|
Args:
|
|
args: Dictionary of command-line arguments.
|
|
parser: argparse.ArgumentParser object.
|
|
"""
|
|
# fmt: off
|
|
|
|
model_name: str = args.pop("model")
|
|
batch_size: int = args.pop("batch_size")
|
|
model_dir: str = args.pop("model_dir")
|
|
model_cache_only: bool = args.pop("model_cache_only")
|
|
output_dir: str = args.pop("output_dir")
|
|
output_format: str = args.pop("output_format")
|
|
device: str = args.pop("device")
|
|
device_index: int = args.pop("device_index")
|
|
compute_type: str = args.pop("compute_type")
|
|
verbose: bool = args.pop("verbose")
|
|
|
|
# model_flush: bool = args.pop("model_flush")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
align_model: str = args.pop("align_model")
|
|
interpolate_method: str = args.pop("interpolate_method")
|
|
no_align: bool = args.pop("no_align")
|
|
task: str = args.pop("task")
|
|
if task == "translate":
|
|
# translation cannot be aligned
|
|
no_align = True
|
|
|
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
|
|
|
hf_token: str = args.pop("hf_token")
|
|
vad_method: str = args.pop("vad_method")
|
|
vad_onset: float = args.pop("vad_onset")
|
|
vad_offset: float = args.pop("vad_offset")
|
|
|
|
chunk_size: int = args.pop("chunk_size")
|
|
|
|
diarize: bool = args.pop("diarize")
|
|
min_speakers: int = args.pop("min_speakers")
|
|
max_speakers: int = args.pop("max_speakers")
|
|
diarize_model_name: str = args.pop("diarize_model")
|
|
print_progress: bool = args.pop("print_progress")
|
|
|
|
if args["language"] is not None:
|
|
args["language"] = args["language"].lower()
|
|
if args["language"] not in LANGUAGES:
|
|
if args["language"] in TO_LANGUAGE_CODE:
|
|
args["language"] = TO_LANGUAGE_CODE[args["language"]]
|
|
else:
|
|
raise ValueError(f"Unsupported language: {args['language']}")
|
|
|
|
if model_name.endswith(".en") and args["language"] != "en":
|
|
if args["language"] is not None:
|
|
warnings.warn(
|
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
|
)
|
|
args["language"] = "en"
|
|
align_language = (
|
|
args["language"] if args["language"] is not None else "en"
|
|
) # default to loading english if not specified
|
|
|
|
temperature = args.pop("temperature")
|
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
|
else:
|
|
temperature = [temperature]
|
|
|
|
faster_whisper_threads = 4
|
|
if (threads := args.pop("threads")) > 0:
|
|
torch.set_num_threads(threads)
|
|
faster_whisper_threads = threads
|
|
|
|
asr_options = {
|
|
"beam_size": args.pop("beam_size"),
|
|
"patience": args.pop("patience"),
|
|
"length_penalty": args.pop("length_penalty"),
|
|
"temperatures": temperature,
|
|
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
|
|
"log_prob_threshold": args.pop("logprob_threshold"),
|
|
"no_speech_threshold": args.pop("no_speech_threshold"),
|
|
"condition_on_previous_text": False,
|
|
"initial_prompt": args.pop("initial_prompt"),
|
|
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
|
"suppress_numerals": args.pop("suppress_numerals"),
|
|
}
|
|
|
|
writer = get_writer(output_format, output_dir)
|
|
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
|
if no_align:
|
|
for option in word_options:
|
|
if args[option]:
|
|
parser.error(f"--{option} not possible with --no_align")
|
|
if args["max_line_count"] and not args["max_line_width"]:
|
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
|
|
|
# Part 1: VAD & ASR Loop
|
|
results = []
|
|
tmp_results = []
|
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
|
model = load_model(
|
|
model_name,
|
|
device=device,
|
|
device_index=device_index,
|
|
download_root=model_dir,
|
|
compute_type=compute_type,
|
|
language=args["language"],
|
|
asr_options=asr_options,
|
|
vad_method=vad_method,
|
|
vad_options={
|
|
"chunk_size": chunk_size,
|
|
"vad_onset": vad_onset,
|
|
"vad_offset": vad_offset,
|
|
},
|
|
task=task,
|
|
local_files_only=model_cache_only,
|
|
threads=faster_whisper_threads,
|
|
)
|
|
|
|
for audio_path in args.pop("audio"):
|
|
audio = load_audio(audio_path)
|
|
# >> VAD & ASR
|
|
print(">>Performing transcription...")
|
|
result: TranscriptionResult = model.transcribe(
|
|
audio,
|
|
batch_size=batch_size,
|
|
chunk_size=chunk_size,
|
|
print_progress=print_progress,
|
|
verbose=verbose,
|
|
)
|
|
results.append((result, audio_path))
|
|
|
|
# Unload Whisper and VAD
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Part 2: Align Loop
|
|
if not no_align:
|
|
tmp_results = results
|
|
results = []
|
|
align_model, align_metadata = load_align_model(
|
|
align_language, device, model_name=align_model
|
|
)
|
|
for result, audio_path in tmp_results:
|
|
# >> Align
|
|
if len(tmp_results) > 1:
|
|
input_audio = audio_path
|
|
else:
|
|
# lazily load audio from part 1
|
|
input_audio = audio
|
|
|
|
if align_model is not None and len(result["segments"]) > 0:
|
|
if result.get("language", "en") != align_metadata["language"]:
|
|
# load new language
|
|
print(
|
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
|
)
|
|
align_model, align_metadata = load_align_model(
|
|
result["language"], device
|
|
)
|
|
print(">>Performing alignment...")
|
|
result: AlignedTranscriptionResult = align(
|
|
result["segments"],
|
|
align_model,
|
|
align_metadata,
|
|
input_audio,
|
|
device,
|
|
interpolate_method=interpolate_method,
|
|
return_char_alignments=return_char_alignments,
|
|
print_progress=print_progress,
|
|
)
|
|
|
|
results.append((result, audio_path))
|
|
|
|
# Unload align model
|
|
del align_model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# >> Diarize
|
|
if diarize:
|
|
if hf_token is None:
|
|
print(
|
|
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
|
)
|
|
tmp_results = results
|
|
print(">>Performing diarization...")
|
|
print(">>Using model:", diarize_model_name)
|
|
results = []
|
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
|
for result, input_audio_path in tmp_results:
|
|
diarize_segments = diarize_model(
|
|
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
|
)
|
|
result = assign_word_speakers(diarize_segments, result)
|
|
results.append((result, input_audio_path))
|
|
# >> Write
|
|
for result, audio_path in results:
|
|
result["language"] = align_language
|
|
writer(result, audio_path, writer_args)
|