mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #235 from sorgfresser/main
Add custom typing for results
This commit is contained in:
@ -3,7 +3,7 @@ Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
transcript: Iterator[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
):
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
@ -146,7 +147,7 @@ def align(
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
@ -154,7 +155,7 @@ def align(
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
aligned_seg = {
|
||||
aligned_seg: SingleAlignedSegment = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
@ -301,7 +302,7 @@ def align(
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments = []
|
||||
word_segments: List[SingleWordSegment] = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
|
@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
|
||||
from .types import TranscriptionResult, SingleSegment
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None):
|
||||
@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
):
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
|
||||
@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
else:
|
||||
language = self.tokenizer.language_code
|
||||
|
||||
segments = []
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
text = out['text']
|
||||
|
58
whisperx/types.py
Normal file
58
whisperx/types.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
"""
|
||||
A single word of a speech.
|
||||
"""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
class SingleCharSegment(TypedDict):
|
||||
"""
|
||||
A single char of a speech.
|
||||
"""
|
||||
char: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
|
||||
class SingleSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: list[SingleWordSegment]
|
||||
chars: Optional[list[SingleCharSegment]]
|
||||
|
||||
|
||||
class TranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleSegment]
|
||||
language: str
|
||||
|
||||
|
||||
class AlignedTranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleAlignedSegment]
|
||||
word_segments: list[SingleWordSegment]
|
Reference in New Issue
Block a user