Custom result types

This commit is contained in:
Simon
2023-05-08 20:45:34 +02:00
parent b50aafb17b
commit eabf35dff0
3 changed files with 68 additions and 9 deletions

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union
from typing import Iterator, Union, List
import numpy as np
import pandas as pd
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
transcript: Iterator[dict],
transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
):
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
@ -146,7 +147,7 @@ def align(
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments = []
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
@ -154,7 +155,7 @@ def align(
t2 = segment["end"]
text = segment["text"]
aligned_seg = {
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
@ -301,7 +302,7 @@ def align(
aligned_segments += aligned_subsegments
# create word_segments list
word_segments = []
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]

View File

@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None):
@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline):
else:
language = self.tokenizer.language_code
segments = []
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']

58
whisperx/types.py Normal file
View File

@ -0,0 +1,58 @@
from typing import TypedDict, Optional
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
words: list[SingleWordSegment]
chars: Optional[list[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleAlignedSegment]
word_segments: list[SingleWordSegment]