mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
233 Commits
Author | SHA1 | Date | |
---|---|---|---|
73db39703e | |||
db1750fa48 | |||
734084cdf6 | |||
9395b0de18 | |||
d57f9dc54c | |||
a90bd1ce3f | |||
10b05fc43f | |||
26d9b46888 | |||
9a8967f27e | |||
0f7f9f9f83 | |||
c60594fa3b | |||
4916192246 | |||
cbdac53e87 | |||
940a223219 | |||
a0eb31019b | |||
b08ad67a72 | |||
c18f9f979b | |||
948b3e368b | |||
e9ac5b63bc | |||
90b45459d9 | |||
81c4af96a6 | |||
1c6d9327bc | |||
0fdb55d317 | |||
51da22771f | |||
15ad5bf7df | |||
7fdbd21fe3 | |||
3ff625c561 | |||
7307306a9d | |||
3027cc32bc | |||
9e4b1b4c49 | |||
9b9e03c4cc | |||
19eff8e79a | |||
6f3bc5b7b8 | |||
9809336db6 | |||
a898b3ba94 | |||
c141074cbd | |||
a9e50ef0af | |||
161ae1f7ad | |||
a83ddbdf9b | |||
9e3a9e0e38 | |||
3f339f9515 | |||
9a9b6171e6 | |||
59b4d88d1d | |||
6f70aa6beb | |||
912920c591 | |||
58f00339af | |||
f2da2f858e | |||
78dcfaab51 | |||
d6562c26da | |||
c313f4dd5c | |||
bbaa2f0d1a | |||
e906be9688 | |||
fbbd07bece | |||
d8c9196346 | |||
2686f74bc9 | |||
8227807fa9 | |||
59962a70be | |||
06e30b2a25 | |||
6bb2f1cd48 | |||
f8cc46c6f7 | |||
942c336b8f | |||
8ae6416594 | |||
8540ff5985 | |||
5dfbfcbdc0 | |||
1c7b1a87da | |||
9f23739f90 | |||
19ab91c5a6 | |||
089cd5ab21 | |||
2b7ab95ad6 | |||
4553e0d4ed | |||
f865dfe710 | |||
4acbdd75be | |||
e9c507ce5d | |||
a5dca2cc65 | |||
8a8eeb33ee | |||
b4d7b1a422 | |||
5a16e59217 | |||
b4e4143e3b | |||
4b05198eed | |||
71a5281bde | |||
d97cdb7bcf | |||
20161935a1 | |||
1d7f8ccbf1 | |||
5756b0fb13 | |||
aaaa3de810 | |||
ba30365344 | |||
bd3aa03b6f | |||
f5c544ff90 | |||
7c2a9a8b7b | |||
9f41c49fe5 | |||
48d651e5ea | |||
4ece2369d7 | |||
52fbe5c26f | |||
6703d2774b | |||
a2af569838 | |||
0c7f32f55c | |||
6936dd6991 | |||
6b1100a919 | |||
d4a600b568 | |||
afd5ef1d58 | |||
dbeb8617f2 | |||
c6fe379d9e | |||
e9a6385d3c | |||
b522133340 | |||
49e0130e4e | |||
d4ac9531d9 | |||
66808f6147 | |||
b69956d725 | |||
a150df4310 | |||
02c0323777 | |||
14a7cab8eb | |||
acf31b754f | |||
4cdce3b927 | |||
a5356509b6 | |||
1001a055db | |||
051047bb25 | |||
c1b821a08d | |||
78e20a16a8 | |||
be07c13f75 | |||
8049dba2f7 | |||
d077abdbdf | |||
84423ca517 | |||
a22b8b009b | |||
79801167ac | |||
07fafa37b3 | |||
a0b6459c8b | |||
2a11ce3ef0 | |||
18abcf46ee | |||
652aa24919 | |||
b17908473d | |||
f137f31de6 | |||
e94b904308 | |||
ffd6167b26 | |||
4c7ce14fed | |||
0ae0d49d1d | |||
b1a98b78c9 | |||
c6d9e6cb67 | |||
31f5233949 | |||
2ca99ce909 | |||
15d9e08d3e | |||
15451d0f1c | |||
8c4a21b66d | |||
5223de2a41 | |||
f505702dc7 | |||
adf455a97c | |||
9647f60fca | |||
a8bfac6bef | |||
6d414e20e2 | |||
3c7b03935b | |||
eb771cf56d | |||
cc81ab7db7 | |||
ef965a03ed | |||
6f2ff16aad | |||
81b12af321 | |||
c1197c490e | |||
4e28492dbd | |||
6cb7267dc2 | |||
abbb66b58e | |||
ea7bb91a56 | |||
d2d840f06c | |||
0a1137e41c | |||
0767597bff | |||
cb3ed4ab9d | |||
65688208c9 | |||
72685d0398 | |||
1bb4839b0f | |||
4acb5b3abc | |||
14e593f60b | |||
66da4b3eb7 | |||
18d5fdc995 | |||
423667f00b | |||
1b092de19a | |||
69a52b00c7 | |||
9e3145cead | |||
577db33430 | |||
da6ed83dc9 | |||
7eb9692cb9 | |||
8de0e2af51 | |||
225f6b4d69 | |||
864976af23 | |||
9d736dca1c | |||
d87f6268d0 | |||
d80b98601b | |||
aa37509362 | |||
15b4c558c2 | |||
54504a2be8 | |||
8c0fee90d3 | |||
016f0293cd | |||
44daf50501 | |||
48e7caad77 | |||
8673064658 | |||
e6ecbaa68f | |||
e92325b7eb | |||
eb712f3999 | |||
30eff5a01f | |||
734ecc2844 | |||
512ab1acf9 | |||
befe2b242e | |||
f9c5ff9f08 | |||
d39c1b2319 | |||
b13778fefd | |||
076ff96eb2 | |||
0c84c26d92 | |||
d7f1d16f19 | |||
74a00eecd7 | |||
b026407fd9 | |||
a323cff654 | |||
93ed6cfa93 | |||
9797a67391 | |||
5a4382ae4d | |||
ec6a110cdf | |||
8d8c027a92 | |||
4cbd3030cc | |||
1c528d1a3c | |||
c65e7ba9b4 | |||
5a47f458ac | |||
f1032bb40a | |||
bc8a03881a | |||
42b4909bc0 | |||
bb15d6b68e | |||
23d405e1cf | |||
17e2f7f859 | |||
1d9d630fb9 | |||
9c042c2d28 | |||
a23f2aa3f7 | |||
7c5468116f | |||
a1c705b3a7 | |||
29a5e0b236 | |||
715435db42 | |||
1fc965bc1a | |||
74b98ebfaa | |||
53396adb21 | |||
63fb5fc46f |
37
.github/workflows/build-and-release.yml
vendored
Normal file
37
.github/workflows/build-and-release.yml
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
name: Build and release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.9"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install build
|
||||
|
||||
- name: Build wheels
|
||||
run: python -m build --wheel
|
||||
|
||||
- name: Release to Github
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*
|
||||
|
||||
- name: Publish package to PyPi
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
32
.github/workflows/python-compatibility.yml
vendored
Normal file
32
.github/workflows/python-compatibility.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
name: Python Compatibility Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install package
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
python -c "import whisperx; print('Successfully imported whisperx')"
|
35
.github/workflows/tmp.yml
vendored
Normal file
35
.github/workflows/tmp.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Python Compatibility Test (PyPi)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install package
|
||||
run: |
|
||||
pip install whisperx
|
||||
|
||||
- name: Print packages
|
||||
run: |
|
||||
pip list
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
python -c "import whisperx; print('Successfully imported whisperx')"
|
172
.gitignore
vendored
172
.gitignore
vendored
@ -1,3 +1,171 @@
|
||||
whisperx.egg-info/
|
||||
**/__pycache__/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
39
LICENSE
39
LICENSE
@ -1,27 +1,24 @@
|
||||
Copyright (c) 2022, Max Bain
|
||||
All rights reserved.
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2024, Max Bain
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
3. All advertising materials mentioning features or use of this software
|
||||
must display the following acknowledgement:
|
||||
This product includes software developed by Max Bain.
|
||||
4. Neither the name of Max Bain nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
|
||||
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
@ -1,4 +1,3 @@
|
||||
include whisperx/assets/*
|
||||
include whisperx/assets/gpt2/*
|
||||
include whisperx/assets/multilingual/*
|
||||
include whisperx/normalizers/english.json
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
|
79
README.md
79
README.md
@ -23,7 +23,7 @@
|
||||
</p>
|
||||
|
||||
|
||||
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
|
||||
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
||||
|
||||
|
||||
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||
@ -54,6 +54,8 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
||||
|
||||
<h2 align="left", id="highlights">New🚨</h2>
|
||||
|
||||
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||
@ -72,31 +74,53 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
|
||||
`conda activate whisperx`
|
||||
|
||||
|
||||
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
||||
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
|
||||
|
||||
`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia`
|
||||
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
|
||||
|
||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
|
||||
|
||||
### 3. Install this repo
|
||||
### 3. Install WhisperX
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git`
|
||||
You have several installation options:
|
||||
|
||||
If already installed, update package to most recent commit
|
||||
#### Option A: Stable Release (recommended)
|
||||
Install the latest stable version from PyPI:
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
|
||||
|
||||
If wishing to modify this package, clone and install in editable mode:
|
||||
```bash
|
||||
pip install whisperx
|
||||
```
|
||||
$ git clone https://github.com/m-bain/whisperX.git
|
||||
$ cd whisperX
|
||||
$ pip install -e .
|
||||
|
||||
#### Option B: Development Version
|
||||
Install the latest development version directly from GitHub (may be unstable):
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/m-bain/whisperx.git
|
||||
```
|
||||
|
||||
If already installed, update to the most recent commit:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/m-bain/whisperx.git --upgrade
|
||||
```
|
||||
|
||||
#### Option C: Development Mode
|
||||
If you wish to modify the package, clone and install in editable mode:
|
||||
```bash
|
||||
git clone https://github.com/m-bain/whisperX.git
|
||||
cd whisperX
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||
|
||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||
|
||||
### Speaker Diarization
|
||||
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||
|
||||
> **Note**<br>
|
||||
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||
|
||||
|
||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||
@ -126,6 +150,10 @@ To label the transcript with speaker ID's (set number of speakers if known e.g.
|
||||
|
||||
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
||||
|
||||
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||
|
||||
whisperx examples/sample01.wav --compute_type int8
|
||||
|
||||
### Other languages
|
||||
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||
@ -156,6 +184,10 @@ compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accura
|
||||
# 1. Transcribe with original whisper (batched)
|
||||
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
||||
|
||||
# save model to local path (optional)
|
||||
# model_dir = "/path/"
|
||||
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
|
||||
|
||||
audio = whisperx.load_audio(audio_file)
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
print(result["segments"]) # before alignment
|
||||
@ -176,14 +208,21 @@ print(result["segments"]) # after alignment
|
||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||
|
||||
# add min/max number of speakers if known
|
||||
diarize_segments = diarize_model(audio_file)
|
||||
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_segments = diarize_model(audio)
|
||||
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
print(diarize_segments)
|
||||
print(result["segments"]) # segments are now assigned speaker IDs
|
||||
```
|
||||
|
||||
## Demos 🚀
|
||||
|
||||
[](https://replicate.com/victor-upmeet/whisperx)
|
||||
[](https://replicate.com/daanelson/whisperx)
|
||||
[](https://replicate.com/carnifexer/whisperx)
|
||||
|
||||
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||
|
||||
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||
|
||||
@ -196,14 +235,14 @@ To reduce GPU memory requirements, try any of the following (2. & 3. can affect
|
||||
|
||||
Transcription differences from openai's whisper:
|
||||
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
|
||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||
|
||||
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
|
||||
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
||||
- Diarization is far from perfect (working on this with custom model v4 -- see contact me).
|
||||
- Diarization is far from perfect
|
||||
- Language specific wav2vec2 model is needed
|
||||
|
||||
|
||||
@ -247,7 +286,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||
|
||||
|
||||
Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch.
|
||||
Contact maxhbain@gmail.com for queries.
|
||||
|
||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||
|
||||
@ -261,7 +300,7 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
|
||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||
|
||||
|
||||
Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
|
||||
|
||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||
|
||||
@ -276,7 +315,7 @@ If you use this in your research, please cite the paper:
|
||||
@article{bain2022whisperx,
|
||||
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
|
||||
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
|
||||
journal={arXiv preprint, arXiv:2303.00747},
|
||||
journal={INTERSPEECH 2023},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
@ -1,8 +1,8 @@
|
||||
torch==2.0.0
|
||||
torchaudio==2.0.1
|
||||
faster-whisper
|
||||
torch>=2
|
||||
torchaudio>=2
|
||||
faster-whisper==1.1.0
|
||||
ctranslate2>=4.5.0
|
||||
transformers
|
||||
ffmpeg-python==0.2.0
|
||||
pandas
|
||||
setuptools==65.6.3
|
||||
nltk
|
||||
setuptools>=65
|
||||
nltk
|
||||
|
23
setup.py
23
setup.py
@ -1,28 +1,33 @@
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name="whisperx",
|
||||
py_modules=["whisperx"],
|
||||
version="3.1.1",
|
||||
version="3.3.2",
|
||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||
readme="README.md",
|
||||
python_requires=">=3.8",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
python_requires=">=3.9, <3.13",
|
||||
author="Max Bain",
|
||||
url="https://github.com/m-bain/whisperx",
|
||||
license="MIT",
|
||||
license="BSD-2-Clause",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
||||
entry_points = {
|
||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
||||
]
|
||||
+ [f"pyannote.audio==3.3.2"],
|
||||
entry_points={
|
||||
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={'dev': ['pytest']},
|
||||
extras_require={"dev": ["pytest"]},
|
||||
)
|
||||
|
227
whisperx/SubtitlesProcessor.py
Normal file
227
whisperx/SubtitlesProcessor.py
Normal file
@ -0,0 +1,227 @@
|
||||
import math
|
||||
from .conjunctions import get_conjunctions, get_comma
|
||||
from typing import TextIO
|
||||
|
||||
def normal_round(n):
|
||||
if n - math.floor(n) < 0.5:
|
||||
return math.floor(n)
|
||||
return math.ceil(n)
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, is_vtt: bool = False):
|
||||
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
separator = '.' if is_vtt else ','
|
||||
|
||||
hours_marker = f"{hours:02d}:"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class SubtitlesProcessor:
|
||||
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
|
||||
self.comma = get_comma(lang)
|
||||
self.conjunctions = set(get_conjunctions(lang))
|
||||
self.segments = segments
|
||||
self.lang = lang
|
||||
self.max_line_length = max_line_length
|
||||
self.min_char_length_splitter = min_char_length_splitter
|
||||
self.is_vtt = is_vtt
|
||||
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
|
||||
if self.lang in complex_script_languages:
|
||||
self.max_line_length = 30
|
||||
self.min_char_length_splitter = 20
|
||||
|
||||
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
|
||||
k = 0.25
|
||||
has_prev_end = i > 0 and 'end' in words[i - 1]
|
||||
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
|
||||
|
||||
if has_prev_end:
|
||||
words[i]['start'] = words[i - 1]['end']
|
||||
if has_next_start:
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
|
||||
|
||||
elif has_next_start:
|
||||
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['start'] = next_segment_start_time - 1
|
||||
words[i]['end'] = next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['start'] = 0
|
||||
words[i]['end'] = 0
|
||||
|
||||
|
||||
|
||||
def process_segments(self, advanced_splitting=True):
|
||||
subtitles = []
|
||||
for i, segment in enumerate(self.segments):
|
||||
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
|
||||
|
||||
if advanced_splitting:
|
||||
|
||||
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
|
||||
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
|
||||
else:
|
||||
words = segment['words']
|
||||
for i, word in enumerate(words):
|
||||
if 'start' not in word or 'end' not in word:
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
subtitles.append({
|
||||
'start': segment['start'],
|
||||
'end': segment['end'],
|
||||
'text': segment['text']
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
|
||||
split_points = []
|
||||
last_split_point = 0
|
||||
char_count = 0
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
add_space = 0 if self.lang in ['zh', 'ja'] else 1
|
||||
|
||||
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
|
||||
char_count_after = total_char_count
|
||||
|
||||
for i, word in enumerate(words):
|
||||
word_text = word['word'] if isinstance(word, dict) else word
|
||||
word_length = len(word_text) + add_space
|
||||
char_count += word_length
|
||||
char_count_after -= word_length
|
||||
|
||||
char_count_before = char_count - word_length
|
||||
|
||||
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
if char_count >= self.max_line_length:
|
||||
midpoint = normal_round((last_split_point + i) / 2)
|
||||
if char_count_before >= self.min_char_length_splitter:
|
||||
split_points.append(midpoint)
|
||||
last_split_point = midpoint + 1
|
||||
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
|
||||
|
||||
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i)
|
||||
last_split_point = i + 1
|
||||
char_count = 0
|
||||
|
||||
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i - 1)
|
||||
last_split_point = i
|
||||
char_count = word_length
|
||||
|
||||
return split_points
|
||||
|
||||
|
||||
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
|
||||
subtitles = []
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
total_word_count = len(words)
|
||||
total_time = segment['end'] - segment['start']
|
||||
elapsed_time = segment['start']
|
||||
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
|
||||
start_idx = 0
|
||||
for split_point in split_points:
|
||||
|
||||
fragment_words = words[start_idx:split_point + 1]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
|
||||
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
|
||||
end_time = next_start_time_for_word
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
elapsed_time += current_duration
|
||||
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time,
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
start_idx = split_point + 1
|
||||
|
||||
# Handle the last fragment
|
||||
if start_idx < len(words):
|
||||
fragment_words = words[start_idx:]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
|
||||
if next_start_time and (next_start_time - end_time) <= 0.8:
|
||||
end_time = next_start_time
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time if end_time is not None else segment['end'],
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
|
||||
|
||||
def save(self, filename="subtitles.srt", advanced_splitting=True):
|
||||
|
||||
subtitles = self.process_segments(advanced_splitting)
|
||||
|
||||
def write_subtitle(file, idx, start_time, end_time, text):
|
||||
|
||||
file.write(f"{idx}\n")
|
||||
file.write(f"{start_time} --> {end_time}\n")
|
||||
file.write(text + "\n\n")
|
||||
|
||||
with open(filename, 'w', encoding='utf-8') as file:
|
||||
if self.is_vtt:
|
||||
file.write("WEBVTT\n\n")
|
||||
|
||||
if advanced_splitting:
|
||||
for idx, subtitle in enumerate(subtitles, 1):
|
||||
start_time = format_timestamp(subtitle['start'], self.is_vtt)
|
||||
end_time = format_timestamp(subtitle['end'], self.is_vtt)
|
||||
text = subtitle['text'].strip()
|
||||
write_subtitle(file, idx, start_time, end_time, text)
|
||||
|
||||
return len(subtitles)
|
@ -3,7 +3,7 @@ Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union, List
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -15,6 +15,9 @@ from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@ -33,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
||||
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
||||
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
||||
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
||||
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
||||
@ -42,10 +46,26 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
|
||||
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
||||
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
||||
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
|
||||
"ca": "softcatala/wav2vec2-large-xlsr-catala",
|
||||
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
|
||||
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
|
||||
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
|
||||
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
|
||||
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
|
||||
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
||||
"ro": "gigant/romanian-wav2vec2",
|
||||
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
||||
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
||||
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
||||
}
|
||||
|
||||
|
||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
@ -65,8 +85,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
try:
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||
@ -82,13 +102,15 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[SingleSegment],
|
||||
transcript: Iterable[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
print_progress: bool = False,
|
||||
combined_progress: bool = False,
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
@ -108,8 +130,14 @@ def align(
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
total_segments = len(transcript)
|
||||
for sdx, segment in enumerate(transcript):
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
if print_progress:
|
||||
base_progress = ((sdx + 1) / total_segments) * 100
|
||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
@ -141,7 +169,11 @@ def align(
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||
|
||||
punkt_param = PunktParameters()
|
||||
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
@ -149,9 +181,10 @@ def align(
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
@ -172,8 +205,8 @@ def align(
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
if t1 >= MAX_DURATION:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
@ -185,10 +218,18 @@ def align(
|
||||
|
||||
# TODO: Probably can get some speedup gain with batched inference here
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
# Handle the minimum input length for wav2vec2 models
|
||||
if waveform_segment.shape[-1] < 400:
|
||||
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||
waveform_segment = torch.nn.functional.pad(
|
||||
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
@ -253,7 +294,8 @@ def align(
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
sentence_end = curr_chars["end"].max()
|
||||
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||
sentence_end = end_chars["end"].max()
|
||||
sentence_words = []
|
||||
|
||||
for word_idx in curr_chars["word-idx"].unique():
|
||||
@ -300,6 +342,8 @@ def align(
|
||||
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||
# concatenate sentences with same timestamps
|
||||
agg_dict = {"text": " ".join, "words": "sum"}
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
agg_dict["text"] = "".join
|
||||
if return_char_alignments:
|
||||
agg_dict["chars"] = "sum"
|
||||
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||
|
329
whisperx/asr.py
329
whisperx/asr.py
@ -1,79 +1,30 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from dataclasses import replace
|
||||
|
||||
import ctranslate2
|
||||
import faster_whisper
|
||||
import numpy as np
|
||||
import torch
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
from .types import TranscriptionResult, SingleSegment
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None, task="transcribe"):
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
device: str - The device to load the model on.
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
|
||||
from .types import SingleSegment, TranscriptionResult
|
||||
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
|
||||
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
|
||||
if has_numeral_symbol:
|
||||
numeral_symbol_tokens.append(i)
|
||||
return numeral_symbol_tokens
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
@ -81,7 +32,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
batch_size = features.shape[0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
@ -106,15 +63,14 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
# length_penalty=options.length_penalty,
|
||||
# max_length=self.max_length,
|
||||
# return_scores=True,
|
||||
# return_no_speech_prob=True,
|
||||
# suppress_blank=options.suppress_blank,
|
||||
# suppress_tokens=options.suppress_tokens,
|
||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
@ -127,7 +83,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
@ -135,10 +91,10 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
# unsqueeze if batch size = 1
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
||||
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
@ -148,18 +104,23 @@ class FasterWhisperPipeline(Pipeline):
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
vad,
|
||||
options,
|
||||
tokenizer=None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework = "pt",
|
||||
**kwargs
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad: VoiceActivitySegmentation,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
suppress_numerals: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.suppress_numerals = suppress_numerals
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
@ -176,9 +137,10 @@ class FasterWhisperPipeline(Pipeline):
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
self._vad_params = vad_params
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
@ -188,18 +150,29 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
features = log_mel_spectrogram(
|
||||
audio,
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=N_SAMPLES - audio.shape[0],
|
||||
)
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
||||
self,
|
||||
inputs,
|
||||
num_workers: int,
|
||||
batch_size: int,
|
||||
preprocess_params: dict,
|
||||
forward_params: dict,
|
||||
postprocess_params: dict,
|
||||
):
|
||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||
@ -214,11 +187,20 @@ class FasterWhisperPipeline(Pipeline):
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers=0,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
chunk_size=30,
|
||||
print_progress=False,
|
||||
combined_progress=False,
|
||||
verbose=False,
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
@ -227,22 +209,53 @@ class FasterWhisperPipeline(Pipeline):
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
del_tokenizer = False
|
||||
vad_segments = merge_chunks(
|
||||
vad_segments,
|
||||
chunk_size,
|
||||
onset=self._vad_params["vad_onset"],
|
||||
offset=self._vad_params["vad_offset"],
|
||||
)
|
||||
if self.tokenizer is None:
|
||||
language = self.detect_language(audio)
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
|
||||
del_tokenizer = True
|
||||
language = language or self.detect_language(audio)
|
||||
task = task or "transcribe"
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
language = self.tokenizer.language_code
|
||||
language = language or self.tokenizer.language_code
|
||||
task = task or self.tokenizer.task
|
||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
|
||||
if self.suppress_numerals:
|
||||
previous_suppress_tokens = self.options.suppress_tokens
|
||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||
print(f"Suppressing numeral and symbol tokens")
|
||||
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
total_segments = len(vad_segments)
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
if print_progress:
|
||||
base_progress = ((idx + 1) / total_segments) * 100
|
||||
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
text = out['text']
|
||||
if batch_size in [0, 1, None]:
|
||||
text = text[0]
|
||||
if verbose:
|
||||
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
|
||||
segments.append(
|
||||
{
|
||||
"text": text,
|
||||
@ -250,17 +263,23 @@ class FasterWhisperPipeline(Pipeline):
|
||||
"end": round(vad_segments[idx]['end'], 3)
|
||||
}
|
||||
)
|
||||
|
||||
if del_tokenizer:
|
||||
|
||||
# revert the tokenizer if multilingual inference is enabled
|
||||
if self.preset_language is None:
|
||||
self.tokenizer = None
|
||||
|
||||
# revert suppressed tokens if suppress_numerals is enabled
|
||||
if self.suppress_numerals:
|
||||
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
|
||||
|
||||
return {"segments": segments, "language": language}
|
||||
|
||||
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
def detect_language(self, audio: np.ndarray) -> str:
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||
encoder_output = self.model.encode(segment)
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
@ -268,3 +287,111 @@ class FasterWhisperPipeline(Pipeline):
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
||||
|
||||
|
||||
def load_model(
|
||||
whisper_arch: str,
|
||||
device: str,
|
||||
device_index=0,
|
||||
compute_type="float16",
|
||||
asr_options: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
vad_model: Optional[VoiceActivitySegmentation] = None,
|
||||
vad_options: Optional[dict] = None,
|
||||
model: Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
download_root: Optional[str] = None,
|
||||
local_files_only=False,
|
||||
threads=4,
|
||||
) -> FasterWhisperPipeline:
|
||||
"""Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch - The name of the Whisper model to load.
|
||||
device - The device to load the model on.
|
||||
compute_type - The compute type to use for the model.
|
||||
options - A dictionary of options to use for the model.
|
||||
language - The language of the model. (use English for now)
|
||||
model - The WhisperModel instance to use.
|
||||
download_root - The root directory to download the model to.
|
||||
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
||||
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
"""
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = model or WhisperModel(whisper_arch,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
download_root=download_root,
|
||||
local_files_only=local_files_only,
|
||||
cpu_threads=threads)
|
||||
if language is not None:
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"repetition_penalty": 1,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"prompt_reset_on_temperature": 0.5,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||
"multilingual": model.model.is_multilingual,
|
||||
"suppress_numerals": False,
|
||||
"max_new_tokens": None,
|
||||
"clip_timestamps": None,
|
||||
"hallucination_silence_threshold": None,
|
||||
"hotwords": None,
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
|
||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
if vad_model is not None:
|
||||
vad_model = vad_model
|
||||
else:
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
vad=vad_model,
|
||||
options=default_asr_options,
|
||||
tokenizer=tokenizer,
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
Binary file not shown.
BIN
whisperx/assets/pytorch_model.bin
Normal file
BIN
whisperx/assets/pytorch_model.bin
Normal file
Binary file not shown.
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import subprocess
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -12,7 +12,6 @@ from .utils import exact_div
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
@ -23,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
@ -40,14 +39,27 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
try:
|
||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except ffmpeg.Error as e:
|
||||
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI to be installed.
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"0",
|
||||
"-i",
|
||||
file,
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(sr),
|
||||
"-",
|
||||
]
|
||||
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
@ -80,7 +92,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
@ -90,7 +102,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
@ -99,7 +111,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
n_mels: int,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
|
47
whisperx/conjunctions.py
Normal file
47
whisperx/conjunctions.py
Normal file
@ -0,0 +1,47 @@
|
||||
# conjunctions.py
|
||||
|
||||
from typing import Set
|
||||
|
||||
|
||||
conjunctions_by_language = {
|
||||
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
|
||||
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
|
||||
'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
|
||||
'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
|
||||
'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'},
|
||||
'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
|
||||
'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
|
||||
'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
|
||||
'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
|
||||
'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'},
|
||||
'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
|
||||
'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
|
||||
'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
|
||||
'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
|
||||
'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
|
||||
'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
|
||||
'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
|
||||
'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
|
||||
'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
|
||||
'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
|
||||
'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
|
||||
'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'},
|
||||
'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'},
|
||||
'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
|
||||
'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'}
|
||||
|
||||
}
|
||||
|
||||
commas_by_language = {
|
||||
'ja': '、',
|
||||
'zh': ',',
|
||||
'fa': '،',
|
||||
'ur': '،'
|
||||
}
|
||||
|
||||
def get_conjunctions(lang_code: str) -> Set[str]:
|
||||
return conjunctions_by_language.get(lang_code, set())
|
||||
|
||||
|
||||
def get_comma(lang_code: str) -> str:
|
||||
return commas_by_language.get(lang_code, ",")
|
@ -4,10 +4,14 @@ from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from .audio import load_audio, SAMPLE_RATE
|
||||
from .types import TranscriptionResult, AlignedTranscriptionResult
|
||||
|
||||
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization@2.1",
|
||||
model_name="pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
@ -15,16 +19,31 @@ class DiarizationPipeline:
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
||||
def __call__(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio_data = {
|
||||
'waveform': torch.from_numpy(audio[None, :]),
|
||||
'sample_rate': SAMPLE_RATE
|
||||
}
|
||||
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
||||
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
||||
return diarize_df
|
||||
|
||||
|
||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
def assign_word_speakers(
|
||||
diarize_df: pd.DataFrame,
|
||||
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||
fill_nearest=False,
|
||||
) -> dict:
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
|
@ -10,8 +10,15 @@ from .alignment import align, load_align_model
|
||||
from .asr import load_model
|
||||
from .audio import load_audio
|
||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
||||
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
|
||||
optional_int, str2bool)
|
||||
from .types import AlignedTranscriptionResult, TranscriptionResult
|
||||
from .utils import (
|
||||
LANGUAGES,
|
||||
TO_LANGUAGE_CODE,
|
||||
get_writer,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
@ -21,11 +28,12 @@ def cli():
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
@ -40,6 +48,7 @@ def cli():
|
||||
# vad params
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
@ -49,10 +58,12 @@ def cli():
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
@ -63,22 +74,27 @@ def cli():
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
batch_size: int = args.pop("batch_size")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
device_index: int = args.pop("device_index")
|
||||
compute_type: str = args.pop("compute_type")
|
||||
verbose: bool = args.pop("verbose")
|
||||
|
||||
# model_flush: bool = args.pop("model_flush")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -86,7 +102,7 @@ def cli():
|
||||
align_model: str = args.pop("align_model")
|
||||
interpolate_method: str = args.pop("interpolate_method")
|
||||
no_align: bool = args.pop("no_align")
|
||||
task : str = args.pop("task")
|
||||
task: str = args.pop("task")
|
||||
if task == "translate":
|
||||
# translation cannot be aligned
|
||||
no_align = True
|
||||
@ -97,17 +113,28 @@ def cli():
|
||||
vad_onset: float = args.pop("vad_onset")
|
||||
vad_offset: float = args.pop("vad_offset")
|
||||
|
||||
chunk_size: int = args.pop("chunk_size")
|
||||
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
print_progress: bool = args.pop("print_progress")
|
||||
|
||||
if args["language"] is not None:
|
||||
args["language"] = args["language"].lower()
|
||||
if args["language"] not in LANGUAGES:
|
||||
if args["language"] in TO_LANGUAGE_CODE:
|
||||
args["language"] = TO_LANGUAGE_CODE[args["language"]]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {args['language']}")
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if model_name.endswith(".en") and args["language"] != "en":
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
@ -115,8 +142,10 @@ def cli():
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
faster_whisper_threads = 4
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
faster_whisper_threads = threads
|
||||
|
||||
asr_options = {
|
||||
"beam_size": args.pop("beam_size"),
|
||||
@ -128,6 +157,8 @@ def cli():
|
||||
"no_speech_threshold": args.pop("no_speech_threshold"),
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": args.pop("initial_prompt"),
|
||||
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
||||
"suppress_numerals": args.pop("suppress_numerals"),
|
||||
}
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@ -135,7 +166,7 @@ def cli():
|
||||
if no_align:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
parser.error(f"--{option} not possible with --no_align")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
@ -144,13 +175,19 @@ def cli():
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
|
||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
# >> VAD & ASR
|
||||
print(">>Performing transcription...")
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
result: TranscriptionResult = model.transcribe(
|
||||
audio,
|
||||
batch_size=batch_size,
|
||||
chunk_size=chunk_size,
|
||||
print_progress=print_progress,
|
||||
verbose=verbose,
|
||||
)
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload Whisper and VAD
|
||||
@ -162,7 +199,6 @@ def cli():
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
for result, audio_path in tmp_results:
|
||||
# >> Align
|
||||
@ -178,7 +214,16 @@ def cli():
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
||||
result: AlignedTranscriptionResult = align(
|
||||
result["segments"],
|
||||
align_model,
|
||||
align_metadata,
|
||||
input_audio,
|
||||
device,
|
||||
interpolate_method=interpolate_method,
|
||||
return_char_alignments=return_char_alignments,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
results.append((result, audio_path))
|
||||
|
||||
@ -201,7 +246,8 @@ def cli():
|
||||
results.append((result, input_audio_path))
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
result["language"] = align_language
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
cli()
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TypedDict, Optional
|
||||
from typing import TypedDict, Optional, List
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
@ -38,15 +38,15 @@ class SingleAlignedSegment(TypedDict):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: list[SingleWordSegment]
|
||||
chars: Optional[list[SingleCharSegment]]
|
||||
words: List[SingleWordSegment]
|
||||
chars: Optional[List[SingleCharSegment]]
|
||||
|
||||
|
||||
class TranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleSegment]
|
||||
segments: List[SingleSegment]
|
||||
language: str
|
||||
|
||||
|
||||
@ -54,5 +54,5 @@ class AlignedTranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleAlignedSegment]
|
||||
word_segments: list[SingleWordSegment]
|
||||
segments: List[SingleAlignedSegment]
|
||||
word_segments: List[SingleWordSegment]
|
||||
|
@ -105,6 +105,7 @@ LANGUAGES = {
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
@ -123,6 +124,7 @@ TO_LANGUAGE_CODE = {
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@ -212,7 +214,12 @@ class WriteTXT(ResultWriter):
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
speaker = segment.get("speaker")
|
||||
text = segment["text"].strip()
|
||||
if speaker is not None:
|
||||
print(f"[{speaker}]: {text}", file=file, flush=True)
|
||||
else:
|
||||
print(text, file=file, flush=True)
|
||||
|
||||
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
@ -226,6 +233,9 @@ class SubtitlesWriter(ResultWriter):
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
|
||||
if len(result["segments"]) == 0:
|
||||
return
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
@ -277,7 +287,10 @@ class SubtitlesWriter(ResultWriter):
|
||||
sstart, ssend, speaker = _[0]
|
||||
subtitle_start = self.format_timestamp(sstart)
|
||||
subtitle_end = self.format_timestamp(ssend)
|
||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||
if result["language"] in LANGUAGES_WITHOUT_SPACES:
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
else:
|
||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||
has_timing = any(["start" in word for word in subtitle])
|
||||
|
||||
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||
@ -293,7 +306,7 @@ class SubtitlesWriter(ResultWriter):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
yield last, start, prefix + subtitle_text
|
||||
|
||||
yield start, end, prefix + " ".join(
|
||||
[
|
||||
@ -365,12 +378,34 @@ class WriteTSV(ResultWriter):
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
class WriteAudacity(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a text file that audacity can import as labels.
|
||||
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
|
||||
Yet this is not an audacity project but only a label file!
|
||||
|
||||
Please note : Audacity uses seconds in timestamps not ms!
|
||||
Also there is no header expected.
|
||||
|
||||
If speaker is provided it is prepended to the text between double square brackets [[]].
|
||||
"""
|
||||
|
||||
extension: str = "aud"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
ARROW = " "
|
||||
for segment in result["segments"]:
|
||||
print(segment["start"], file=file, end=ARROW)
|
||||
print(segment["end"], file=file, end=ARROW)
|
||||
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
json.dump(result, file)
|
||||
json.dump(result, file, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_writer(
|
||||
@ -383,6 +418,9 @@ def get_writer(
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
optional_writers = {
|
||||
"aud": WriteAudacity,
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
@ -393,10 +431,12 @@ def get_writer(
|
||||
|
||||
return write_all
|
||||
|
||||
if output_format in optional_writers:
|
||||
return optional_writers[output_format](output_dir)
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
def interpolate_nans(x, method='nearest'):
|
||||
if x.notnull().sum() > 1:
|
||||
return x.interpolate(method=method).ffill().bfill()
|
||||
else:
|
||||
return x.ffill().bfill()
|
||||
return x.ffill().bfill()
|
||||
|
@ -15,37 +15,33 @@ from tqdm import tqdm
|
||||
|
||||
from .diarize import Segment as SegmentX
|
||||
|
||||
# deprecated
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
|
||||
vad_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
# Dynamically resolve the path to the model file
|
||||
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
|
||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||
else:
|
||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||
|
||||
# Check if the resolved model file exists
|
||||
if not os.path.exists(model_fp):
|
||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
if not os.path.isfile(model_fp):
|
||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
@ -142,13 +138,12 @@ class Binarize:
|
||||
is_active = k_scores[0] > self.onset
|
||||
curr_scores = [k_scores[0]]
|
||||
curr_timestamps = [start]
|
||||
t = start
|
||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||
# currently active
|
||||
if is_active:
|
||||
curr_duration = t - start
|
||||
if curr_duration > self.max_duration:
|
||||
# if curr_duration > 15:
|
||||
# import pdb; pdb.set_trace()
|
||||
search_after = len(curr_scores) // 2
|
||||
# divide segment
|
||||
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||||
@ -166,14 +161,14 @@ class Binarize:
|
||||
is_active = False
|
||||
curr_scores = []
|
||||
curr_timestamps = []
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
# currently inactive
|
||||
else:
|
||||
# switching from inactive to active
|
||||
if y > self.onset:
|
||||
start = t
|
||||
is_active = True
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
|
||||
# if active at the end, add final region
|
||||
if is_active:
|
||||
@ -262,7 +257,12 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
||||
|
||||
def merge_chunks(segments, chunk_size):
|
||||
def merge_chunks(
|
||||
segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
@ -272,7 +272,7 @@ def merge_chunks(segments, chunk_size):
|
||||
speaker_idxs = []
|
||||
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size)
|
||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
|
Reference in New Issue
Block a user