Merge branch 'main' into main

This commit is contained in:
Max Bain
2025-01-13 10:09:20 +00:00
committed by GitHub
6 changed files with 52 additions and 31 deletions

View File

@ -1,7 +1,8 @@
"""" """
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Optional, Union, List
@ -13,8 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment from .types import (
import nltk AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -131,6 +137,8 @@ def align(
# 1. Preprocess to keep only characters in dictionary # 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript) total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount. # strip spaces at beginning / end, but keep track of the amount.
if print_progress: if print_progress:
@ -175,11 +183,13 @@ def align(
sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text)) sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char segment_data[sdx] = {
segment["clean_cdx"] = clean_cdx "clean_char": clean_char,
segment["clean_wdx"] = clean_wdx "clean_cdx": clean_cdx,
segment["sentence_spans"] = sentence_spans "clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = [] aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align # 2. Get prediction matrix from alignment model & align
@ -194,13 +204,14 @@ def align(
"end": t2, "end": t2,
"text": text, "text": text,
"words": [], "words": [],
"chars": None,
} }
if return_char_alignments: if return_char_alignments:
aligned_seg["chars"] = [] aligned_seg["chars"] = []
# check we can align # check we can align
if len(segment["clean_char"]) == 0: if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
@ -210,7 +221,7 @@ def align(
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
text_clean = "".join(segment["clean_char"]) text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean] tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE) f1 = int(t1 * SAMPLE_RATE)
@ -261,8 +272,8 @@ def align(
word_idx = 0 word_idx = 0
for cdx, char in enumerate(text): for cdx, char in enumerate(text):
start, end, score = None, None, None start, end, score = None, None, None
if cdx in segment["clean_cdx"]: if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)] char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3) start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3) score = round(char_seg.score, 3)
@ -288,10 +299,10 @@ def align(
aligned_subsegments = [] aligned_subsegments = []
# assign sentence_idx to each character index # assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
sentence_text = text[sstart:send] sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min() sentence_start = curr_chars["start"].min()
end_chars = curr_chars[curr_chars["char"] != ' '] end_chars = curr_chars[curr_chars["char"] != ' ']

View File

@ -1,6 +1,5 @@
import os import os
import warnings from typing import List, Optional, Union
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace from dataclasses import replace
import ctranslate2 import ctranslate2
@ -13,8 +12,8 @@ from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
import whisperx.vads
from .types import SingleSegment, TranscriptionResult from .types import SingleSegment, TranscriptionResult
from .vads import Vad, Silero, Pyannote
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [] numeral_symbol_tokens = []
@ -209,12 +208,12 @@ class FasterWhisperPipeline(Pipeline):
# Pre-process audio and merge chunks as defined by the respective VAD child class # Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), whisperx.vads.Vad): if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio) waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks merge_chunks = self.vad_model.merge_chunks
else: else:
waveform = whisperx.vads.Pyannote.preprocess_audio(audio) waveform = Pyannote.preprocess_audio(audio)
merge_chunks = whisperx.vads.Pyannote.merge_chunks merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks( vad_segments = merge_chunks(
@ -304,8 +303,8 @@ def load_model(
compute_type="float16", compute_type="float16",
asr_options: Optional[dict] = None, asr_options: Optional[dict] = None,
language: Optional[str] = None, language: Optional[str] = None,
vad_model = None, vad_model: Optional[Vad]= None,
vad_method: str = "pyannote", vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None, vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None, model: Optional[WhisperModel] = None,
task="transcribe", task="transcribe",
@ -318,7 +317,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device - The device to load the model on. device - The device to load the model on.
compute_type - The compute type to use for the model. compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None. vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language - The language of the model. (use English for now) language - The language of the model. (use English for now)
model - The WhisperModel instance to use. model - The WhisperModel instance to use.
@ -398,9 +397,9 @@ def load_model(
vad_model = vad_model vad_model = vad_model
else: else:
if vad_method == "silero": if vad_method == "silero":
vad_model = whisperx.vads.Silero(**default_vad_options) vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote": elif vad_method == "pyannote":
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else: else:
raise ValueError(f"Invalid vad_method: {vad_method}") raise ValueError(f"Invalid vad_method: {vad_method}")

View File

@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment: class Segment:
def __init__(self, start, end, speaker=None): def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start self.start = start
self.end = end self.end = end
self.speaker = speaker self.speaker = speaker

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List from typing import TypedDict, Optional, List, Tuple
class SingleWordSegment(TypedDict): class SingleWordSegment(TypedDict):
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
text: str text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict): class SingleAlignedSegment(TypedDict):
""" """
A single segment (up to multiple sentences) of a speech with word alignment. A single segment (up to multiple sentences) of a speech with word alignment.

View File

@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: list[dict] = []
times = [] times: list[tuple] = []
last = result["segments"][0]["start"] last = result["segments"][0]["start"]
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): for i, original_timing in enumerate(segment["words"]):

View File

@ -26,8 +26,8 @@ class Vad:
""" """
curr_end = 0 curr_end = 0
merged_segments = [] merged_segments = []
seg_idxs = [] seg_idxs: list[tuple]= []
speaker_idxs = [] speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start curr_start = segments[0].start
for seg in segments: for seg in segments: