feat: add version and Python version arguments to CLI

This commit is contained in:
Barabazs
2025-05-01 10:43:02 +02:00
parent cd59f21d1a
commit ac0c8bd79a

View File

@ -1,7 +1,10 @@
import argparse
import gc
import os
import sys
import warnings
import importlib.metadata
import platform
import numpy as np
import torch
@ -85,6 +88,8 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on
args = parser.parse_args().__dict__
@ -138,7 +143,9 @@ def cli():
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
)
args["language"] = "en"
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@ -174,12 +181,29 @@ def cli():
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
@ -203,7 +227,9 @@ def cli():
if not no_align:
tmp_results = results
results = []
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results:
# >> Align
if len(tmp_results) > 1:
@ -215,8 +241,12 @@ def cli():
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...")
result: AlignedTranscriptionResult = align(
result["segments"],
@ -239,13 +269,17 @@ def cli():
# >> Diarize
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
@ -253,5 +287,6 @@ def cli():
result["language"] = align_language
writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()