mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
18 Commits
Author | SHA1 | Date | |
---|---|---|---|
4e2ac4e4e9 | |||
d2116b98ca | |||
d8f0ef4a19 | |||
1b62c61c71 | |||
2d59eb9726 | |||
cb53661070 | |||
2a6830492c | |||
da3aabe181 | |||
067189248f | |||
9fb51412c0 | |||
64ca208cc8 | |||
5becc99e56 | |||
e24ca9e0a2 | |||
601c91140f | |||
31a9ec7466 | |||
b9c8c5072b | |||
a903e57cf1 | |||
cb176a186e |
12
README.md
12
README.md
@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
|
|||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
|
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||||
|
|
||||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||||
|
|
||||||
|
|
||||||
### 1. Create Python3.8 environment
|
### 1. Create Python3.10 environment
|
||||||
|
|
||||||
`conda create --name whisperx python=3.8`
|
`conda create --name whisperx python=3.10`
|
||||||
|
|
||||||
`conda activate whisperx`
|
`conda activate whisperx`
|
||||||
|
|
||||||
|
|
||||||
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
|
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
||||||
|
|
||||||
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
|
`pip3 install torch torchvision torchaudio`
|
||||||
|
|
||||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
|
See other methods [here.](https://pytorch.org/get-started/locally/)
|
||||||
|
|
||||||
### 3. Install this repo
|
### 3. Install this repo
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
torch==1.11.0
|
torch==2.0.0
|
||||||
torchaudio==0.11.0
|
torchaudio==2.0.1
|
||||||
pyannote.audio
|
|
||||||
faster-whisper
|
faster-whisper
|
||||||
transformers
|
transformers
|
||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
|
4
setup.py
4
setup.py
@ -6,7 +6,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name="whisperx",
|
name="whisperx",
|
||||||
py_modules=["whisperx"],
|
py_modules=["whisperx"],
|
||||||
version="3.0.0",
|
version="3.0.2",
|
||||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||||
readme="README.md",
|
readme="README.md",
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
@ -19,7 +19,7 @@ setup(
|
|||||||
for r in pkg_resources.parse_requirements(
|
for r in pkg_resources.parse_requirements(
|
||||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
)
|
)
|
||||||
],
|
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
||||||
entry_points = {
|
entry_points = {
|
||||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
||||||
},
|
},
|
||||||
|
@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||||
|
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -231,8 +232,13 @@ def align(
|
|||||||
|
|
||||||
emission = emissions[0].cpu().detach()
|
emission = emissions[0].cpu().detach()
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens)
|
blank_id = 0
|
||||||
path = backtrack(trellis, emission, tokens)
|
for char, code in model_dictionary.items():
|
||||||
|
if char == '[pad]' or char == '<pad>':
|
||||||
|
blank_id = code
|
||||||
|
|
||||||
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
|
path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
break
|
break
|
||||||
@ -262,8 +268,8 @@ def align(
|
|||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in clean_cdx:
|
if cdx in clean_cdx:
|
||||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||||
start = char_seg.start * ratio + t1
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = char_seg.end * ratio + t1
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = char_seg.score
|
score = char_seg.score
|
||||||
|
|
||||||
char_segments_arr["char"].append(char)
|
char_segments_arr["char"].append(char)
|
||||||
@ -439,8 +445,8 @@ def align(
|
|||||||
word_list.append(
|
word_list.append(
|
||||||
{
|
{
|
||||||
"word": curr_text.rstrip(),
|
"word": curr_text.rstrip(),
|
||||||
"start": word_start,
|
"start": wseg.iloc[wdx]['start'],
|
||||||
"end": word_end,
|
"end": wseg.iloc[wdx]['end'],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -450,8 +456,8 @@ def align(
|
|||||||
"end": srow["end"],
|
"end": srow["end"],
|
||||||
"text": text,
|
"text": text,
|
||||||
"words": word_list,
|
"words": word_list,
|
||||||
# "word-segments": wseg,
|
"word-segments": wseg,
|
||||||
# "char-segments": cseg
|
"char-segments": cseg
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None
|
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||||
):
|
):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -232,7 +232,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
segments = []
|
segments = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
text = out['text']
|
text = out['text']
|
||||||
if batch_size in [0, 1, None]:
|
if batch_size in [0, 1, None]:
|
||||||
text = text[0]
|
text = text[0]
|
||||||
@ -251,7 +251,10 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
|
|
||||||
def detect_language(self, audio: np.ndarray):
|
def detect_language(self, audio: np.ndarray):
|
||||||
segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
|
if audio.shape[0] < N_SAMPLES:
|
||||||
|
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||||
|
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||||
|
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||||
encoder_output = self.model.encode(segment)
|
encoder_output = self.model.encode(segment)
|
||||||
results = self.model.model.detect_language(encoder_output)
|
results = self.model.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization@2.1",
|
model_name="pyannote/speaker-diarization@2.1",
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
|
@ -72,7 +72,6 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
||||||
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -86,10 +85,6 @@ def cli():
|
|||||||
# model_flush: bool = args.pop("model_flush")
|
# model_flush: bool = args.pop("model_flush")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
tmp_dir: str = args.pop("tmp_dir")
|
|
||||||
if tmp_dir is not None:
|
|
||||||
os.makedirs(tmp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
align_model: str = args.pop("align_model")
|
align_model: str = args.pop("align_model")
|
||||||
interpolate_method: str = args.pop("interpolate_method")
|
interpolate_method: str = args.pop("interpolate_method")
|
||||||
no_align: bool = args.pop("no_align")
|
no_align: bool = args.pop("no_align")
|
||||||
@ -193,6 +188,7 @@ def cli():
|
|||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
@ -203,6 +199,12 @@ def cli():
|
|||||||
|
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
|
# Remove pandas dataframes from result so that
|
||||||
|
# we can serialize the result with json
|
||||||
|
for seg in result["segments"]:
|
||||||
|
seg.pop("word-segments", None)
|
||||||
|
seg.pop("char-segments", None)
|
||||||
|
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user