refactor: add type hints

This commit is contained in:
Barabazs
2025-01-05 11:26:18 +01:00
parent 0f7f9f9f83
commit 9a8967f27e
6 changed files with 111 additions and 57 deletions

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Union, List from typing import Iterable, Optional, Union, List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -65,7 +65,7 @@ DEFAULT_ALIGN_MODELS_HF = {
} }
def load_align_model(language_code, device, model_name=None, model_dir=None): def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None: if model_name is None:
# use default model # use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH: if language_code in DEFAULT_ALIGN_MODELS_TORCH:

View File

@ -1,20 +1,20 @@
import os import os
import warnings import warnings
from typing import List, Union, Optional, NamedTuple from typing import List, NamedTuple, Optional, Union
import ctranslate2 import ctranslate2
import faster_whisper import faster_whisper
import numpy as np import numpy as np
import torch import torch
from faster_whisper.tokenizer import Tokenizer from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import (TranscriptionOptions, from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
get_ctranslate2_storage)
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks from .types import SingleSegment, TranscriptionResult
from .types import TranscriptionResult, SingleSegment from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [] numeral_symbol_tokens = []
@ -104,16 +104,16 @@ class FasterWhisperPipeline(Pipeline):
def __init__( def __init__(
self, self,
model, model: WhisperModel,
vad, vad: VoiceActivitySegmentation,
vad_params: dict, vad_params: dict,
options: NamedTuple, options: NamedTuple,
tokenizer=None, tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1, device: Union[int, str, "torch.device"] = -1,
framework="pt", framework="pt",
language: Optional[str] = None, language: Optional[str] = None,
suppress_numerals: bool = False, suppress_numerals: bool = False,
**kwargs **kwargs,
): ):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -165,7 +165,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs return model_outputs
def get_iterator( def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
): ):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ: if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -180,7 +186,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator return final_iterator
def transcribe( def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, verbose=False self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
@ -258,8 +273,7 @@ class FasterWhisperPipeline(Pipeline):
return {"segments": segments, "language": language} return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray) -> str:
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES: if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.") print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size") model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -273,33 +287,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language return language
def load_model(whisper_arch,
device, def load_model(
whisper_arch: str,
device: str,
device_index=0, device_index=0,
compute_type="float16", compute_type="float16",
asr_options=None, asr_options: Optional[dict] = None,
language: Optional[str] = None, language: Optional[str] = None,
vad_model=None, vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options=None, vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None, model: Optional[WhisperModel] = None,
task="transcribe", task="transcribe",
download_root=None, download_root: Optional[str] = None,
local_files_only=False, local_files_only=False,
threads=4): threads=4,
'''Load a Whisper model for inference. ) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args: Args:
whisper_arch: str - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device: str - The device to load the model on. device - The device to load the model on.
compute_type: str - The compute type to use for the model. compute_type - The compute type to use for the model.
options: dict - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now) language - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use. model - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to. download_root - The root directory to download the model to.
local_files_only: bool - If `True`, avoid downloading the file and return the path to the local cached file if it exists. local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns: Returns:
A Whisper pipeline. A Whisper pipeline.
''' """
if whisper_arch.endswith(".en"): if whisper_arch.endswith(".en"):
language = "en" language = "en"

View File

@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary

View File

@ -1,5 +1,8 @@
# conjunctions.py # conjunctions.py
from typing import Set
conjunctions_by_language = { conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'}, 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'}, 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
@ -36,8 +39,9 @@ commas_by_language = {
'ur': '،' 'ur': '،'
} }
def get_conjunctions(lang_code): def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set()) return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code):
return commas_by_language.get(lang_code, ',') def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch import torch
from .audio import load_audio, SAMPLE_RATE from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device) device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
audio_data = { audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"] transcript_segments = transcript_result["segments"]
for seg in transcript_segments: for seg in transcript_segments:
# assign speaker to segment (if any) # assign speaker to segment (if any)

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model from .asr import load_model
from .audio import load_audio from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, from .types import AlignedTranscriptionResult, TranscriptionResult
optional_int, str2bool) from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli(): def cli():
@ -174,7 +181,13 @@ def cli():
audio = load_audio(audio_path) audio = load_audio(audio_path)
# >> VAD & ASR # >> VAD & ASR
print(">>Performing transcription...") print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress, verbose=verbose) result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path)) results.append((result, audio_path))
# Unload Whisper and VAD # Unload Whisper and VAD
@ -201,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...") print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress) result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path)) results.append((result, audio_path))