refactor: add type hints

This commit is contained in:
Barabazs
2025-01-05 11:26:18 +01:00
parent 0f7f9f9f83
commit 9a8967f27e
6 changed files with 111 additions and 57 deletions

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch
from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)