Custom result types

This commit is contained in:
Simon
2023-05-08 20:45:34 +02:00
parent b50aafb17b
commit eabf35dff0
3 changed files with 68 additions and 9 deletions

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union
from typing import Iterator, Union, List
import numpy as np
import pandas as pd
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
transcript: Iterator[dict],
transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
):
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
@ -146,7 +147,7 @@ def align(
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments = []
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
@ -154,7 +155,7 @@ def align(
t2 = segment["end"]
text = segment["text"]
aligned_seg = {
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
@ -301,7 +302,7 @@ def align(
aligned_segments += aligned_subsegments
# create word_segments list
word_segments = []
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]