mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge branch 'main' into main
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
""""
|
||||
"""
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
@ -13,8 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
from .types import (
|
||||
AlignedTranscriptionResult,
|
||||
SingleSegment,
|
||||
SingleAlignedSegment,
|
||||
SingleWordSegment,
|
||||
SegmentData,
|
||||
)
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
@ -131,6 +137,8 @@ def align(
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
total_segments = len(transcript)
|
||||
# Store temporary processing values
|
||||
segment_data: dict[int, SegmentData] = {}
|
||||
for sdx, segment in enumerate(transcript):
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
if print_progress:
|
||||
@ -175,11 +183,13 @@ def align(
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
segment_data[sdx] = {
|
||||
"clean_char": clean_char,
|
||||
"clean_cdx": clean_cdx,
|
||||
"clean_wdx": clean_wdx,
|
||||
"sentence_spans": sentence_spans
|
||||
}
|
||||
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
@ -194,13 +204,14 @@ def align(
|
||||
"end": t2,
|
||||
"text": text,
|
||||
"words": [],
|
||||
"chars": None,
|
||||
}
|
||||
|
||||
if return_char_alignments:
|
||||
aligned_seg["chars"] = []
|
||||
|
||||
# check we can align
|
||||
if len(segment["clean_char"]) == 0:
|
||||
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
@ -210,7 +221,7 @@ def align(
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
text_clean = "".join(segment["clean_char"])
|
||||
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
@ -261,8 +272,8 @@ def align(
|
||||
word_idx = 0
|
||||
for cdx, char in enumerate(text):
|
||||
start, end, score = None, None, None
|
||||
if cdx in segment["clean_cdx"]:
|
||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = round(char_seg.score, 3)
|
||||
@ -288,10 +299,10 @@ def align(
|
||||
aligned_subsegments = []
|
||||
# assign sentence_idx to each character index
|
||||
char_segments_arr["sentence-idx"] = None
|
||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import replace
|
||||
|
||||
import ctranslate2
|
||||
@ -13,8 +12,8 @@ from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
import whisperx.vads
|
||||
from .types import SingleSegment, TranscriptionResult
|
||||
from .vads import Vad, Silero, Pyannote
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
@ -209,12 +208,12 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||
if issubclass(type(self.vad_model), whisperx.vads.Vad):
|
||||
if issubclass(type(self.vad_model), Vad):
|
||||
waveform = self.vad_model.preprocess_audio(audio)
|
||||
merge_chunks = self.vad_model.merge_chunks
|
||||
else:
|
||||
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
|
||||
merge_chunks = whisperx.vads.Pyannote.merge_chunks
|
||||
waveform = Pyannote.preprocess_audio(audio)
|
||||
merge_chunks = Pyannote.merge_chunks
|
||||
|
||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(
|
||||
@ -304,8 +303,8 @@ def load_model(
|
||||
compute_type="float16",
|
||||
asr_options: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
vad_model = None,
|
||||
vad_method: str = "pyannote",
|
||||
vad_model: Optional[Vad]= None,
|
||||
vad_method: Optional[str] = "pyannote",
|
||||
vad_options: Optional[dict] = None,
|
||||
model: Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
@ -318,7 +317,7 @@ def load_model(
|
||||
whisper_arch - The name of the Whisper model to load.
|
||||
device - The device to load the model on.
|
||||
compute_type - The compute type to use for the model.
|
||||
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
|
||||
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||
options - A dictionary of options to use for the model.
|
||||
language - The language of the model. (use English for now)
|
||||
model - The WhisperModel instance to use.
|
||||
@ -398,9 +397,9 @@ def load_model(
|
||||
vad_model = vad_model
|
||||
else:
|
||||
if vad_method == "silero":
|
||||
vad_model = whisperx.vads.Silero(**default_vad_options)
|
||||
vad_model = Silero(**default_vad_options)
|
||||
elif vad_method == "pyannote":
|
||||
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
else:
|
||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||
|
||||
|
@ -79,7 +79,7 @@ def assign_word_speakers(
|
||||
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TypedDict, Optional, List
|
||||
from typing import TypedDict, Optional, List, Tuple
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class SegmentData(TypedDict):
|
||||
"""
|
||||
Temporary processing data used during alignment.
|
||||
Contains cleaned and preprocessed data for each segment.
|
||||
"""
|
||||
clean_char: List[str] # Cleaned characters that exist in model dictionary
|
||||
clean_cdx: List[int] # Original indices of cleaned characters
|
||||
clean_wdx: List[int] # Indices of words containing valid characters
|
||||
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
|
@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
times = []
|
||||
times: list[tuple] = []
|
||||
last = result["segments"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
|
@ -26,8 +26,8 @@ class Vad:
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
seg_idxs: list[tuple]= []
|
||||
speaker_idxs: list[Optional[str]] = []
|
||||
|
||||
curr_start = segments[0].start
|
||||
for seg in segments:
|
||||
|
Reference in New Issue
Block a user