From f286e7f3de9b08838e7f83e971d841717c44aaa7 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Mon, 13 Jan 2025 08:28:27 +0100 Subject: [PATCH] refactor: improve type hints and clean up imports --- whisperx/alignment.py | 4 ++-- whisperx/asr.py | 9 ++++----- whisperx/diarize.py | 2 +- whisperx/utils.py | 2 +- whisperx/vads/vad.py | 4 ++-- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index d6241bb..ae91828 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -1,4 +1,4 @@ -"""" +""" Forced Alignment with Whisper C. Max Bain """ @@ -14,7 +14,6 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment -import nltk from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] @@ -194,6 +193,7 @@ def align( "end": t2, "text": text, "words": [], + "chars": None, } if return_char_alignments: diff --git a/whisperx/asr.py b/whisperx/asr.py index 52e7972..c71c90a 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,5 @@ import os -import warnings -from typing import List, NamedTuple, Optional, Union +from typing import List, Optional, Union from dataclasses import replace import ctranslate2 @@ -304,8 +303,8 @@ def load_model( compute_type="float16", asr_options: Optional[dict] = None, language: Optional[str] = None, - vad_model = None, - vad_method = None, + vad_model: Optional[Vad]= None, + vad_method: Optional[str] = None, vad_options: Optional[dict] = None, model: Optional[WhisperModel] = None, task="transcribe", @@ -318,7 +317,7 @@ def load_model( whisper_arch - The name of the Whisper model to load. device - The device to load the model on. compute_type - The compute type to use for the model. - vad_method: str - The vad method to use. vad_model has higher priority if is not None. + vad_method - The vad method to use. vad_model has higher priority if is not None. options - A dictionary of options to use for the model. language - The language of the model. (use English for now) model - The WhisperModel instance to use. diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 2a636e6..5b685b3 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -79,7 +79,7 @@ def assign_word_speakers( class Segment: - def __init__(self, start, end, speaker=None): + def __init__(self, start:int, end:int, speaker:Optional[str]=None): self.start = start self.end = end self.speaker = speaker diff --git a/whisperx/utils.py b/whisperx/utils.py index 0b440b7..dfe3cf2 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter): line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) subtitle: list[dict] = [] - times = [] + times: list[tuple] = [] last = result["segments"][0]["start"] for segment in result["segments"]: for i, original_timing in enumerate(segment["words"]): diff --git a/whisperx/vads/vad.py b/whisperx/vads/vad.py index d96184c..d3ffbb1 100644 --- a/whisperx/vads/vad.py +++ b/whisperx/vads/vad.py @@ -26,8 +26,8 @@ class Vad: """ curr_end = 0 merged_segments = [] - seg_idxs = [] - speaker_idxs = [] + seg_idxs: list[tuple]= [] + speaker_idxs: list[Optional[str]] = [] curr_start = segments[0].start for seg in segments: