feat: add SegmentData type for temporary processing during alignment

This commit is contained in:
Barabazs
2025-01-13 09:27:33 +01:00
parent 024bc8481b
commit 2f93e029c7
2 changed files with 21 additions and 3 deletions

View File

@ -2,6 +2,7 @@
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Optional, Union, List
@ -13,7 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment from .types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -131,7 +138,7 @@ def align(
# 1. Preprocess to keep only characters in dictionary # 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript) total_segments = len(transcript)
# Store temporary processing values # Store temporary processing values
segment_data = {} segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount. # strip spaces at beginning / end, but keep track of the amount.
if print_progress: if print_progress:

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List from typing import TypedDict, Optional, List, Tuple
class SingleWordSegment(TypedDict): class SingleWordSegment(TypedDict):
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
text: str text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict): class SingleAlignedSegment(TypedDict):
""" """
A single segment (up to multiple sentences) of a speech with word alignment. A single segment (up to multiple sentences) of a speech with word alignment.