mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
18 Commits
Author | SHA1 | Date | |
---|---|---|---|
4e2ac4e4e9 | |||
d2116b98ca | |||
d8f0ef4a19 | |||
1b62c61c71 | |||
2d59eb9726 | |||
cb53661070 | |||
2a6830492c | |||
da3aabe181 | |||
067189248f | |||
9fb51412c0 | |||
64ca208cc8 | |||
5becc99e56 | |||
e24ca9e0a2 | |||
601c91140f | |||
31a9ec7466 | |||
b9c8c5072b | |||
a903e57cf1 | |||
cb176a186e |
12
README.md
12
README.md
@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
|
||||
|
||||
|
||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
|
||||
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||
|
||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||
|
||||
|
||||
### 1. Create Python3.8 environment
|
||||
### 1. Create Python3.10 environment
|
||||
|
||||
`conda create --name whisperx python=3.8`
|
||||
`conda create --name whisperx python=3.10`
|
||||
|
||||
`conda activate whisperx`
|
||||
|
||||
|
||||
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
|
||||
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
||||
|
||||
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
|
||||
`pip3 install torch torchvision torchaudio`
|
||||
|
||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
|
||||
See other methods [here.](https://pytorch.org/get-started/locally/)
|
||||
|
||||
### 3. Install this repo
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
torch==1.11.0
|
||||
torchaudio==0.11.0
|
||||
pyannote.audio
|
||||
torch==2.0.0
|
||||
torchaudio==2.0.1
|
||||
faster-whisper
|
||||
transformers
|
||||
ffmpeg-python==0.2.0
|
||||
|
4
setup.py
4
setup.py
@ -6,7 +6,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="whisperx",
|
||||
py_modules=["whisperx"],
|
||||
version="3.0.0",
|
||||
version="3.0.2",
|
||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||
readme="README.md",
|
||||
python_requires=">=3.8",
|
||||
@ -19,7 +19,7 @@ setup(
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
],
|
||||
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
||||
entry_points = {
|
||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
||||
},
|
||||
|
@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||
}
|
||||
|
||||
|
||||
@ -231,8 +232,13 @@ def align(
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
trellis = get_trellis(emission, tokens)
|
||||
path = backtrack(trellis, emission, tokens)
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
break
|
||||
@ -262,8 +268,8 @@ def align(
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = char_seg.start * ratio + t1
|
||||
end = char_seg.end * ratio + t1
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = char_seg.score
|
||||
|
||||
char_segments_arr["char"].append(char)
|
||||
@ -439,8 +445,8 @@ def align(
|
||||
word_list.append(
|
||||
{
|
||||
"word": curr_text.rstrip(),
|
||||
"start": word_start,
|
||||
"end": word_end,
|
||||
"start": wseg.iloc[wdx]['start'],
|
||||
"end": wseg.iloc[wdx]['end'],
|
||||
}
|
||||
)
|
||||
|
||||
@ -450,8 +456,8 @@ def align(
|
||||
"end": srow["end"],
|
||||
"text": text,
|
||||
"words": word_list,
|
||||
# "word-segments": wseg,
|
||||
# "char-segments": cseg
|
||||
"word-segments": wseg,
|
||||
"char-segments": cseg
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -207,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
@ -232,7 +232,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
segments = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
text = out['text']
|
||||
if batch_size in [0, 1, None]:
|
||||
text = text[0]
|
||||
@ -251,7 +251,10 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||
encoder_output = self.model.encode(segment)
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
|
@ -1,14 +1,19 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization@2.1",
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
@ -72,7 +72,6 @@ def cli():
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
||||
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -86,10 +85,6 @@ def cli():
|
||||
# model_flush: bool = args.pop("model_flush")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
tmp_dir: str = args.pop("tmp_dir")
|
||||
if tmp_dir is not None:
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
align_model: str = args.pop("align_model")
|
||||
interpolate_method: str = args.pop("interpolate_method")
|
||||
no_align: bool = args.pop("no_align")
|
||||
@ -193,6 +188,7 @@ def cli():
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
for result, input_audio_path in tmp_results:
|
||||
@ -203,6 +199,12 @@ def cli():
|
||||
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
# Remove pandas dataframes from result so that
|
||||
# we can serialize the result with json
|
||||
for seg in result["segments"]:
|
||||
seg.pop("word-segments", None)
|
||||
seg.pop("char-segments", None)
|
||||
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user