mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: improve type hints and clean up imports
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
""""
|
"""
|
||||||
Forced Alignment with Whisper
|
Forced Alignment with Whisper
|
||||||
C. Max Bain
|
C. Max Bain
|
||||||
"""
|
"""
|
||||||
@ -14,7 +14,6 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|||||||
from .audio import SAMPLE_RATE, load_audio
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from .utils import interpolate_nans
|
||||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||||
import nltk
|
|
||||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||||
|
|
||||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||||
@ -194,6 +193,7 @@ def align(
|
|||||||
"end": t2,
|
"end": t2,
|
||||||
"text": text,
|
"text": text,
|
||||||
"words": [],
|
"words": [],
|
||||||
|
"chars": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_char_alignments:
|
if return_char_alignments:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
from typing import List, Optional, Union
|
||||||
from typing import List, NamedTuple, Optional, Union
|
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
@ -304,8 +303,8 @@ def load_model(
|
|||||||
compute_type="float16",
|
compute_type="float16",
|
||||||
asr_options: Optional[dict] = None,
|
asr_options: Optional[dict] = None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
vad_model = None,
|
vad_model: Optional[Vad]= None,
|
||||||
vad_method = None,
|
vad_method: Optional[str] = None,
|
||||||
vad_options: Optional[dict] = None,
|
vad_options: Optional[dict] = None,
|
||||||
model: Optional[WhisperModel] = None,
|
model: Optional[WhisperModel] = None,
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
@ -318,7 +317,7 @@ def load_model(
|
|||||||
whisper_arch - The name of the Whisper model to load.
|
whisper_arch - The name of the Whisper model to load.
|
||||||
device - The device to load the model on.
|
device - The device to load the model on.
|
||||||
compute_type - The compute type to use for the model.
|
compute_type - The compute type to use for the model.
|
||||||
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
|
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||||
options - A dictionary of options to use for the model.
|
options - A dictionary of options to use for the model.
|
||||||
language - The language of the model. (use English for now)
|
language - The language of the model. (use English for now)
|
||||||
model - The WhisperModel instance to use.
|
model - The WhisperModel instance to use.
|
||||||
|
@ -79,7 +79,7 @@ def assign_word_speakers(
|
|||||||
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
self.speaker = speaker
|
self.speaker = speaker
|
||||||
|
@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: list[dict] = []
|
||||||
times = []
|
times: list[tuple] = []
|
||||||
last = result["segments"][0]["start"]
|
last = result["segments"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
for i, original_timing in enumerate(segment["words"]):
|
||||||
|
@ -26,8 +26,8 @@ class Vad:
|
|||||||
"""
|
"""
|
||||||
curr_end = 0
|
curr_end = 0
|
||||||
merged_segments = []
|
merged_segments = []
|
||||||
seg_idxs = []
|
seg_idxs: list[tuple]= []
|
||||||
speaker_idxs = []
|
speaker_idxs: list[Optional[str]] = []
|
||||||
|
|
||||||
curr_start = segments[0].start
|
curr_start = segments[0].start
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
|
Reference in New Issue
Block a user