feat: add SegmentData type for temporary processing during alignment

This commit is contained in:
Barabazs
2025-01-13 09:27:33 +01:00
parent 024bc8481b
commit 2f93e029c7
2 changed files with 21 additions and 3 deletions

View File

@ -2,6 +2,7 @@
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
@ -13,7 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
from .types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -131,7 +138,7 @@ def align(
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data = {}
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount.
if print_progress:

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List
from typing import TypedDict, Optional, List, Tuple
class SingleWordSegment(TypedDict):
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.