refactor: add type hints

This commit is contained in:
Barabazs
2025-01-05 11:26:18 +01:00
parent 0f7f9f9f83
commit 9a8967f27e
6 changed files with 111 additions and 57 deletions

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model
from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
optional_int, str2bool)
from .types import AlignedTranscriptionResult, TranscriptionResult
from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli():
@ -95,7 +102,7 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task : str = args.pop("task")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
@ -174,7 +181,13 @@ def cli():
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress, verbose=verbose)
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
@ -201,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))