2022-12-14 18:59:12 +00:00
|
|
|
import argparse
|
2023-04-09 15:58:55 +08:00
|
|
|
import gc
|
2023-04-24 21:08:43 +01:00
|
|
|
import os
|
2022-12-14 18:59:12 +00:00
|
|
|
import warnings
|
2023-04-24 21:08:43 +01:00
|
|
|
|
2022-12-14 18:59:12 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-04-24 21:08:43 +01:00
|
|
|
|
2025-03-25 16:13:55 +01:00
|
|
|
from whisperx.alignment import align, load_align_model
|
|
|
|
from whisperx.asr import load_model
|
|
|
|
from whisperx.audio import load_audio
|
|
|
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
|
|
|
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
2025-05-01 14:19:42 +02:00
|
|
|
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
2023-02-01 19:41:20 +00:00
|
|
|
|
2022-12-14 18:59:12 +00:00
|
|
|
|
2025-05-01 14:19:42 +02:00
|
|
|
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|
|
|
"""Transcription task to be called from CLI.
|
2023-06-05 15:27:42 +01:00
|
|
|
|
2025-05-01 14:19:42 +02:00
|
|
|
Args:
|
|
|
|
args: Dictionary of command-line arguments.
|
|
|
|
parser: argparse.ArgumentParser object.
|
|
|
|
"""
|
|
|
|
# fmt: off
|
2023-03-30 05:31:57 +01:00
|
|
|
|
2022-12-14 18:59:12 +00:00
|
|
|
model_name: str = args.pop("model")
|
2023-04-24 21:08:43 +01:00
|
|
|
batch_size: int = args.pop("batch_size")
|
2023-12-27 14:03:54 -05:00
|
|
|
model_dir: str = args.pop("model_dir")
|
2025-01-27 12:16:37 +00:00
|
|
|
model_cache_only: bool = args.pop("model_cache_only")
|
2022-12-14 18:59:12 +00:00
|
|
|
output_dir: str = args.pop("output_dir")
|
2023-03-30 05:31:57 +01:00
|
|
|
output_format: str = args.pop("output_format")
|
2022-12-14 18:59:12 +00:00
|
|
|
device: str = args.pop("device")
|
2023-05-20 13:02:46 +02:00
|
|
|
device_index: int = args.pop("device_index")
|
2023-04-24 21:26:44 +01:00
|
|
|
compute_type: str = args.pop("compute_type")
|
2025-01-01 17:37:52 +05:30
|
|
|
verbose: bool = args.pop("verbose")
|
2023-04-24 21:26:44 +01:00
|
|
|
|
2023-03-31 23:02:38 +01:00
|
|
|
# model_flush: bool = args.pop("model_flush")
|
2023-03-30 05:31:57 +01:00
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
2022-12-14 18:59:12 +00:00
|
|
|
|
|
|
|
align_model: str = args.pop("align_model")
|
2023-03-30 05:31:57 +01:00
|
|
|
interpolate_method: str = args.pop("interpolate_method")
|
|
|
|
no_align: bool = args.pop("no_align")
|
2025-01-05 11:26:18 +01:00
|
|
|
task: str = args.pop("task")
|
2023-05-13 12:14:06 +01:00
|
|
|
if task == "translate":
|
|
|
|
# translation cannot be aligned
|
|
|
|
no_align = True
|
|
|
|
|
2023-05-07 20:28:33 +01:00
|
|
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
2023-03-30 05:31:57 +01:00
|
|
|
|
2023-01-26 00:42:35 +02:00
|
|
|
hf_token: str = args.pop("hf_token")
|
2024-09-26 10:28:52 +02:00
|
|
|
vad_method: str = args.pop("vad_method")
|
2023-03-30 05:31:57 +01:00
|
|
|
vad_onset: float = args.pop("vad_onset")
|
|
|
|
vad_offset: float = args.pop("vad_offset")
|
2023-01-20 12:54:20 +00:00
|
|
|
|
2023-08-29 23:09:02 +08:00
|
|
|
chunk_size: int = args.pop("chunk_size")
|
|
|
|
|
2023-01-24 15:02:08 +00:00
|
|
|
diarize: bool = args.pop("diarize")
|
|
|
|
min_speakers: int = args.pop("min_speakers")
|
|
|
|
max_speakers: int = args.pop("max_speakers")
|
2025-05-31 13:32:31 +02:00
|
|
|
diarize_model_name: str = args.pop("diarize_model")
|
2023-08-16 16:22:29 +02:00
|
|
|
print_progress: bool = args.pop("print_progress")
|
2025-03-21 13:57:47 +00:00
|
|
|
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
|
|
|
|
|
|
|
|
if return_speaker_embeddings and not diarize:
|
|
|
|
warnings.warn("--speaker_embeddings has no effect without --diarize")
|
2023-01-24 15:02:08 +00:00
|
|
|
|
2023-10-10 10:20:58 +02:00
|
|
|
if args["language"] is not None:
|
|
|
|
args["language"] = args["language"].lower()
|
|
|
|
if args["language"] not in LANGUAGES:
|
|
|
|
if args["language"] in TO_LANGUAGE_CODE:
|
|
|
|
args["language"] = TO_LANGUAGE_CODE[args["language"]]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported language: {args['language']}")
|
|
|
|
|
|
|
|
if model_name.endswith(".en") and args["language"] != "en":
|
2022-12-14 18:59:12 +00:00
|
|
|
if args["language"] is not None:
|
2023-03-30 05:31:57 +01:00
|
|
|
warnings.warn(
|
2023-12-04 17:38:50 +03:00
|
|
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
2023-03-30 05:31:57 +01:00
|
|
|
)
|
2022-12-14 18:59:12 +00:00
|
|
|
args["language"] = "en"
|
2025-05-01 10:43:02 +02:00
|
|
|
align_language = (
|
|
|
|
args["language"] if args["language"] is not None else "en"
|
|
|
|
) # default to loading english if not specified
|
2022-12-14 18:59:12 +00:00
|
|
|
|
|
|
|
temperature = args.pop("temperature")
|
2023-03-30 05:31:57 +01:00
|
|
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
|
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
2022-12-14 18:59:12 +00:00
|
|
|
else:
|
|
|
|
temperature = [temperature]
|
|
|
|
|
2023-09-14 11:47:51 +02:00
|
|
|
faster_whisper_threads = 4
|
2023-03-30 05:31:57 +01:00
|
|
|
if (threads := args.pop("threads")) > 0:
|
2022-12-14 18:59:12 +00:00
|
|
|
torch.set_num_threads(threads)
|
2023-09-14 11:47:51 +02:00
|
|
|
faster_whisper_threads = threads
|
2022-12-14 18:59:12 +00:00
|
|
|
|
2023-04-24 21:08:43 +01:00
|
|
|
asr_options = {
|
|
|
|
"beam_size": args.pop("beam_size"),
|
|
|
|
"patience": args.pop("patience"),
|
|
|
|
"length_penalty": args.pop("length_penalty"),
|
|
|
|
"temperatures": temperature,
|
|
|
|
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
|
|
|
|
"log_prob_threshold": args.pop("logprob_threshold"),
|
|
|
|
"no_speech_threshold": args.pop("no_speech_threshold"),
|
|
|
|
"condition_on_previous_text": False,
|
|
|
|
"initial_prompt": args.pop("initial_prompt"),
|
2023-06-05 15:27:42 +01:00
|
|
|
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
|
|
|
"suppress_numerals": args.pop("suppress_numerals"),
|
2023-04-24 21:08:43 +01:00
|
|
|
}
|
2022-12-20 19:54:55 +00:00
|
|
|
|
2023-03-30 05:31:57 +01:00
|
|
|
writer = get_writer(output_format, output_dir)
|
2023-04-24 21:08:43 +01:00
|
|
|
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
|
|
|
if no_align:
|
|
|
|
for option in word_options:
|
|
|
|
if args[option]:
|
2023-10-31 18:55:35 +01:00
|
|
|
parser.error(f"--{option} not possible with --no_align")
|
2023-04-24 21:08:43 +01:00
|
|
|
if args["max_line_count"] and not args["max_line_width"]:
|
|
|
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
|
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
2025-05-01 10:43:02 +02:00
|
|
|
|
2023-04-09 15:58:55 +08:00
|
|
|
# Part 1: VAD & ASR Loop
|
|
|
|
results = []
|
|
|
|
tmp_results = []
|
2023-04-24 21:08:43 +01:00
|
|
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
2025-05-01 10:43:02 +02:00
|
|
|
model = load_model(
|
|
|
|
model_name,
|
|
|
|
device=device,
|
|
|
|
device_index=device_index,
|
|
|
|
download_root=model_dir,
|
|
|
|
compute_type=compute_type,
|
|
|
|
language=args["language"],
|
|
|
|
asr_options=asr_options,
|
|
|
|
vad_method=vad_method,
|
|
|
|
vad_options={
|
|
|
|
"chunk_size": chunk_size,
|
|
|
|
"vad_onset": vad_onset,
|
|
|
|
"vad_offset": vad_offset,
|
|
|
|
},
|
|
|
|
task=task,
|
|
|
|
local_files_only=model_cache_only,
|
|
|
|
threads=faster_whisper_threads,
|
|
|
|
)
|
2023-04-01 00:06:40 +01:00
|
|
|
|
2023-04-24 21:08:43 +01:00
|
|
|
for audio_path in args.pop("audio"):
|
|
|
|
audio = load_audio(audio_path)
|
2023-04-01 00:06:40 +01:00
|
|
|
# >> VAD & ASR
|
2023-04-24 21:08:43 +01:00
|
|
|
print(">>Performing transcription...")
|
2025-01-05 11:26:18 +01:00
|
|
|
result: TranscriptionResult = model.transcribe(
|
|
|
|
audio,
|
|
|
|
batch_size=batch_size,
|
|
|
|
chunk_size=chunk_size,
|
|
|
|
print_progress=print_progress,
|
|
|
|
verbose=verbose,
|
|
|
|
)
|
2023-04-24 21:08:43 +01:00
|
|
|
results.append((result, audio_path))
|
2023-04-09 15:58:55 +08:00
|
|
|
|
|
|
|
# Unload Whisper and VAD
|
|
|
|
del model
|
|
|
|
gc.collect()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
# Part 2: Align Loop
|
|
|
|
if not no_align:
|
|
|
|
tmp_results = results
|
|
|
|
results = []
|
2025-05-01 10:43:02 +02:00
|
|
|
align_model, align_metadata = load_align_model(
|
|
|
|
align_language, device, model_name=align_model
|
|
|
|
)
|
2023-04-24 21:08:43 +01:00
|
|
|
for result, audio_path in tmp_results:
|
2023-04-09 15:58:55 +08:00
|
|
|
# >> Align
|
2023-04-24 21:08:43 +01:00
|
|
|
if len(tmp_results) > 1:
|
|
|
|
input_audio = audio_path
|
|
|
|
else:
|
|
|
|
# lazily load audio from part 1
|
|
|
|
input_audio = audio
|
|
|
|
|
2023-04-09 15:58:55 +08:00
|
|
|
if align_model is not None and len(result["segments"]) > 0:
|
|
|
|
if result.get("language", "en") != align_metadata["language"]:
|
|
|
|
# load new language
|
2025-05-01 10:43:02 +02:00
|
|
|
print(
|
|
|
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
|
|
|
)
|
|
|
|
align_model, align_metadata = load_align_model(
|
|
|
|
result["language"], device
|
|
|
|
)
|
2023-04-09 15:58:55 +08:00
|
|
|
print(">>Performing alignment...")
|
2025-01-05 11:26:18 +01:00
|
|
|
result: AlignedTranscriptionResult = align(
|
|
|
|
result["segments"],
|
|
|
|
align_model,
|
|
|
|
align_metadata,
|
|
|
|
input_audio,
|
|
|
|
device,
|
|
|
|
interpolate_method=interpolate_method,
|
|
|
|
return_char_alignments=return_char_alignments,
|
|
|
|
print_progress=print_progress,
|
|
|
|
)
|
2023-05-07 15:32:58 +01:00
|
|
|
|
2023-04-24 21:08:43 +01:00
|
|
|
results.append((result, audio_path))
|
2023-04-09 15:58:55 +08:00
|
|
|
|
|
|
|
# Unload align model
|
|
|
|
del align_model
|
|
|
|
gc.collect()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
# >> Diarize
|
|
|
|
if diarize:
|
|
|
|
if hf_token is None:
|
2025-05-01 10:43:02 +02:00
|
|
|
print(
|
|
|
|
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
|
|
|
)
|
2023-04-09 15:58:55 +08:00
|
|
|
tmp_results = results
|
2023-05-04 16:25:34 +02:00
|
|
|
print(">>Performing diarization...")
|
2025-05-31 13:32:31 +02:00
|
|
|
print(">>Using model:", diarize_model_name)
|
2023-04-09 15:58:55 +08:00
|
|
|
results = []
|
2025-05-31 13:32:31 +02:00
|
|
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
2023-04-09 15:58:55 +08:00
|
|
|
for result, input_audio_path in tmp_results:
|
2025-06-25 13:55:20 +02:00
|
|
|
diarize_result = diarize_model(
|
2025-03-21 13:57:47 +00:00
|
|
|
input_audio_path,
|
|
|
|
min_speakers=min_speakers,
|
|
|
|
max_speakers=max_speakers,
|
|
|
|
return_embeddings=return_speaker_embeddings
|
2025-05-01 10:43:02 +02:00
|
|
|
)
|
2025-06-25 13:55:20 +02:00
|
|
|
|
|
|
|
if return_speaker_embeddings:
|
|
|
|
diarize_segments, speaker_embeddings = diarize_result
|
|
|
|
else:
|
|
|
|
diarize_segments = diarize_result
|
|
|
|
speaker_embeddings = None
|
|
|
|
|
2025-03-21 13:57:47 +00:00
|
|
|
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
|
2023-04-09 15:58:55 +08:00
|
|
|
results.append((result, input_audio_path))
|
|
|
|
# >> Write
|
|
|
|
for result, audio_path in results:
|
2023-08-26 06:48:35 +08:00
|
|
|
result["language"] = align_language
|
2023-04-24 21:08:43 +01:00
|
|
|
writer(result, audio_path, writer_args)
|