refactor: improve type hints and clean up imports

This commit is contained in:
Barabazs
2025-01-13 08:28:27 +01:00
parent 73e644559d
commit f286e7f3de
5 changed files with 10 additions and 11 deletions

View File

@ -1,4 +1,4 @@
""""
"""
Forced Alignment with Whisper
C. Max Bain
"""
@ -14,7 +14,6 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -194,6 +193,7 @@ def align(
"end": t2,
"text": text,
"words": [],
"chars": None,
}
if return_char_alignments:

View File

@ -1,6 +1,5 @@
import os
import warnings
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
from dataclasses import replace
import ctranslate2
@ -304,8 +303,8 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model = None,
vad_method = None,
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -318,7 +317,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.

View File

@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment:
def __init__(self, start, end, speaker=None):
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start
self.end = end
self.speaker = speaker

View File

@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times = []
times: list[tuple] = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):

View File

@ -26,8 +26,8 @@ class Vad:
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments: