From eabf35dff0d80ff3cabc946b65d2faf42797e671 Mon Sep 17 00:00:00 2001 From: Simon Date: Mon, 8 May 2023 20:45:34 +0200 Subject: [PATCH] Custom result types --- whisperx/alignment.py | 13 +++++----- whisperx/asr.py | 6 ++--- whisperx/types.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 whisperx/types.py diff --git a/whisperx/alignment.py b/whisperx/alignment.py index b873475..eb8d4b6 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -3,7 +3,7 @@ Forced Alignment with Whisper C. Max Bain """ from dataclasses import dataclass -from typing import Iterator, Union +from typing import Iterator, Union, List import numpy as np import pandas as pd @@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans +from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment import nltk LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None): def align( - transcript: Iterator[dict], + transcript: Iterator[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, interpolate_method: str = "nearest", return_char_alignments: bool = False, -): +) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. """ @@ -146,7 +147,7 @@ def align( segment["clean_wdx"] = clean_wdx segment["sentence_spans"] = sentence_spans - aligned_segments = [] + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): @@ -154,7 +155,7 @@ def align( t2 = segment["end"] text = segment["text"] - aligned_seg = { + aligned_seg: SingleAlignedSegment = { "start": t1, "end": t2, "text": text, @@ -301,7 +302,7 @@ def align( aligned_segments += aligned_subsegments # create word_segments list - word_segments = [] + word_segments: List[SingleWordSegment] = [] for segment in aligned_segments: word_segments += segment["words"] diff --git a/whisperx/asr.py b/whisperx/asr.py index 21357ec..e131ae1 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks - +from .types import TranscriptionResult, SingleSegment def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, vad_options=None, model=None): @@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline): def transcribe( self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 - ): + ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline): else: language = self.tokenizer.language_code - segments = [] + segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): text = out['text'] diff --git a/whisperx/types.py b/whisperx/types.py new file mode 100644 index 0000000..75d4485 --- /dev/null +++ b/whisperx/types.py @@ -0,0 +1,58 @@ +from typing import TypedDict, Optional + + +class SingleWordSegment(TypedDict): + """ + A single word of a speech. + """ + word: str + start: float + end: float + score: float + +class SingleCharSegment(TypedDict): + """ + A single char of a speech. + """ + char: str + start: float + end: float + score: float + + +class SingleSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech. + """ + + start: float + end: float + text: str + + +class SingleAlignedSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech with word alignment. + """ + + start: float + end: float + text: str + words: list[SingleWordSegment] + chars: Optional[list[SingleCharSegment]] + + +class TranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleSegment] + language: str + + +class AlignedTranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleAlignedSegment] + word_segments: list[SingleWordSegment]