mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #235 from sorgfresser/main
Add custom typing for results
This commit is contained in:
@ -3,7 +3,7 @@ Forced Alignment with Whisper
|
|||||||
C. Max Bain
|
C. Max Bain
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, Union
|
from typing import Iterator, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from .utils import interpolate_nans
|
||||||
|
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
|||||||
|
|
||||||
|
|
||||||
def align(
|
def align(
|
||||||
transcript: Iterator[dict],
|
transcript: Iterator[SingleSegment],
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
align_model_metadata: dict,
|
align_model_metadata: dict,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
device: str,
|
device: str,
|
||||||
interpolate_method: str = "nearest",
|
interpolate_method: str = "nearest",
|
||||||
return_char_alignments: bool = False,
|
return_char_alignments: bool = False,
|
||||||
):
|
) -> AlignedTranscriptionResult:
|
||||||
"""
|
"""
|
||||||
Align phoneme recognition predictions to known transcription.
|
Align phoneme recognition predictions to known transcription.
|
||||||
"""
|
"""
|
||||||
@ -146,7 +147,7 @@ def align(
|
|||||||
segment["clean_wdx"] = clean_wdx
|
segment["clean_wdx"] = clean_wdx
|
||||||
segment["sentence_spans"] = sentence_spans
|
segment["sentence_spans"] = sentence_spans
|
||||||
|
|
||||||
aligned_segments = []
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
# 2. Get prediction matrix from alignment model & align
|
# 2. Get prediction matrix from alignment model & align
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
@ -154,7 +155,7 @@ def align(
|
|||||||
t2 = segment["end"]
|
t2 = segment["end"]
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
aligned_seg = {
|
aligned_seg: SingleAlignedSegment = {
|
||||||
"start": t1,
|
"start": t1,
|
||||||
"end": t2,
|
"end": t2,
|
||||||
"text": text,
|
"text": text,
|
||||||
@ -301,7 +302,7 @@ def align(
|
|||||||
aligned_segments += aligned_subsegments
|
aligned_segments += aligned_subsegments
|
||||||
|
|
||||||
# create word_segments list
|
# create word_segments list
|
||||||
word_segments = []
|
word_segments: List[SingleWordSegment] = []
|
||||||
for segment in aligned_segments:
|
for segment in aligned_segments:
|
||||||
word_segments += segment["words"]
|
word_segments += segment["words"]
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
|||||||
|
|
||||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
from .vad import load_vad_model, merge_chunks
|
from .vad import load_vad_model, merge_chunks
|
||||||
|
from .types import TranscriptionResult, SingleSegment
|
||||||
|
|
||||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||||
vad_options=None, model=None):
|
vad_options=None, model=None):
|
||||||
@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||||
):
|
) -> TranscriptionResult:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
|
|
||||||
@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
language = self.tokenizer.language_code
|
language = self.tokenizer.language_code
|
||||||
|
|
||||||
segments = []
|
segments: List[SingleSegment] = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
text = out['text']
|
text = out['text']
|
||||||
|
58
whisperx/types.py
Normal file
58
whisperx/types.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SingleWordSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single word of a speech.
|
||||||
|
"""
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
class SingleCharSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single char of a speech.
|
||||||
|
"""
|
||||||
|
char: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single segment (up to multiple sentences) of a speech.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class SingleAlignedSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
words: list[SingleWordSegment]
|
||||||
|
chars: Optional[list[SingleCharSegment]]
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResult(TypedDict):
|
||||||
|
"""
|
||||||
|
A list of segments and word segments of a speech.
|
||||||
|
"""
|
||||||
|
segments: list[SingleSegment]
|
||||||
|
language: str
|
||||||
|
|
||||||
|
|
||||||
|
class AlignedTranscriptionResult(TypedDict):
|
||||||
|
"""
|
||||||
|
A list of segments and word segments of a speech.
|
||||||
|
"""
|
||||||
|
segments: list[SingleAlignedSegment]
|
||||||
|
word_segments: list[SingleWordSegment]
|
Reference in New Issue
Block a user