12 Commits

6 changed files with 29 additions and 23 deletions

View File

@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!) Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
### 1. Create Python3.8 environment ### 1. Create Python3.10 environment
`conda create --name whisperx python=3.8` `conda create --name whisperx python=3.10`
`conda activate whisperx` `conda activate whisperx`
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows: ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113` `pip3 install torch torchvision torchaudio`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4) See other methods [here.](https://pytorch.org/get-started/locally/)
### 3. Install this repo ### 3. Install this repo

View File

@ -1,6 +1,5 @@
torch==1.11.0 torch==2.0.0
torchaudio==0.11.0 torchaudio==2.0.1
pyannote.audio
faster-whisper faster-whisper
transformers transformers
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup( setup(
name="whisperx", name="whisperx",
py_modules=["whisperx"], py_modules=["whisperx"],
version="3.0.0", version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.", description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md", readme="README.md",
python_requires=">=3.8", python_requires=">=3.8",
@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements( for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
) )
], ] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
entry_points = { entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'], 'console_scripts': ['whisperx=whisperx.transcribe:cli'],
}, },

View File

@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
} }
@ -231,8 +232,13 @@ def align(
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens) blank_id = 0
path = backtrack(trellis, emission, tokens) for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None: if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break break
@ -262,8 +268,8 @@ def align(
start, end, score = None, None, None start, end, score = None, None, None
if cdx in clean_cdx: if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)] char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1 start = round(char_seg.start * ratio + t1, 3)
end = char_seg.end * ratio + t1 end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score score = char_seg.score
char_segments_arr["char"].append(char) char_segments_arr["char"].append(char)
@ -439,8 +445,8 @@ def align(
word_list.append( word_list.append(
{ {
"word": curr_text.rstrip(), "word": curr_text.rstrip(),
"start": word_start, "start": wseg.iloc[wdx]['start'],
"end": word_end, "end": wseg.iloc[wdx]['end'],
} }
) )

View File

@ -1,14 +1,19 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
class DiarizationPipeline: class DiarizationPipeline:
def __init__( def __init__(
self, self,
model_name="pyannote/speaker-diarization@2.1", model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None, use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
): ):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None): def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

View File

@ -72,7 +72,6 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush") # model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method") interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align") no_align: bool = args.pop("no_align")
@ -193,6 +188,7 @@ def cli():
if hf_token is None: if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results tmp_results = results
print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token) diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results: