18 Commits

Author SHA1 Message Date
4e2ac4e4e9 torch2.0, remove compile for now, round to times to 3 decimal 2023-05-04 20:38:13 +01:00
d2116b98ca Merge pull request #210 from sorgfresser/v3
Update pyannote and torch version
2023-05-04 20:32:06 +01:00
d8f0ef4a19 Set diarization device manually 2023-05-04 16:25:34 +02:00
1b62c61c71 Merge pull request #216 from aramlang/blank_id-fix
Enable Hebrew support
2023-05-04 01:13:23 +01:00
2d59eb9726 Add torch compile to log mel spectrogram 2023-05-03 23:17:44 +02:00
cb53661070 Enable Hebrew support 2023-05-03 11:26:12 -05:00
2a6830492c Fix pyannote to specific commit 2023-05-02 20:25:56 +02:00
da3aabe181 Merge branch 'm-bain:v3' into v3 2023-05-02 18:55:43 +02:00
067189248f Use pyannote develop branch and torch version 2 2023-05-02 18:44:43 +02:00
9fb51412c0 Merge pull request #208 from arnavmehta7/patch-1 2023-05-02 10:55:13 +01:00
64ca208cc8 Fixed the word_start variable not initialized bug. 2023-05-02 13:13:02 +05:30
5becc99e56 Version bump pyannote, pytorch 2023-05-01 13:47:41 +02:00
e24ca9e0a2 Merge pull request #205 from prashanthellina/v3-fix-diarization 2023-04-30 21:08:45 +01:00
601c91140f references #202, attempt to fix speaker diarization failing in v3 2023-04-30 17:33:24 +00:00
31a9ec7466 Merge pull request #204 from sorgfresser/v3 2023-04-30 18:29:46 +01:00
b9c8c5072b Pad language detection if audio is too short 2023-04-30 18:34:18 +02:00
a903e57cf1 Merge pull request #199 from thomasmol/v3 2023-04-29 23:35:42 +01:00
cb176a186e added num_workers to fix pickling error 2023-04-29 19:51:05 +02:00
7 changed files with 43 additions and 28 deletions

View File

@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
### 1. Create Python3.8 environment
### 1. Create Python3.10 environment
`conda create --name whisperx python=3.8`
`conda create --name whisperx python=3.10`
`conda activate whisperx`
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
`pip3 install torch torchvision torchaudio`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
See other methods [here.](https://pytorch.org/get-started/locally/)
### 3. Install this repo

View File

@ -1,6 +1,5 @@
torch==1.11.0
torchaudio==0.11.0
pyannote.audio
torch==2.0.0
torchaudio==2.0.1
faster-whisper
transformers
ffmpeg-python==0.2.0

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.0.0",
version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
},

View File

@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
}
@ -231,8 +232,13 @@ def align(
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
@ -262,8 +268,8 @@ def align(
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score
char_segments_arr["char"].append(char)
@ -439,8 +445,8 @@ def align(
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
"start": wseg.iloc[wdx]['start'],
"end": wseg.iloc[wdx]['end'],
}
)
@ -450,8 +456,8 @@ def align(
"end": srow["end"],
"text": text,
"words": word_list,
# "word-segments": wseg,
# "char-segments": cseg
"word-segments": wseg,
"char-segments": cseg
}
)

View File

@ -207,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
if isinstance(audio, str):
audio = load_audio(audio)
@ -232,7 +232,7 @@ class FasterWhisperPipeline(Pipeline):
segments = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
@ -251,7 +251,10 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]

View File

@ -1,14 +1,19 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
class DiarizationPipeline:
def __init__(
self,
model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

View File

@ -72,7 +72,6 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on
args = parser.parse_args().__dict__
@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
@ -193,6 +188,7 @@ def cli():
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results:
@ -203,6 +199,12 @@ def cli():
# >> Write
for result, audio_path in results:
# Remove pandas dataframes from result so that
# we can serialize the result with json
for seg in result["segments"]:
seg.pop("word-segments", None)
seg.pop("char-segments", None)
writer(result, audio_path, writer_args)
if __name__ == "__main__":