refactor: improve type hints and clean up imports

This commit is contained in:
Barabazs
2025-01-13 08:28:27 +01:00
parent 73e644559d
commit f286e7f3de
5 changed files with 10 additions and 11 deletions

View File

@ -1,4 +1,4 @@
"""" """
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
@ -14,7 +14,6 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -194,6 +193,7 @@ def align(
"end": t2, "end": t2,
"text": text, "text": text,
"words": [], "words": [],
"chars": None,
} }
if return_char_alignments: if return_char_alignments:

View File

@ -1,6 +1,5 @@
import os import os
import warnings from typing import List, Optional, Union
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace from dataclasses import replace
import ctranslate2 import ctranslate2
@ -304,8 +303,8 @@ def load_model(
compute_type="float16", compute_type="float16",
asr_options: Optional[dict] = None, asr_options: Optional[dict] = None,
language: Optional[str] = None, language: Optional[str] = None,
vad_model = None, vad_model: Optional[Vad]= None,
vad_method = None, vad_method: Optional[str] = None,
vad_options: Optional[dict] = None, vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None, model: Optional[WhisperModel] = None,
task="transcribe", task="transcribe",
@ -318,7 +317,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device - The device to load the model on. device - The device to load the model on.
compute_type - The compute type to use for the model. compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None. vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language - The language of the model. (use English for now) language - The language of the model. (use English for now)
model - The WhisperModel instance to use. model - The WhisperModel instance to use.

View File

@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment: class Segment:
def __init__(self, start, end, speaker=None): def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start self.start = start
self.end = end self.end = end
self.speaker = speaker self.speaker = speaker

View File

@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: list[dict] = []
times = [] times: list[tuple] = []
last = result["segments"][0]["start"] last = result["segments"][0]["start"]
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): for i, original_timing in enumerate(segment["words"]):

View File

@ -26,8 +26,8 @@ class Vad:
""" """
curr_end = 0 curr_end = 0
merged_segments = [] merged_segments = []
seg_idxs = [] seg_idxs: list[tuple]= []
speaker_idxs = [] speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start curr_start = segments[0].start
for seg in segments: for seg in segments: