From 8de0e2af516f390126585058f00f84b7a5301578 Mon Sep 17 00:00:00 2001 From: Dudu Asulin <46293514+davidas1@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:11:43 +0300 Subject: [PATCH 1/5] make diarization faster --- README.md | 2 +- whisperx/diarize.py | 8 +++++++- whisperx/transcribe.py | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b52401b..a497604 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ print(result["segments"]) # after alignment diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) # add min/max number of speakers if known -diarize_segments = diarize_model(audio_file) +diarize_segments = diarize_model(audio) # diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers) result = whisperx.assign_word_speakers(diarize_segments, result) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 320d2a4..2a9bd69 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -4,6 +4,8 @@ from pyannote.audio import Pipeline from typing import Optional, Union import torch +from .audio import SAMPLE_RATE + class DiarizationPipeline: def __init__( self, @@ -16,7 +18,11 @@ class DiarizationPipeline: self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) def __call__(self, audio, min_speakers=None, max_speakers=None): - segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers) diarize_df = pd.DataFrame(segments.itertracks(yield_label=True)) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 1cc144e..be2dfaf 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -202,7 +202,8 @@ def cli(): results = [] diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: - diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) + audio = load_audio(input_audio_path) + diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) result = assign_word_speakers(diarize_segments, result) results.append((result, input_audio_path)) # >> Write From 7eb9692cb95f21e1a06913c28f656cf5280cff23 Mon Sep 17 00:00:00 2001 From: Dudu Asulin <46293514+davidas1@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:32:02 +0300 Subject: [PATCH 2/5] more --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a497604..3be1a3c 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, devic # add min/max number of speakers if known diarize_segments = diarize_model(audio) -# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers) +# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) result = whisperx.assign_word_speakers(diarize_segments, result) print(diarize_segments) From da6ed83dc98c18256dcd27942d8d211f810c53c0 Mon Sep 17 00:00:00 2001 From: Dudu Asulin <46293514+davidas1@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:34:42 +0300 Subject: [PATCH 3/5] more --- whisperx/diarize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 2a9bd69..eae6a19 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -4,7 +4,7 @@ from pyannote.audio import Pipeline from typing import Optional, Union import torch -from .audio import SAMPLE_RATE +from .audio import load_audio, SAMPLE_RATE class DiarizationPipeline: def __init__( @@ -18,6 +18,8 @@ class DiarizationPipeline: self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) def __call__(self, audio, min_speakers=None, max_speakers=None): + if isinstance(audio, str): + audio = load_audio(audio) audio_data = { 'waveform': torch.from_numpy(audio[None, :]), 'sample_rate': SAMPLE_RATE From 577db33430a5b7691ff3476e849846b1f3dff87a Mon Sep 17 00:00:00 2001 From: Dudu Asulin <46293514+davidas1@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:35:20 +0300 Subject: [PATCH 4/5] more --- whisperx/transcribe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index be2dfaf..1cc144e 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -202,8 +202,7 @@ def cli(): results = [] diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: - audio = load_audio(input_audio_path) - diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) result = assign_word_speakers(diarize_segments, result) results.append((result, input_audio_path)) # >> Write From 9e3145ceadc6f9826a9374334c647fc3ce92b093 Mon Sep 17 00:00:00 2001 From: Dudu Asulin <46293514+davidas1@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:36:56 +0300 Subject: [PATCH 5/5] more --- whisperx/diarize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index eae6a19..e50dc0f 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -17,7 +17,7 @@ class DiarizationPipeline: device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) - def __call__(self, audio, min_speakers=None, max_speakers=None): + def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None): if isinstance(audio, str): audio = load_audio(audio) audio_data = {