mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Custom result types
This commit is contained in:
@ -3,7 +3,7 @@ Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
transcript: Iterator[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
):
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
@ -146,7 +147,7 @@ def align(
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
@ -154,7 +155,7 @@ def align(
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
aligned_seg = {
|
||||
aligned_seg: SingleAlignedSegment = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
@ -301,7 +302,7 @@ def align(
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments = []
|
||||
word_segments: List[SingleWordSegment] = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
|
Reference in New Issue
Block a user