feat: add version and Python version arguments to CLI

This commit is contained in:
Barabazs
2025-05-01 10:43:02 +02:00
parent cd59f21d1a
commit ac0c8bd79a

View File

@ -1,7 +1,10 @@
import argparse import argparse
import gc import gc
import os import os
import sys
import warnings import warnings
import importlib.metadata
import platform
import numpy as np import numpy as np
import torch import torch
@ -85,6 +88,8 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -138,7 +143,9 @@ def cli():
f"{model_name} is an English-only model but received '{args['language']}'; using English instead." f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
) )
args["language"] = "en" args["language"] = "en"
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature") temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None: if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@ -179,7 +186,24 @@ def cli():
results = [] results = []
tmp_results = [] tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir) # model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads) model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
audio = load_audio(audio_path) audio = load_audio(audio_path)
@ -203,7 +227,9 @@ def cli():
if not no_align: if not no_align:
tmp_results = results tmp_results = results
results = [] results = []
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results: for result, audio_path in tmp_results:
# >> Align # >> Align
if len(tmp_results) > 1: if len(tmp_results) > 1:
@ -215,8 +241,12 @@ def cli():
if align_model is not None and len(result["segments"]) > 0: if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]: if result.get("language", "en") != align_metadata["language"]:
# load new language # load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(
align_model, align_metadata = load_align_model(result["language"], device) f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...") print(">>Performing alignment...")
result: AlignedTranscriptionResult = align( result: AlignedTranscriptionResult = align(
result["segments"], result["segments"],
@ -239,13 +269,17 @@ def cli():
# >> Diarize # >> Diarize
if diarize: if diarize:
if hf_token is None: if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results tmp_results = results
print(">>Performing diarization...") print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
result = assign_word_speakers(diarize_segments, result) result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path)) results.append((result, input_audio_path))
# >> Write # >> Write
@ -253,5 +287,6 @@ def cli():
result["language"] = align_language result["language"] = align_language
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()