Merge branch 'main' into main

This commit is contained in:
Max Bain
2025-01-13 10:09:20 +00:00
committed by GitHub
6 changed files with 52 additions and 31 deletions

View File

@ -1,7 +1,8 @@
""""
"""
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
@ -13,8 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
from .types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -131,6 +137,8 @@ def align(
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount.
if print_progress:
@ -175,11 +183,13 @@ def align(
sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
segment_data[sdx] = {
"clean_char": clean_char,
"clean_cdx": clean_cdx,
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
@ -194,13 +204,14 @@ def align(
"end": t2,
"text": text,
"words": [],
"chars": None,
}
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment["clean_char"]) == 0:
if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
@ -210,7 +221,7 @@ def align(
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment["clean_char"])
text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
@ -261,8 +272,8 @@ def align(
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
@ -288,10 +299,10 @@ def align(
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
end_chars = curr_chars[curr_chars["char"] != ' ']

View File

@ -1,6 +1,5 @@
import os
import warnings
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
from dataclasses import replace
import ctranslate2
@ -13,8 +12,8 @@ from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
import whisperx.vads
from .types import SingleSegment, TranscriptionResult
from .vads import Vad, Silero, Pyannote
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -209,12 +208,12 @@ class FasterWhisperPipeline(Pipeline):
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), whisperx.vads.Vad):
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
merge_chunks = whisperx.vads.Pyannote.merge_chunks
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
@ -304,8 +303,8 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model = None,
vad_method: str = "pyannote",
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -318,7 +317,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
@ -398,9 +397,9 @@ def load_model(
vad_model = vad_model
else:
if vad_method == "silero":
vad_model = whisperx.vads.Silero(**default_vad_options)
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")

View File

@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment:
def __init__(self, start, end, speaker=None):
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start
self.end = end
self.speaker = speaker

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List
from typing import TypedDict, Optional, List, Tuple
class SingleWordSegment(TypedDict):
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.

View File

@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times = []
times: list[tuple] = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):

View File

@ -26,8 +26,8 @@ class Vad:
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments: