feat: add SegmentData type for temporary processing during alignment

This commit is contained in:
Barabazs
2025-01-13 09:27:33 +01:00
parent 024bc8481b
commit 2f93e029c7
2 changed files with 21 additions and 3 deletions

View File

@ -2,6 +2,7 @@
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
@ -13,7 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
from .types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -131,7 +138,7 @@ def align(
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data = {}
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount.
if print_progress: