mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
7 Commits
v3.3.4
...
1843f3553a
Author | SHA1 | Date | |
---|---|---|---|
1843f3553a | |||
d700b56c9c | |||
b343241253 | |||
6fe0a8784a | |||
5012650d0f | |||
108bd0c400 | |||
c72c627d10 |
3
.github/workflows/build-and-release.yml
vendored
3
.github/workflows/build-and-release.yml
vendored
@ -17,6 +17,9 @@ jobs:
|
|||||||
version: "0.5.14"
|
version: "0.5.14"
|
||||||
python-version: "3.9"
|
python-version: "3.9"
|
||||||
|
|
||||||
|
- name: Check if lockfile is up to date
|
||||||
|
run: uv lock --check
|
||||||
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|
||||||
|
3
.github/workflows/python-compatibility.yml
vendored
3
.github/workflows/python-compatibility.yml
vendored
@ -23,6 +23,9 @@ jobs:
|
|||||||
version: "0.5.14"
|
version: "0.5.14"
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Check if lockfile is up to date
|
||||||
|
run: uv lock --check
|
||||||
|
|
||||||
- name: Install the project
|
- name: Install the project
|
||||||
run: uv sync --all-extras
|
run: uv sync --all-extras
|
||||||
|
|
||||||
|
23
README.md
23
README.md
@ -97,6 +97,25 @@ uv sync --all-extras --dev
|
|||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
|
### Common Issues & Troubleshooting 🔧
|
||||||
|
|
||||||
|
#### libcudnn Dependencies (GPU Users)
|
||||||
|
|
||||||
|
If you're using WhisperX with GPU support and encounter errors like:
|
||||||
|
|
||||||
|
- `Could not load library libcudnn_ops_infer.so.8`
|
||||||
|
- `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
|
||||||
|
- `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory`
|
||||||
|
|
||||||
|
This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default.
|
||||||
|
|
||||||
|
**Install cuDNN (example for apt based systems):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install libcudnn8 libcudnn8-dev -y
|
||||||
|
```
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
|
|
||||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||||
@ -170,7 +189,7 @@ result = model.transcribe(audio, batch_size=batch_size)
|
|||||||
print(result["segments"]) # before alignment
|
print(result["segments"]) # before alignment
|
||||||
|
|
||||||
# delete model if low on GPU resources
|
# delete model if low on GPU resources
|
||||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
|
||||||
|
|
||||||
# 2. Align whisper output
|
# 2. Align whisper output
|
||||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||||
@ -179,7 +198,7 @@ result = whisperx.align(result["segments"], model_a, metadata, audio, device, re
|
|||||||
print(result["segments"]) # after alignment
|
print(result["segments"]) # after alignment
|
||||||
|
|
||||||
# delete model if low on GPU resources
|
# delete model if low on GPU resources
|
||||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||||
|
|
||||||
# 3. Assign speaker labels
|
# 3. Assign speaker labels
|
||||||
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||||
|
2
uv.lock
generated
2
uv.lock
generated
@ -2787,7 +2787,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whisperx"
|
name = "whisperx"
|
||||||
version = "3.3.3"
|
version = "3.3.4"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "ctranslate2" },
|
{ name = "ctranslate2" },
|
||||||
|
@ -43,6 +43,7 @@ def cli():
|
|||||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||||
|
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
|
@ -5,7 +5,7 @@ C. Max Bain
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, Optional, Union, List
|
from typing import Iterable, Union, List, Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -120,6 +120,7 @@ def align(
|
|||||||
return_char_alignments: bool = False,
|
return_char_alignments: bool = False,
|
||||||
print_progress: bool = False,
|
print_progress: bool = False,
|
||||||
combined_progress: bool = False,
|
combined_progress: bool = False,
|
||||||
|
on_progress: Callable[[int, int], None] = None
|
||||||
) -> AlignedTranscriptionResult:
|
) -> AlignedTranscriptionResult:
|
||||||
"""
|
"""
|
||||||
Align phoneme recognition predictions to known transcription.
|
Align phoneme recognition predictions to known transcription.
|
||||||
@ -149,6 +150,9 @@ def align(
|
|||||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||||
print(f"Progress: {percent_complete:.2f}%...")
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
|
|
||||||
|
if on_progress:
|
||||||
|
on_progress(sdx + 1, total_segments)
|
||||||
|
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Union
|
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
import warnings
|
||||||
|
from typing import List, Union, Optional, NamedTuple, Callable
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import faster_whisper
|
import faster_whisper
|
||||||
@ -103,6 +105,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# - add support for timestamp mode
|
# - add support for timestamp mode
|
||||||
# - add support for custom inference kwargs
|
# - add support for custom inference kwargs
|
||||||
|
|
||||||
|
class TranscriptionState(Enum):
|
||||||
|
LOADING_AUDIO = "loading_audio"
|
||||||
|
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
|
||||||
|
TRANSCRIBING = "transcribing"
|
||||||
|
FINISHED = "finished"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: WhisperModel,
|
model: WhisperModel,
|
||||||
@ -197,8 +205,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
print_progress=False,
|
print_progress=False,
|
||||||
combined_progress=False,
|
combined_progress=False,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
|
||||||
) -> TranscriptionResult:
|
) -> TranscriptionResult:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
|
if on_progress:
|
||||||
|
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)
|
||||||
|
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
|
|
||||||
def data(audio, segments):
|
def data(audio, segments):
|
||||||
@ -216,6 +228,8 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
waveform = Pyannote.preprocess_audio(audio)
|
waveform = Pyannote.preprocess_audio(audio)
|
||||||
merge_chunks = Pyannote.merge_chunks
|
merge_chunks = Pyannote.merge_chunks
|
||||||
|
if on_progress:
|
||||||
|
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||||
vad_segments = merge_chunks(
|
vad_segments = merge_chunks(
|
||||||
@ -255,16 +269,22 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
segments: List[SingleSegment] = []
|
segments: List[SingleSegment] = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
total_segments = len(vad_segments)
|
total_segments = len(vad_segments)
|
||||||
|
|
||||||
|
if on_progress:
|
||||||
|
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)
|
||||||
|
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
if print_progress:
|
if print_progress:
|
||||||
base_progress = ((idx + 1) / total_segments) * 100
|
base_progress = ((idx + 1) / total_segments) * 100
|
||||||
percent_complete = base_progress / 2 if combined_progress else base_progress
|
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||||
print(f"Progress: {percent_complete:.2f}%...")
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
|
|
||||||
|
if on_progress:
|
||||||
|
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)
|
||||||
|
|
||||||
text = out['text']
|
text = out['text']
|
||||||
if batch_size in [0, 1, None]:
|
if batch_size in [0, 1, None]:
|
||||||
text = text[0]
|
text = text[0]
|
||||||
if verbose:
|
|
||||||
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
|
|
||||||
segments.append(
|
segments.append(
|
||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
@ -273,6 +293,9 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if on_progress:
|
||||||
|
on_progress(self.__class__.TranscriptionState.FINISHED)
|
||||||
|
|
||||||
# revert the tokenizer if multilingual inference is enabled
|
# revert the tokenizer if multilingual inference is enabled
|
||||||
if self.preset_language is None:
|
if self.preset_language is None:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
@ -11,13 +11,14 @@ from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
|||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization-3.1",
|
model_name=None,
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
device: Optional[Union[str, torch.device]] = "cpu",
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||||
|
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -57,6 +57,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|||||||
diarize: bool = args.pop("diarize")
|
diarize: bool = args.pop("diarize")
|
||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
diarize_model_name: str = args.pop("diarize_model")
|
||||||
print_progress: bool = args.pop("print_progress")
|
print_progress: bool = args.pop("print_progress")
|
||||||
|
|
||||||
if args["language"] is not None:
|
if args["language"] is not None:
|
||||||
@ -204,8 +205,9 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
print(">>Performing diarization...")
|
||||||
|
print(">>Using model:", diarize_model_name)
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(
|
diarize_segments = diarize_model(
|
||||||
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||||
|
Reference in New Issue
Block a user