12 Commits

6 changed files with 29 additions and 23 deletions

View File

@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
### 1. Create Python3.8 environment
### 1. Create Python3.10 environment
`conda create --name whisperx python=3.8`
`conda create --name whisperx python=3.10`
`conda activate whisperx`
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
`pip3 install torch torchvision torchaudio`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
See other methods [here.](https://pytorch.org/get-started/locally/)
### 3. Install this repo

View File

@ -1,6 +1,5 @@
torch==1.11.0
torchaudio==0.11.0
pyannote.audio
torch==2.0.0
torchaudio==2.0.1
faster-whisper
transformers
ffmpeg-python==0.2.0

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.0.0",
version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
},

View File

@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
}
@ -231,8 +232,13 @@ def align(
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
@ -262,8 +268,8 @@ def align(
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score
char_segments_arr["char"].append(char)
@ -439,8 +445,8 @@ def align(
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
"start": wseg.iloc[wdx]['start'],
"end": wseg.iloc[wdx]['end'],
}
)

View File

@ -1,14 +1,19 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
class DiarizationPipeline:
def __init__(
self,
model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

View File

@ -72,7 +72,6 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on
args = parser.parse_args().__dict__
@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
@ -193,6 +188,7 @@ def cli():
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results: