mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
style: minor code formatting
This commit is contained in:
@ -30,16 +30,16 @@ class DiarizationPipeline:
|
|||||||
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
|
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
|
||||||
"""
|
"""
|
||||||
Perform speaker diarization on audio.
|
Perform speaker diarization on audio.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio: Path to audio file or audio array
|
audio: Path to audio file or audio array
|
||||||
num_speakers: Exact number of speakers (if known)
|
num_speakers: Exact number of speakers (if known)
|
||||||
min_speakers: Minimum number of speakers to detect
|
min_speakers: Minimum number of speakers to detect
|
||||||
max_speakers: Maximum number of speakers to detect
|
max_speakers: Maximum number of speakers to detect
|
||||||
return_embeddings: Whether to return speaker embeddings
|
return_embeddings: Whether to return speaker embeddings
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If return_embeddings is True:
|
If return_embeddings is True:
|
||||||
Tuple of (diarization dataframe, speaker embeddings dictionary)
|
Tuple of (diarization dataframe, speaker embeddings dictionary)
|
||||||
Otherwise:
|
Otherwise:
|
||||||
Just the diarization dataframe
|
Just the diarization dataframe
|
||||||
@ -53,18 +53,18 @@ class DiarizationPipeline:
|
|||||||
|
|
||||||
if return_embeddings:
|
if return_embeddings:
|
||||||
diarization, embeddings = self.model(
|
diarization, embeddings = self.model(
|
||||||
audio_data,
|
audio_data,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
min_speakers=min_speakers,
|
min_speakers=min_speakers,
|
||||||
max_speakers=max_speakers,
|
max_speakers=max_speakers,
|
||||||
return_embeddings=True
|
return_embeddings=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
diarization = self.model(
|
diarization = self.model(
|
||||||
audio_data,
|
audio_data,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
min_speakers=min_speakers,
|
min_speakers=min_speakers,
|
||||||
max_speakers=max_speakers
|
max_speakers=max_speakers,
|
||||||
)
|
)
|
||||||
embeddings = None
|
embeddings = None
|
||||||
|
|
||||||
@ -91,13 +91,13 @@ def assign_word_speakers(
|
|||||||
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
|
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
|
||||||
"""
|
"""
|
||||||
Assign speakers to words and segments in the transcript.
|
Assign speakers to words and segments in the transcript.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
diarize_df: Diarization dataframe from DiarizationPipeline
|
diarize_df: Diarization dataframe from DiarizationPipeline
|
||||||
transcript_result: Transcription result to augment with speaker labels
|
transcript_result: Transcription result to augment with speaker labels
|
||||||
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
|
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
|
||||||
fill_nearest: If True, assign speakers even when there's no direct time overlap
|
fill_nearest: If True, assign speakers even when there's no direct time overlap
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated transcript_result with speaker assignments and optionally embeddings
|
Updated transcript_result with speaker assignments and optionally embeddings
|
||||||
"""
|
"""
|
||||||
@ -131,12 +131,12 @@ def assign_word_speakers(
|
|||||||
# sum over speakers
|
# sum over speakers
|
||||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||||
word["speaker"] = speaker
|
word["speaker"] = speaker
|
||||||
|
|
||||||
# Add speaker embeddings to the result if provided
|
# Add speaker embeddings to the result if provided
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
transcript_result["speaker_embeddings"] = speaker_embeddings
|
transcript_result["speaker_embeddings"] = speaker_embeddings
|
||||||
|
|
||||||
return transcript_result
|
return transcript_result
|
||||||
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
|
Reference in New Issue
Block a user