mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: add type hints
This commit is contained in:
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from .audio import load_audio, SAMPLE_RATE
|
||||
from .types import TranscriptionResult, AlignedTranscriptionResult
|
||||
|
||||
|
||||
class DiarizationPipeline:
|
||||
@ -18,7 +19,13 @@ class DiarizationPipeline:
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
|
||||
def __call__(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio_data = {
|
||||
@ -32,7 +39,11 @@ class DiarizationPipeline:
|
||||
return diarize_df
|
||||
|
||||
|
||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
def assign_word_speakers(
|
||||
diarize_df: pd.DataFrame,
|
||||
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||
fill_nearest=False,
|
||||
) -> dict:
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
|
Reference in New Issue
Block a user