mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
322 Commits
v3.0.2
...
429658d4cc
Author | SHA1 | Date | |
---|---|---|---|
429658d4cc | |||
e0833da5dc | |||
ffedc5cdf0 | |||
b93e9b6f57 | |||
844736e4e4 | |||
220fec9aea | |||
1631c3040f | |||
d700b56c9c | |||
b343241253 | |||
6fe0a8784a | |||
5012650d0f | |||
108bd0c400 | |||
b2d50a027b | |||
36d552cad3 | |||
7d36b832f9 | |||
d2a493e910 | |||
f5b40b5366 | |||
ac0c8bd79a | |||
cd59f21d1a | |||
0aed874589 | |||
f10dbf6ab1 | |||
a7564c2ad6 | |||
e7712f496e | |||
8e53866704 | |||
3205436d58 | |||
8c58c54635 | |||
0d9807adc5 | |||
4db839018c | |||
f8d11df727 | |||
d2f0e53f71 | |||
7489ebf876 | |||
90256cc481 | |||
b41ebd4871 | |||
63bc1903c1 | |||
272714e07d | |||
44e8bf5bb6 | |||
7b3c9ce629 | |||
36d2622e27 | |||
8bfa12193b | |||
acbeba6057 | |||
fca563a782 | |||
2117909bf6 | |||
de0d8fe313 | |||
355f8e06f7 | |||
86e2b3ee74 | |||
70c639cdb5 | |||
235536e28d | |||
12604a48ea | |||
ffbc73664c | |||
289eadfc76 | |||
22a93f2932 | |||
1027367b79 | |||
5e54b872a9 | |||
6be02cccfa | |||
2f93e029c7 | |||
024bc8481b | |||
f286e7f3de | |||
73e644559d | |||
1ec527375a | |||
6695426a85 | |||
7a98456321 | |||
aaddb83aa5 | |||
c288f4812a | |||
4ebfb078c5 | |||
65b2332e13 | |||
69281f3a29 | |||
734084cdf6 | |||
9395b0de18 | |||
d57f9dc54c | |||
a90bd1ce3f | |||
79eb8fa53d | |||
10b05fc43f | |||
26d9b46888 | |||
9a8967f27e | |||
0f7f9f9f83 | |||
c60594fa3b | |||
4916192246 | |||
cbdac53e87 | |||
940a223219 | |||
a0eb31019b | |||
b08ad67a72 | |||
c18f9f979b | |||
948b3e368b | |||
e9ac5b63bc | |||
90b45459d9 | |||
81c4af96a6 | |||
1c6d9327bc | |||
0fdb55d317 | |||
51da22771f | |||
15ad5bf7df | |||
7fdbd21fe3 | |||
3ff625c561 | |||
7307306a9d | |||
3027cc32bc | |||
9e4b1b4c49 | |||
9b9e03c4cc | |||
19eff8e79a | |||
6f3bc5b7b8 | |||
9809336db6 | |||
a898b3ba94 | |||
c141074cbd | |||
a9e50ef0af | |||
161ae1f7ad | |||
a83ddbdf9b | |||
9e3a9e0e38 | |||
3f339f9515 | |||
9a9b6171e6 | |||
59b4d88d1d | |||
6f70aa6beb | |||
912920c591 | |||
58f00339af | |||
f2da2f858e | |||
78dcfaab51 | |||
d6562c26da | |||
c313f4dd5c | |||
bbaa2f0d1a | |||
e906be9688 | |||
fbbd07bece | |||
d8c9196346 | |||
2686f74bc9 | |||
8227807fa9 | |||
59962a70be | |||
06e30b2a25 | |||
6bb2f1cd48 | |||
f8cc46c6f7 | |||
942c336b8f | |||
8ae6416594 | |||
8540ff5985 | |||
5dfbfcbdc0 | |||
1c7b1a87da | |||
9f23739f90 | |||
19ab91c5a6 | |||
089cd5ab21 | |||
2b7ab95ad6 | |||
4553e0d4ed | |||
f865dfe710 | |||
4acbdd75be | |||
e9c507ce5d | |||
a5dca2cc65 | |||
8a8eeb33ee | |||
b4d7b1a422 | |||
5a16e59217 | |||
b4e4143e3b | |||
4b05198eed | |||
71a5281bde | |||
d97cdb7bcf | |||
20161935a1 | |||
1d7f8ccbf1 | |||
5756b0fb13 | |||
aaaa3de810 | |||
ba30365344 | |||
bd3aa03b6f | |||
f5c544ff90 | |||
7c2a9a8b7b | |||
9f41c49fe5 | |||
48d651e5ea | |||
4ece2369d7 | |||
52fbe5c26f | |||
6703d2774b | |||
a2af569838 | |||
0c7f32f55c | |||
6936dd6991 | |||
6b1100a919 | |||
d4a600b568 | |||
afd5ef1d58 | |||
dbeb8617f2 | |||
c6fe379d9e | |||
e9a6385d3c | |||
b522133340 | |||
49e0130e4e | |||
d4ac9531d9 | |||
66808f6147 | |||
b69956d725 | |||
a150df4310 | |||
02c0323777 | |||
14a7cab8eb | |||
acf31b754f | |||
4cdce3b927 | |||
a5356509b6 | |||
1001a055db | |||
051047bb25 | |||
c1b821a08d | |||
78e20a16a8 | |||
be07c13f75 | |||
8049dba2f7 | |||
d077abdbdf | |||
84423ca517 | |||
a22b8b009b | |||
79801167ac | |||
07fafa37b3 | |||
a0b6459c8b | |||
2a11ce3ef0 | |||
18abcf46ee | |||
652aa24919 | |||
b17908473d | |||
f137f31de6 | |||
e94b904308 | |||
ffd6167b26 | |||
4c7ce14fed | |||
0ae0d49d1d | |||
b1a98b78c9 | |||
c6d9e6cb67 | |||
31f5233949 | |||
2ca99ce909 | |||
15d9e08d3e | |||
15451d0f1c | |||
8c4a21b66d | |||
5223de2a41 | |||
f505702dc7 | |||
adf455a97c | |||
9647f60fca | |||
a8bfac6bef | |||
6d414e20e2 | |||
3c7b03935b | |||
eb771cf56d | |||
cc81ab7db7 | |||
ef965a03ed | |||
6f2ff16aad | |||
81b12af321 | |||
c1197c490e | |||
4e28492dbd | |||
6cb7267dc2 | |||
abbb66b58e | |||
ea7bb91a56 | |||
d2d840f06c | |||
0a1137e41c | |||
0767597bff | |||
cb3ed4ab9d | |||
65688208c9 | |||
72685d0398 | |||
1bb4839b0f | |||
4acb5b3abc | |||
14e593f60b | |||
66da4b3eb7 | |||
18d5fdc995 | |||
423667f00b | |||
1b092de19a | |||
69a52b00c7 | |||
9e3145cead | |||
577db33430 | |||
da6ed83dc9 | |||
7eb9692cb9 | |||
8de0e2af51 | |||
225f6b4d69 | |||
864976af23 | |||
9d736dca1c | |||
d87f6268d0 | |||
d80b98601b | |||
aa37509362 | |||
15b4c558c2 | |||
54504a2be8 | |||
8c0fee90d3 | |||
016f0293cd | |||
44daf50501 | |||
48e7caad77 | |||
8673064658 | |||
e6ecbaa68f | |||
e92325b7eb | |||
eb712f3999 | |||
30eff5a01f | |||
734ecc2844 | |||
512ab1acf9 | |||
befe2b242e | |||
f9c5ff9f08 | |||
d39c1b2319 | |||
b13778fefd | |||
076ff96eb2 | |||
0c84c26d92 | |||
d7f1d16f19 | |||
74a00eecd7 | |||
b026407fd9 | |||
a323cff654 | |||
93ed6cfa93 | |||
9797a67391 | |||
5a4382ae4d | |||
ec6a110cdf | |||
8d8c027a92 | |||
4cbd3030cc | |||
1c528d1a3c | |||
c65e7ba9b4 | |||
5a47f458ac | |||
f1032bb40a | |||
bc8a03881a | |||
42b4909bc0 | |||
bb15d6b68e | |||
23d405e1cf | |||
17e2f7f859 | |||
1d9d630fb9 | |||
9c042c2d28 | |||
a23f2aa3f7 | |||
7c5468116f | |||
a1c705b3a7 | |||
29a5e0b236 | |||
715435db42 | |||
1fc965bc1a | |||
74b98ebfaa | |||
53396adb21 | |||
63fb5fc46f | |||
d8a2b4ffc9 | |||
9ffb7e7a23 | |||
fd8f1003cf | |||
46b416296f | |||
7642390d0a | |||
8b05ad4dae | |||
5421f1d7ca | |||
91e959ec4f | |||
eabf35dff0 | |||
4919ad21fc | |||
b50aafb17b | |||
2efa136114 | |||
0b839f3f01 | |||
1caddfb564 | |||
7ad554c64f | |||
4603f010a5 | |||
24008aa1ed | |||
07361ba1d7 | |||
b666523004 | |||
69e038cbc4 | |||
a693a779fa | |||
5b85c5433f | |||
d31f6e0b8a | |||
c8404d9805 |
34
.github/workflows/build-and-release.yml
vendored
Normal file
34
.github/workflows/build-and-release.yml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: Build and release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: "0.5.14"
|
||||
python-version: "3.9"
|
||||
|
||||
- name: Check if lockfile is up to date
|
||||
run: uv lock --check
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
|
||||
- name: Release to Github
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*.whl
|
||||
|
||||
- name: Publish package to PyPi
|
||||
run: uv publish
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
34
.github/workflows/python-compatibility.yml
vendored
Normal file
34
.github/workflows/python-compatibility.yml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: Python Compatibility Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: "0.5.14"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Check if lockfile is up to date
|
||||
run: uv lock --check
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
uv run python -c "import whisperx; print('Successfully imported whisperx')"
|
173
.gitignore
vendored
173
.gitignore
vendored
@ -1,2 +1,171 @@
|
||||
whisperx.egg-info/
|
||||
**/__pycache__/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
39
LICENSE
39
LICENSE
@ -1,27 +1,24 @@
|
||||
Copyright (c) 2022, Max Bain
|
||||
All rights reserved.
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2024, Max Bain
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
3. All advertising materials mentioning features or use of this software
|
||||
must display the following acknowledgement:
|
||||
This product includes software developed by Max Bain.
|
||||
4. Neither the name of Max Bain nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
|
||||
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
@ -1,4 +1,3 @@
|
||||
include whisperx/assets/*
|
||||
include whisperx/assets/gpt2/*
|
||||
include whisperx/assets/multilingual/*
|
||||
include whisperx/normalizers/english.json
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
|
287
README.md
287
README.md
@ -13,36 +13,30 @@
|
||||
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
|
||||
alt="GitHub license">
|
||||
</a>
|
||||
<a href="https://arxiv.org/abs/2303.00747">
|
||||
<img src="http://img.shields.io/badge/Arxiv-2303.00747-B31B1B.svg"
|
||||
alt="ArXiv paper">
|
||||
</a>
|
||||
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
|
||||
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#what-is-it">What is it</a> •
|
||||
<a href="#setup">Setup</a> •
|
||||
<a href="#example">Usage</a> •
|
||||
<a href="#other-languages">Multilingual</a> •
|
||||
<a href="#contribute">Contribute</a> •
|
||||
<a href="EXAMPLES.md">More examples</a> •
|
||||
<a href="https://arxiv.org/abs/2303.00747">Paper</a>
|
||||
</p>
|
||||
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
||||
|
||||
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||
|
||||
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
|
||||
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
||||
|
||||
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
||||
|
||||
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
|
||||
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
||||
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
||||
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
||||
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
||||
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
<h2 align="left", id="what-is-it">What is it 🔎</h2>
|
||||
|
||||
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
|
||||
|
||||
|
||||
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
|
||||
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
|
||||
|
||||
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
||||
|
||||
@ -50,72 +44,94 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
|
||||
|
||||
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
|
||||
|
||||
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
|
||||
|
||||
<h2 align="left", id="highlights">New🚨</h2>
|
||||
|
||||
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
|
||||
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
||||
- Character level timestamps (see `*.char.ass` file output)
|
||||
- Diarization (still in beta, add `--diarize`)
|
||||
|
||||
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
|
||||
|
||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||
|
||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||
### 1. Simple Installation (Recommended)
|
||||
|
||||
The easiest way to install WhisperX is through PyPi:
|
||||
|
||||
### 1. Create Python3.10 environment
|
||||
|
||||
`conda create --name whisperx python=3.10`
|
||||
|
||||
`conda activate whisperx`
|
||||
|
||||
|
||||
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
||||
|
||||
`pip3 install torch torchvision torchaudio`
|
||||
|
||||
See other methods [here.](https://pytorch.org/get-started/locally/)
|
||||
|
||||
### 3. Install this repo
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git@v3`
|
||||
|
||||
If already installed, update package to most recent commit
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
|
||||
|
||||
If wishing to modify this package, clone and install in editable mode:
|
||||
```bash
|
||||
pip install whisperx
|
||||
```
|
||||
$ git clone https://github.com/m-bain/whisperX.git@v3
|
||||
$ cd whisperX
|
||||
$ git checkout v3
|
||||
$ pip install -e .
|
||||
|
||||
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
|
||||
|
||||
```bash
|
||||
uvx whisperx
|
||||
```
|
||||
|
||||
### 2. Advanced Installation Options
|
||||
|
||||
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
|
||||
|
||||
#### Option A: Install from GitHub
|
||||
|
||||
To install directly from the GitHub repository:
|
||||
|
||||
```bash
|
||||
uvx git+https://github.com/m-bain/whisperX.git
|
||||
```
|
||||
|
||||
#### Option B: Developer Installation
|
||||
|
||||
If you want to modify the code or contribute to the project:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/m-bain/whisperX.git
|
||||
cd whisperX
|
||||
uv sync --all-extras --dev
|
||||
```
|
||||
|
||||
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||
|
||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||
|
||||
### Common Issues & Troubleshooting 🔧
|
||||
|
||||
#### libcudnn Dependencies (GPU Users)
|
||||
|
||||
If you're using WhisperX with GPU support and encounter errors like:
|
||||
|
||||
- `Could not load library libcudnn_ops_infer.so.8`
|
||||
- `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
|
||||
- `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory`
|
||||
|
||||
This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default.
|
||||
|
||||
**Install cuDNN (example for apt based systems):**
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install libcudnn8 libcudnn8-dev -y
|
||||
```
|
||||
|
||||
### Speaker Diarization
|
||||
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||
|
||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||
|
||||
> **Note**<br>
|
||||
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||
|
||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||
|
||||
### English
|
||||
|
||||
Run whisper on example segment (using default params)
|
||||
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
||||
|
||||
whisperx examples/sample01.wav
|
||||
whisperx path/to/audio.wav
|
||||
|
||||
|
||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||
|
||||
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
|
||||
|
||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
||||
|
||||
@ -123,115 +139,170 @@ Compare this to original whisper out the box, where many transcriptions are out
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
||||
|
||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||
|
||||
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||
|
||||
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||
|
||||
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
|
||||
|
||||
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||
|
||||
whisperx path/to/audio.wav --compute_type int8
|
||||
|
||||
### Other languages
|
||||
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||
The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||
Just pass in the `--language` code, and use the whisper `--model large`.
|
||||
|
||||
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||
|
||||
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||
|
||||
#### E.g. German
|
||||
whisperx --model large --language de examples/sample_de_01.wav
|
||||
|
||||
whisperx --model large-v2 --language de path/to/audio.wav
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||
|
||||
|
||||
See more examples in other languages [here](EXAMPLES.md).
|
||||
|
||||
## Python usage 🐍
|
||||
## Python usage 🐍
|
||||
|
||||
```python
|
||||
import whisperx
|
||||
import gc
|
||||
|
||||
device = "cuda"
|
||||
device = "cuda"
|
||||
audio_file = "audio.mp3"
|
||||
batch_size = 16 # reduce if low on GPU mem
|
||||
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
|
||||
|
||||
# transcribe with original whisper
|
||||
model = whisperx.load_model("large-v2", device)
|
||||
# 1. Transcribe with original whisper (batched)
|
||||
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
||||
|
||||
# save model to local path (optional)
|
||||
# model_dir = "/path/"
|
||||
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
|
||||
|
||||
audio = whisperx.load_audio(audio_file)
|
||||
result = model.transcribe(audio, batch_size=8)
|
||||
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# load alignment model and metadata
|
||||
# delete model if low on GPU resources
|
||||
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
|
||||
|
||||
# 2. Align whisper output
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
||||
|
||||
# align whisper output
|
||||
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
|
||||
print(result["segments"]) # after alignment
|
||||
|
||||
print(result_aligned["segments"]) # after alignment
|
||||
print(result_aligned["word_segments"]) # after alignment
|
||||
# delete model if low on GPU resources
|
||||
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||
|
||||
# 3. Assign speaker labels
|
||||
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||
|
||||
# add min/max number of speakers if known
|
||||
diarize_segments = diarize_model(audio)
|
||||
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
print(diarize_segments)
|
||||
print(result["segments"]) # segments are now assigned speaker IDs
|
||||
```
|
||||
|
||||
## Demos 🚀
|
||||
|
||||
<h2 align="left" id="whisper-mod">Whisper Modifications</h2>
|
||||
[](https://replicate.com/victor-upmeet/whisperx)
|
||||
[](https://replicate.com/daanelson/whisperx)
|
||||
[](https://replicate.com/carnifexer/whisperx)
|
||||
|
||||
In addition to forced alignment, the following two modifications have been made to the whisper transcription method:
|
||||
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||
|
||||
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||
|
||||
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
|
||||
|
||||
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
|
||||
|
||||
1. reduce batch size, e.g. `--batch_size 4`
|
||||
2. use a smaller ASR model `--model base`
|
||||
3. Use lighter compute type `--compute_type int8`
|
||||
|
||||
Transcription differences from openai's whisper:
|
||||
|
||||
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
|
||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||
|
||||
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
||||
- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
|
||||
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
|
||||
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
||||
- Diariazation is far from perfect.
|
||||
|
||||
- Diarization is far from perfect
|
||||
- Language specific wav2vec2 model is needed
|
||||
|
||||
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
||||
|
||||
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success.
|
||||
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
|
||||
|
||||
The next major upgrade we are working on is whisper with speaker diarization, so if you have any experience on this please share.
|
||||
Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
|
||||
|
||||
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
|
||||
<h2 align="left" id="coming-soon">TODO 🗓</h2>
|
||||
|
||||
* [x] Multilingual init
|
||||
- [x] Multilingual init
|
||||
|
||||
* [x] Subtitle .ass output
|
||||
- [x] Automatic align model selection based on language detection
|
||||
|
||||
* [x] Automatic align model selection based on language detection
|
||||
- [x] Python usage
|
||||
|
||||
* [x] Python usage
|
||||
- [x] Incorporating speaker diarization
|
||||
|
||||
* [x] Character level timestamps
|
||||
- [x] Model flush, for low gpu mem resources
|
||||
|
||||
* [x] Incorporating speaker diarization
|
||||
- [x] Faster-whisper backend
|
||||
|
||||
* [x] Model flush, for low gpu mem resources
|
||||
- [x] Add max-line etc. see (openai's whisper utils.py)
|
||||
|
||||
* [x] Faster-whisper backend
|
||||
- [x] Sentence-level segments (nltk toolbox)
|
||||
|
||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||
- [x] Improve alignment logic
|
||||
|
||||
* [ ] Allow silero-vad as alternative VAD option
|
||||
- [ ] update examples with diarization and word highlighting
|
||||
|
||||
* [ ] Add max-line etc. see (openai's whisper utils.py)
|
||||
- [ ] Subtitle .ass output <- bring this back (removed in v3)
|
||||
|
||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||
- [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||
|
||||
- [x] Allow silero-vad as alternative VAD option
|
||||
|
||||
- [ ] Improve diarization (word level). _Harder than first thought..._
|
||||
|
||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||
|
||||
Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
|
||||
Contact maxhbain@gmail.com for queries.
|
||||
|
||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||
|
||||
|
||||
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
||||
|
||||
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
||||
|
||||
|
||||
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
|
||||
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||
|
||||
Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
|
||||
Valuable VAD & Diarization Models from:
|
||||
|
||||
Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
|
||||
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||
|
||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||
|
||||
Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
|
||||
|
||||
Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
|
||||
|
||||
<h2 align="left" id="cite">Citation</h2>
|
||||
If you use this in your research, please cite the paper:
|
||||
@ -240,7 +311,7 @@ If you use this in your research, please cite the paper:
|
||||
@article{bain2022whisperx,
|
||||
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
|
||||
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
|
||||
journal={arXiv preprint, arXiv:2303.00747},
|
||||
journal={INTERSPEECH 2023},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
||||
[project]
|
||||
urls = { repository = "https://github.com/m-bain/whisperx" }
|
||||
authors = [{ name = "Max Bain" }]
|
||||
name = "whisperx"
|
||||
version = "3.4.2"
|
||||
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9, <3.13"
|
||||
license = { text = "BSD-2-Clause" }
|
||||
|
||||
dependencies = [
|
||||
"ctranslate2<4.5.0",
|
||||
"faster-whisper>=1.1.1",
|
||||
"nltk>=3.9.1",
|
||||
"numpy>=2.0.2",
|
||||
"onnxruntime>=1.19",
|
||||
"pandas>=2.2.3",
|
||||
"pyannote-audio>=3.3.2",
|
||||
"torch>=2.5.1",
|
||||
"torchaudio>=2.5.1",
|
||||
"transformers>=4.48.0",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
whisperx = "whisperx.__main__:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["whisperx*"]
|
@ -1,7 +0,0 @@
|
||||
torch==2.0.0
|
||||
torchaudio==2.0.1
|
||||
faster-whisper
|
||||
transformers
|
||||
ffmpeg-python==0.2.0
|
||||
pandas
|
||||
setuptools==65.6.3
|
28
setup.py
28
setup.py
@ -1,28 +0,0 @@
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="whisperx",
|
||||
py_modules=["whisperx"],
|
||||
version="3.0.2",
|
||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||
readme="README.md",
|
||||
python_requires=">=3.8",
|
||||
author="Max Bain",
|
||||
url="https://github.com/m-bain/whisperx",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
||||
entry_points = {
|
||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={'dev': ['pytest']},
|
||||
)
|
226
whisperx/SubtitlesProcessor.py
Normal file
226
whisperx/SubtitlesProcessor.py
Normal file
@ -0,0 +1,226 @@
|
||||
import math
|
||||
from whisperx.conjunctions import get_conjunctions, get_comma
|
||||
|
||||
def normal_round(n):
|
||||
if n - math.floor(n) < 0.5:
|
||||
return math.floor(n)
|
||||
return math.ceil(n)
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, is_vtt: bool = False):
|
||||
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
separator = '.' if is_vtt else ','
|
||||
|
||||
hours_marker = f"{hours:02d}:"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class SubtitlesProcessor:
|
||||
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
|
||||
self.comma = get_comma(lang)
|
||||
self.conjunctions = set(get_conjunctions(lang))
|
||||
self.segments = segments
|
||||
self.lang = lang
|
||||
self.max_line_length = max_line_length
|
||||
self.min_char_length_splitter = min_char_length_splitter
|
||||
self.is_vtt = is_vtt
|
||||
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
|
||||
if self.lang in complex_script_languages:
|
||||
self.max_line_length = 30
|
||||
self.min_char_length_splitter = 20
|
||||
|
||||
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
|
||||
k = 0.25
|
||||
has_prev_end = i > 0 and 'end' in words[i - 1]
|
||||
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
|
||||
|
||||
if has_prev_end:
|
||||
words[i]['start'] = words[i - 1]['end']
|
||||
if has_next_start:
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
|
||||
|
||||
elif has_next_start:
|
||||
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['start'] = next_segment_start_time - 1
|
||||
words[i]['end'] = next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['start'] = 0
|
||||
words[i]['end'] = 0
|
||||
|
||||
|
||||
|
||||
def process_segments(self, advanced_splitting=True):
|
||||
subtitles = []
|
||||
for i, segment in enumerate(self.segments):
|
||||
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
|
||||
|
||||
if advanced_splitting:
|
||||
|
||||
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
|
||||
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
|
||||
else:
|
||||
words = segment['words']
|
||||
for i, word in enumerate(words):
|
||||
if 'start' not in word or 'end' not in word:
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
subtitles.append({
|
||||
'start': segment['start'],
|
||||
'end': segment['end'],
|
||||
'text': segment['text']
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
|
||||
split_points = []
|
||||
last_split_point = 0
|
||||
char_count = 0
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
add_space = 0 if self.lang in ['zh', 'ja'] else 1
|
||||
|
||||
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
|
||||
char_count_after = total_char_count
|
||||
|
||||
for i, word in enumerate(words):
|
||||
word_text = word['word'] if isinstance(word, dict) else word
|
||||
word_length = len(word_text) + add_space
|
||||
char_count += word_length
|
||||
char_count_after -= word_length
|
||||
|
||||
char_count_before = char_count - word_length
|
||||
|
||||
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
if char_count >= self.max_line_length:
|
||||
midpoint = normal_round((last_split_point + i) / 2)
|
||||
if char_count_before >= self.min_char_length_splitter:
|
||||
split_points.append(midpoint)
|
||||
last_split_point = midpoint + 1
|
||||
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
|
||||
|
||||
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i)
|
||||
last_split_point = i + 1
|
||||
char_count = 0
|
||||
|
||||
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i - 1)
|
||||
last_split_point = i
|
||||
char_count = word_length
|
||||
|
||||
return split_points
|
||||
|
||||
|
||||
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
|
||||
subtitles = []
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
total_word_count = len(words)
|
||||
total_time = segment['end'] - segment['start']
|
||||
elapsed_time = segment['start']
|
||||
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
|
||||
start_idx = 0
|
||||
for split_point in split_points:
|
||||
|
||||
fragment_words = words[start_idx:split_point + 1]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
|
||||
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
|
||||
end_time = next_start_time_for_word
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
elapsed_time += current_duration
|
||||
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time,
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
start_idx = split_point + 1
|
||||
|
||||
# Handle the last fragment
|
||||
if start_idx < len(words):
|
||||
fragment_words = words[start_idx:]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
|
||||
if next_start_time and (next_start_time - end_time) <= 0.8:
|
||||
end_time = next_start_time
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time if end_time is not None else segment['end'],
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
|
||||
|
||||
def save(self, filename="subtitles.srt", advanced_splitting=True):
|
||||
|
||||
subtitles = self.process_segments(advanced_splitting)
|
||||
|
||||
def write_subtitle(file, idx, start_time, end_time, text):
|
||||
|
||||
file.write(f"{idx}\n")
|
||||
file.write(f"{start_time} --> {end_time}\n")
|
||||
file.write(text + "\n\n")
|
||||
|
||||
with open(filename, 'w', encoding='utf-8') as file:
|
||||
if self.is_vtt:
|
||||
file.write("WEBVTT\n\n")
|
||||
|
||||
if advanced_splitting:
|
||||
for idx, subtitle in enumerate(subtitles, 1):
|
||||
start_time = format_timestamp(subtitle['start'], self.is_vtt)
|
||||
end_time = format_timestamp(subtitle['end'], self.is_vtt)
|
||||
text = subtitle['text'].strip()
|
||||
write_subtitle(file, idx, start_time, end_time, text)
|
||||
|
||||
return len(subtitles)
|
@ -1,3 +1,31 @@
|
||||
from .transcribe import load_model
|
||||
from .alignment import load_align_model, align
|
||||
from .audio import load_audio
|
||||
import importlib
|
||||
|
||||
|
||||
def _lazy_import(name):
|
||||
module = importlib.import_module(f"whisperx.{name}")
|
||||
return module
|
||||
|
||||
|
||||
def load_align_model(*args, **kwargs):
|
||||
alignment = _lazy_import("alignment")
|
||||
return alignment.load_align_model(*args, **kwargs)
|
||||
|
||||
|
||||
def align(*args, **kwargs):
|
||||
alignment = _lazy_import("alignment")
|
||||
return alignment.align(*args, **kwargs)
|
||||
|
||||
|
||||
def load_model(*args, **kwargs):
|
||||
asr = _lazy_import("asr")
|
||||
return asr.load_model(*args, **kwargs)
|
||||
|
||||
|
||||
def load_audio(*args, **kwargs):
|
||||
audio = _lazy_import("audio")
|
||||
return audio.load_audio(*args, **kwargs)
|
||||
|
||||
|
||||
def assign_word_speakers(*args, **kwargs):
|
||||
diarize = _lazy_import("diarize")
|
||||
return diarize.assign_word_speakers(*args, **kwargs)
|
||||
|
@ -1,4 +1,89 @@
|
||||
from .transcribe import cli
|
||||
import argparse
|
||||
import importlib.metadata
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
||||
optional_int, str2bool)
|
||||
|
||||
|
||||
cli()
|
||||
def cli():
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
|
||||
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
from whisperx.transcribe import transcribe_task
|
||||
|
||||
transcribe_task(args, parser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -1,9 +1,11 @@
|
||||
""""
|
||||
"""
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
import math
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -11,8 +13,18 @@ import torch
|
||||
import torchaudio
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from whisperx.audio import SAMPLE_RATE, load_audio
|
||||
from whisperx.utils import interpolate_nans
|
||||
from whisperx.types import (
|
||||
AlignedTranscriptionResult,
|
||||
SingleSegment,
|
||||
SingleAlignedSegment,
|
||||
SingleWordSegment,
|
||||
SegmentData,
|
||||
)
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@ -31,6 +43,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
||||
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
||||
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
||||
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
||||
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
||||
@ -38,11 +51,30 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
|
||||
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
||||
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
||||
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
|
||||
"ca": "softcatala/wav2vec2-large-xlsr-catala",
|
||||
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
|
||||
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
|
||||
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
|
||||
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
|
||||
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
|
||||
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
||||
"ro": "gigant/romanian-wav2vec2",
|
||||
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
||||
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
||||
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
||||
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
|
||||
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
|
||||
}
|
||||
|
||||
|
||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
@ -62,8 +94,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
try:
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||
@ -79,459 +111,474 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
transcript: Iterable[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
interpolate_method: str = "nearest",
|
||||
):
|
||||
return_char_alignments: bool = False,
|
||||
print_progress: bool = False,
|
||||
combined_progress: bool = False,
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
Force align phoneme recognition predictions to known transcription
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcript: Iterator[dict]
|
||||
The Whisper model instance
|
||||
|
||||
model: torch.nn.Module
|
||||
Alignment model (wav2vec2)
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
device: str
|
||||
cuda device
|
||||
|
||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
||||
diarization segments with speaker labels.
|
||||
|
||||
extend_duration: float
|
||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
model_lang = align_model_metadata["language"]
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
aligned_segments = []
|
||||
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = {
|
||||
"segment-idx": [],
|
||||
"subsegment-idx": [],
|
||||
"word-idx": [],
|
||||
"char": [],
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
}
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
total_segments = len(transcript)
|
||||
# Store temporary processing values
|
||||
segment_data: dict[int, SegmentData] = {}
|
||||
for sdx, segment in enumerate(transcript):
|
||||
while True:
|
||||
segment_align_success = False
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
if print_progress:
|
||||
base_progress = ((sdx + 1) / total_segments) * 100
|
||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
transcription = segment["text"]
|
||||
# split into words
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = text.split(" ")
|
||||
else:
|
||||
per_word = text
|
||||
|
||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||
# e.g. "$300" -> "three hundred dollars"
|
||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||
|
||||
# split into words
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(text):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = transcription.split(" ")
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(text) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
else:
|
||||
per_word = transcription
|
||||
# add placeholder
|
||||
clean_char.append('*')
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||
clean_wdx.append(wdx)
|
||||
else:
|
||||
# index for placeholder
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(transcription):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(transcription) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
punkt_param = PunktParameters()
|
||||
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# if no characters are in the dictionary, then we skip this segment...
|
||||
if len(clean_char) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
break
|
||||
|
||||
transcription_cleaned = "".join(clean_char)
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
|
||||
# we only pad if not using VAD filtering
|
||||
if "seg_text" not in segment:
|
||||
# pad according original timestamps
|
||||
t1 = max(segment["start"] - extend_duration, 0)
|
||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||
|
||||
# use prev_t2 as current t1 if it"s later
|
||||
if start_from_previous and t1 < prev_t2:
|
||||
t1 = prev_t2
|
||||
|
||||
# check if timestamp range is still valid
|
||||
if t1 >= MAX_DURATION:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
break
|
||||
if t2 - t1 < 0.02:
|
||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||
break
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
break
|
||||
char_segments = merge_repeats(path, transcription_cleaned)
|
||||
# word_segments = merge_words(char_segments)
|
||||
segment_data[sdx] = {
|
||||
"clean_char": clean_char,
|
||||
"clean_cdx": clean_cdx,
|
||||
"clean_wdx": clean_wdx,
|
||||
"sentence_spans": sentence_spans
|
||||
}
|
||||
|
||||
|
||||
# sub-segments
|
||||
if "seg-text" not in segment:
|
||||
segment["seg-text"] = [transcription]
|
||||
|
||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
||||
sub_seg_idx = 0
|
||||
|
||||
wdx = 0
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
for cdx, char in enumerate(transcription + " "):
|
||||
is_last = False
|
||||
if cdx == len(transcription):
|
||||
break
|
||||
elif cdx+1 == len(transcription):
|
||||
is_last = True
|
||||
|
||||
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = char_seg.score
|
||||
|
||||
char_segments_arr["char"].append(char)
|
||||
char_segments_arr["start"].append(start)
|
||||
char_segments_arr["end"].append(end)
|
||||
char_segments_arr["score"].append(score)
|
||||
char_segments_arr["word-idx"].append(wdx)
|
||||
char_segments_arr["segment-idx"].append(sdx)
|
||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
||||
|
||||
# word-level info
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character == word
|
||||
wdx += 1
|
||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx += 1
|
||||
|
||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx = 0
|
||||
sub_seg_idx += 1
|
||||
|
||||
prev_t2 = segment["end"]
|
||||
|
||||
segment_align_success = True
|
||||
# end while True loop
|
||||
break
|
||||
|
||||
# reset prev_t2 due to drifting issues
|
||||
if not segment_align_success:
|
||||
prev_t2 = 0
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
not_space = char_segments_arr["char"] != " "
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
||||
aligned_seg: SingleAlignedSegment = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
"words": [],
|
||||
"chars": None,
|
||||
}
|
||||
|
||||
word_segments_arr = {}
|
||||
if return_char_alignments:
|
||||
aligned_seg["chars"] = []
|
||||
|
||||
# start of word is first char with a timestamp
|
||||
word_segments_arr["start"] = per_word_grp["start"].min().values
|
||||
# end of word is last char with a timestamp
|
||||
word_segments_arr["end"] = per_word_grp["end"].max().values
|
||||
# score of word is mean (excluding nan)
|
||||
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
||||
# check we can align
|
||||
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
||||
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||
if t1 >= MAX_DURATION:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
||||
segments_arr = {}
|
||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
||||
segments_arr = pd.DataFrame(segments_arr)
|
||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
||||
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||
|
||||
# interpolate missing words / sub-segments
|
||||
if interpolate_method != "ignore":
|
||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
# we still know which word timestamps are interpolated because their score == nan
|
||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
# merge words & subsegments which are missing times
|
||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||
|
||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||
|
||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
||||
else:
|
||||
word_segments_arr.dropna(inplace=True)
|
||||
segments_arr.dropna(inplace=True)
|
||||
|
||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
||||
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments_word = []
|
||||
|
||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
||||
|
||||
for sdx, srow in segments_arr.iterrows():
|
||||
|
||||
seg_idx = int(srow["segment-idx"])
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
sub_end = int(srow["subsegment-idx-end"])
|
||||
|
||||
seg = transcript[seg_idx]
|
||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
wseg["start"].fillna(srow["start"], inplace=True)
|
||||
wseg["end"].fillna(srow["end"], inplace=True)
|
||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
||||
|
||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
# fixes bug for single segment in transcript
|
||||
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
||||
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
||||
if 'level_1' in cseg: del cseg['level_1']
|
||||
if 'level_0' in cseg: del cseg['level_0']
|
||||
cseg.reset_index(inplace=True)
|
||||
|
||||
def get_raw_text(word_row):
|
||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
||||
|
||||
word_list = []
|
||||
wdx = 0
|
||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
||||
if not curr_text.startswith(" "):
|
||||
curr_text = " " + curr_text
|
||||
# TODO: Probably can get some speedup gain with batched inference here
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
# Handle the minimum input length for wav2vec2 models
|
||||
if waveform_segment.shape[-1] < 400:
|
||||
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||
waveform_segment = torch.nn.functional.pad(
|
||||
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
|
||||
if len(wseg) > 1:
|
||||
for _, wrow in wseg.iloc[1:].iterrows():
|
||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
||||
word_start = wseg.iloc[wdx]['start']
|
||||
word_end = wseg.iloc[wdx]['end']
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": word_start,
|
||||
"end": word_end
|
||||
}
|
||||
)
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
word_list.append(
|
||||
{
|
||||
"word": curr_text.rstrip(),
|
||||
"start": word_start,
|
||||
"end": word_end,
|
||||
}
|
||||
)
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
curr_text = " "
|
||||
curr_text += get_raw_text(wrow) + " "
|
||||
wdx += 1
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"]
|
||||
}
|
||||
)
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
word_list.append(
|
||||
{
|
||||
"word": curr_text.rstrip(),
|
||||
"start": wseg.iloc[wdx]['start'],
|
||||
"end": wseg.iloc[wdx]['end'],
|
||||
}
|
||||
)
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
aligned_segments.append(
|
||||
{
|
||||
"start": srow["start"],
|
||||
"end": srow["end"],
|
||||
"text": text,
|
||||
"words": word_list,
|
||||
"word-segments": wseg,
|
||||
"char-segments": cseg
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
char_segments_arr = []
|
||||
word_idx = 0
|
||||
for cdx, char in enumerate(text):
|
||||
start, end, score = None, None, None
|
||||
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = round(char_seg.score, 3)
|
||||
|
||||
char_segments_arr.append(
|
||||
{
|
||||
"char": char,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"score": score,
|
||||
"word-idx": word_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
word_idx += 1
|
||||
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||
word_idx += 1
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
|
||||
aligned_subsegments = []
|
||||
# assign sentence_idx to each character index
|
||||
char_segments_arr["sentence-idx"] = None
|
||||
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||
sentence_end = end_chars["end"].max()
|
||||
sentence_words = []
|
||||
|
||||
for word_idx in curr_chars["word-idx"].unique():
|
||||
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
||||
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||
if len(word_text) == 0:
|
||||
continue
|
||||
|
||||
# dont use space character for alignment
|
||||
word_chars = word_chars[word_chars["char"] != " "]
|
||||
|
||||
word_start = word_chars["start"].min()
|
||||
word_end = word_chars["end"].max()
|
||||
word_score = round(word_chars["score"].mean(), 3)
|
||||
|
||||
# -1 indicates unalignable
|
||||
word_segment = {"word": word_text}
|
||||
|
||||
if not np.isnan(word_start):
|
||||
word_segment["start"] = word_start
|
||||
if not np.isnan(word_end):
|
||||
word_segment["end"] = word_end
|
||||
if not np.isnan(word_score):
|
||||
word_segment["score"] = word_score
|
||||
|
||||
sentence_words.append(word_segment)
|
||||
|
||||
aligned_subsegments.append({
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"words": sentence_words,
|
||||
})
|
||||
|
||||
if return_char_alignments:
|
||||
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
||||
curr_chars.fillna(-1, inplace=True)
|
||||
curr_chars = curr_chars.to_dict("records")
|
||||
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
||||
aligned_subsegments[-1]["chars"] = curr_chars
|
||||
|
||||
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
||||
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
||||
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||
# concatenate sentences with same timestamps
|
||||
agg_dict = {"text": " ".join, "words": "sum"}
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
agg_dict["text"] = "".join
|
||||
if return_char_alignments:
|
||||
agg_dict["chars"] = "sum"
|
||||
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||
aligned_subsegments = aligned_subsegments.to_dict('records')
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments: List[SingleWordSegment] = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": word_segments}
|
||||
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
"""
|
||||
|
||||
|
||||
def get_trellis(emission, tokens, blank_id=0):
|
||||
num_frame = emission.size(0)
|
||||
num_tokens = len(tokens)
|
||||
|
||||
# Trellis has extra diemsions for both time axis and tokens.
|
||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
||||
# The extra dim for time axis is for simplification of the code.
|
||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
||||
trellis[0, 0] = 0
|
||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
||||
trellis[0, -num_tokens:] = -float("inf")
|
||||
trellis[-num_tokens:, 0] = float("inf")
|
||||
trellis = torch.zeros((num_frame, num_tokens))
|
||||
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||
trellis[0, 1:] = -float("inf")
|
||||
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||
|
||||
for t in range(num_frame):
|
||||
for t in range(num_frame - 1):
|
||||
trellis[t + 1, 1:] = torch.maximum(
|
||||
# Score for staying at the same token
|
||||
trellis[t, 1:] + emission[t, blank_id],
|
||||
# Score for changing to the next token
|
||||
trellis[t, :-1] + emission[t, tokens],
|
||||
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||
)
|
||||
return trellis
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""Processing token emission scores containing wildcards (vectorized version)
|
||||
|
||||
Args:
|
||||
frame_emission: Emission probability vector for the current frame
|
||||
tokens: List of token indices
|
||||
blank_id: ID of the blank token
|
||||
|
||||
Returns:
|
||||
tensor: Maximum probability score for each token position
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
|
||||
# Convert tokens to a tensor if they are not already
|
||||
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||
|
||||
# Create a mask to identify wildcard positions
|
||||
wildcard_mask = (tokens == -1)
|
||||
|
||||
# Get scores for non-wildcard positions
|
||||
regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index
|
||||
|
||||
# Create a mask and compute the maximum value without modifying frame_emission
|
||||
max_valid_score = frame_emission.clone() # Create a copy
|
||||
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||
max_valid_score = max_valid_score.max()
|
||||
|
||||
# Use where operation to combine results
|
||||
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class Point:
|
||||
token_index: int
|
||||
time_index: int
|
||||
score: float
|
||||
|
||||
|
||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||
# Note:
|
||||
# j and t are indices for trellis, which has extra dimensions
|
||||
# for time and tokens at the beginning.
|
||||
# When referring to time frame index `T` in trellis,
|
||||
# the corresponding index in emission is `T-1`.
|
||||
# Similarly, when referring to token index `J` in trellis,
|
||||
# the corresponding index in transcript is `J-1`.
|
||||
j = trellis.size(1) - 1
|
||||
t_start = torch.argmax(trellis[:, j]).item()
|
||||
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||
while j > 0:
|
||||
# Should not happen but just in case
|
||||
assert t > 0
|
||||
|
||||
path = []
|
||||
for t in range(t_start, 0, -1):
|
||||
# 1. Figure out if the current position was stay or change
|
||||
# Note (again):
|
||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
||||
# Score for token staying the same from time frame J-1 to T.
|
||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
||||
# Score for token changing from C-1 at T-1 to J at T.
|
||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
||||
# Frame-wise score of stay vs change
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
# p_change = emission[t - 1, tokens[j]]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
# 2. Store the path with frame-wise probability.
|
||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
||||
# Return token index and time index in non-trellis coordinate.
|
||||
path.append(Point(j - 1, t - 1, prob))
|
||||
# Context-aware score for stay vs change
|
||||
stayed = trellis[t - 1, j] + p_stay
|
||||
changed = trellis[t - 1, j - 1] + p_change
|
||||
|
||||
# 3. Update the token
|
||||
# Update position
|
||||
t -= 1
|
||||
if changed > stayed:
|
||||
j -= 1
|
||||
if j == 0:
|
||||
break
|
||||
else:
|
||||
# failed
|
||||
return None
|
||||
|
||||
# Store the path with frame-wise probability.
|
||||
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||
path.append(Point(j, t, prob))
|
||||
|
||||
# Now j == 0, which means, it reached the SoS.
|
||||
# Fill up the rest for the sake of visualization
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return path[::-1]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
points: List[Point]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamState:
|
||||
"""State in beam search."""
|
||||
token_index: int # Current token position
|
||||
time_index: int # Current time step
|
||||
score: float # Cumulative score
|
||||
path: List[Point] # Path history
|
||||
|
||||
|
||||
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
"""Standard CTC beam search backtracking implementation.
|
||||
|
||||
Args:
|
||||
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||
and N is the number of tokens (including the blank token).
|
||||
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||
tokens (List[int]): List of token indices (excluding the blank token).
|
||||
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
List[Point]: the best path
|
||||
"""
|
||||
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
init_state = BeamState(
|
||||
token_index=J,
|
||||
time_index=T,
|
||||
score=trellis[T, J],
|
||||
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||
)
|
||||
|
||||
beams = [init_state]
|
||||
|
||||
while beams and beams[0].token_index > 0:
|
||||
next_beams = []
|
||||
|
||||
for beam in beams:
|
||||
t, j = beam.time_index, beam.token_index
|
||||
|
||||
if t <= 0:
|
||||
continue
|
||||
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
stay_score = trellis[t - 1, j]
|
||||
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||
|
||||
# Stay
|
||||
if not math.isinf(stay_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j,
|
||||
time_index=t - 1,
|
||||
score=stay_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# Change
|
||||
if j > 0 and not math.isinf(change_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j - 1,
|
||||
time_index=t - 1,
|
||||
score=change_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# sort by score
|
||||
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||
|
||||
if not beams:
|
||||
break
|
||||
|
||||
if not beams:
|
||||
return None
|
||||
|
||||
best_beam = beams[0]
|
||||
t = best_beam.time_index
|
||||
j = best_beam.token_index
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
best_beam.path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return best_beam.path[::-1]
|
||||
|
||||
|
||||
# Merge the labels
|
||||
@dataclass
|
||||
class Segment:
|
||||
|
503
whisperx/asr.py
503
whisperx/asr.py
@ -1,87 +1,43 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import replace
|
||||
|
||||
import ctranslate2
|
||||
import faster_whisper
|
||||
import numpy as np
|
||||
import torch
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None):
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
device: str - The device to load the model on.
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
|
||||
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from whisperx.types import SingleSegment, TranscriptionResult
|
||||
from whisperx.vads import Vad, Silero, Pyannote
|
||||
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
|
||||
if has_numeral_symbol:
|
||||
numeral_symbol_tokens.append(i)
|
||||
return numeral_symbol_tokens
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
batch_size = features.shape[0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
@ -95,6 +51,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
previous_tokens,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
hotwords=options.hotwords
|
||||
)
|
||||
|
||||
encoder_output = self.encode(features)
|
||||
@ -106,15 +63,14 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
# length_penalty=options.length_penalty,
|
||||
# max_length=self.max_length,
|
||||
# return_scores=True,
|
||||
# return_no_speech_prob=True,
|
||||
# suppress_blank=options.suppress_blank,
|
||||
# suppress_tokens=options.suppress_tokens,
|
||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
@ -127,7 +83,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
@ -135,24 +91,36 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
# unsqueeze if batch size = 1
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
||||
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
vad,
|
||||
options,
|
||||
tokenizer=None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework = "pt",
|
||||
**kwargs
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
suppress_numerals: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.suppress_numerals = suppress_numerals
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
@ -169,9 +137,10 @@ class FasterWhisperPipeline(Pipeline):
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
self._vad_params = vad_params
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
@ -181,18 +150,29 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
features = log_mel_spectrogram(
|
||||
audio,
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=N_SAMPLES - audio.shape[0],
|
||||
)
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
||||
self,
|
||||
inputs,
|
||||
num_workers: int,
|
||||
batch_size: int,
|
||||
preprocess_params: dict,
|
||||
forward_params: dict,
|
||||
postprocess_params: dict,
|
||||
):
|
||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||
@ -207,11 +187,20 @@ class FasterWhisperPipeline(Pipeline):
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
):
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers=0,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
chunk_size=30,
|
||||
print_progress=False,
|
||||
combined_progress=False,
|
||||
verbose=False,
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
@ -219,41 +208,87 @@ class FasterWhisperPipeline(Pipeline):
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
del_tokenizer = False
|
||||
if self.tokenizer is None:
|
||||
language = self.detect_language(audio)
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
|
||||
del_tokenizer = True
|
||||
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||
if issubclass(type(self.vad_model), Vad):
|
||||
waveform = self.vad_model.preprocess_audio(audio)
|
||||
merge_chunks = self.vad_model.merge_chunks
|
||||
else:
|
||||
language = self.tokenizer.language_code
|
||||
waveform = Pyannote.preprocess_audio(audio)
|
||||
merge_chunks = Pyannote.merge_chunks
|
||||
|
||||
segments = []
|
||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(
|
||||
vad_segments,
|
||||
chunk_size,
|
||||
onset=self._vad_params["vad_onset"],
|
||||
offset=self._vad_params["vad_offset"],
|
||||
)
|
||||
if self.tokenizer is None:
|
||||
language = language or self.detect_language(audio)
|
||||
task = task or "transcribe"
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
language = language or self.tokenizer.language_code
|
||||
task = task or self.tokenizer.task
|
||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
|
||||
if self.suppress_numerals:
|
||||
previous_suppress_tokens = self.options.suppress_tokens
|
||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||
print(f"Suppressing numeral and symbol tokens")
|
||||
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
total_segments = len(vad_segments)
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
if print_progress:
|
||||
base_progress = ((idx + 1) / total_segments) * 100
|
||||
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
text = out['text']
|
||||
if batch_size in [0, 1, None]:
|
||||
text = text[0]
|
||||
if verbose:
|
||||
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
|
||||
segments.append(
|
||||
{
|
||||
"text": out['text'],
|
||||
"text": text,
|
||||
"start": round(vad_segments[idx]['start'], 3),
|
||||
"end": round(vad_segments[idx]['end'], 3)
|
||||
}
|
||||
)
|
||||
|
||||
if del_tokenizer:
|
||||
|
||||
# revert the tokenizer if multilingual inference is enabled
|
||||
if self.preset_language is None:
|
||||
self.tokenizer = None
|
||||
|
||||
# revert suppressed tokens if suppress_numerals is enabled
|
||||
if self.suppress_numerals:
|
||||
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
|
||||
|
||||
return {"segments": segments, "language": language}
|
||||
|
||||
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
def detect_language(self, audio: np.ndarray) -> str:
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||
encoder_output = self.model.encode(segment)
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
@ -262,148 +297,120 @@ class FasterWhisperPipeline(Pipeline):
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_type = "simple"
|
||||
import time
|
||||
|
||||
import jiwer
|
||||
from tqdm import tqdm
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
def load_model(
|
||||
whisper_arch: str,
|
||||
device: str,
|
||||
device_index=0,
|
||||
compute_type="float16",
|
||||
asr_options: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
vad_model: Optional[Vad]= None,
|
||||
vad_method: Optional[str] = "pyannote",
|
||||
vad_options: Optional[dict] = None,
|
||||
model: Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
download_root: Optional[str] = None,
|
||||
local_files_only=False,
|
||||
threads=4,
|
||||
) -> FasterWhisperPipeline:
|
||||
"""Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch - The name of the Whisper model to load.
|
||||
device - The device to load the model on.
|
||||
compute_type - The compute type to use for the model.
|
||||
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||
options - A dictionary of options to use for the model.
|
||||
language - The language of the model. (use English for now)
|
||||
model - The WhisperModel instance to use.
|
||||
download_root - The root directory to download the model to.
|
||||
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
||||
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
"""
|
||||
|
||||
from benchmark.tedlium import parse_tedlium_annos
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
if main_type == "complex":
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import TranscriptionOptions
|
||||
from faster_whisper.vad import (SpeechTimestampsMap,
|
||||
get_speech_timestamps)
|
||||
model = model or WhisperModel(whisper_arch,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
download_root=download_root,
|
||||
local_files_only=local_files_only,
|
||||
cpu_threads=threads)
|
||||
if language is not None:
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
from whisperx.vad import load_vad_model, merge_chunks
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"repetition_penalty": 1,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"prompt_reset_on_temperature": 0.5,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||
"multilingual": model.model.is_multilingual,
|
||||
"suppress_numerals": False,
|
||||
"max_new_tokens": None,
|
||||
"clip_timestamps": None,
|
||||
"hallucination_silence_threshold": None,
|
||||
"hotwords": None,
|
||||
}
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
faster_t_options = TranscriptionOptions(
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
length_penalty=1,
|
||||
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
no_speech_threshold=0.6,
|
||||
condition_on_previous_text=False,
|
||||
initial_prompt=None,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens=[-1],
|
||||
without_timestamps=True,
|
||||
max_initial_timestamp=0.0,
|
||||
word_timestamps=False,
|
||||
prepend_punctuations="\"'“¿([{-",
|
||||
append_punctuations="\"'.。,,!!??::”)]}、"
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
|
||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||
if vad_model is not None:
|
||||
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||
vad_model = vad_model
|
||||
else:
|
||||
if vad_method == "silero":
|
||||
vad_model = Silero(**default_vad_options)
|
||||
elif vad_method == "pyannote":
|
||||
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
else:
|
||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
vad=vad_model,
|
||||
options=default_asr_options,
|
||||
tokenizer=tokenizer,
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
||||
whisper_arch = "large-v2"
|
||||
device = "cuda"
|
||||
batch_size = 16
|
||||
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
|
||||
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
|
||||
fn = "DanielKahneman_2010.wav"
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
vad_model = load_vad_model("cuda", 0.6, 0.3)
|
||||
audio = load_audio(os.path.join(wav_dir, fn))
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
vad_method="pyannote"
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
if vad_method == "pyannote":
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
elif vad_method == "silero":
|
||||
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
|
||||
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
|
||||
new_segs = []
|
||||
curr_start = vad_segments[0]['start']
|
||||
curr_end = vad_segments[0]['end']
|
||||
for seg in vad_segments[1:]:
|
||||
if seg['end'] - curr_start > 30:
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
curr_start = seg['start']
|
||||
curr_end = seg['end']
|
||||
else:
|
||||
curr_end = seg['end']
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
vad_segments = new_segs
|
||||
text = []
|
||||
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
|
||||
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
|
||||
text.append(out['text'])
|
||||
t2 = time.time()
|
||||
if batch_size == 1:
|
||||
text = [x[0] for x in text]
|
||||
text = " ".join(text)
|
||||
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
elif main_type == "simple":
|
||||
model = load_model(
|
||||
"large-v2",
|
||||
device="cuda",
|
||||
language="en",
|
||||
)
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
# fn = "DanielKahneman_2010.wav"
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
out = model.transcribe(audio_fp, batch_size=8)["segments"]
|
||||
t2 = time.time()
|
||||
|
||||
text = " ".join([x['text'] for x in out])
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
|
Binary file not shown.
BIN
whisperx/assets/pytorch_model.bin
Normal file
BIN
whisperx/assets/pytorch_model.bin
Normal file
Binary file not shown.
@ -1,18 +1,17 @@
|
||||
import os
|
||||
import subprocess
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import exact_div
|
||||
from whisperx.utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
@ -23,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
@ -40,14 +39,27 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
try:
|
||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except ffmpeg.Error as e:
|
||||
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI to be installed.
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"0",
|
||||
"-i",
|
||||
file,
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(sr),
|
||||
"-",
|
||||
]
|
||||
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
@ -80,7 +92,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
@ -90,7 +102,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
@ -99,7 +111,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
n_mels: int,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
|
47
whisperx/conjunctions.py
Normal file
47
whisperx/conjunctions.py
Normal file
@ -0,0 +1,47 @@
|
||||
# conjunctions.py
|
||||
|
||||
from typing import Set
|
||||
|
||||
|
||||
conjunctions_by_language = {
|
||||
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
|
||||
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
|
||||
'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
|
||||
'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
|
||||
'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'},
|
||||
'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
|
||||
'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
|
||||
'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
|
||||
'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
|
||||
'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'},
|
||||
'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
|
||||
'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
|
||||
'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
|
||||
'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
|
||||
'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
|
||||
'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
|
||||
'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
|
||||
'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
|
||||
'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
|
||||
'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
|
||||
'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
|
||||
'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'},
|
||||
'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'},
|
||||
'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
|
||||
'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'}
|
||||
|
||||
}
|
||||
|
||||
commas_by_language = {
|
||||
'ja': '、',
|
||||
'zh': ',',
|
||||
'fa': '،',
|
||||
'ur': '،'
|
||||
}
|
||||
|
||||
def get_conjunctions(lang_code: str) -> Set[str]:
|
||||
return conjunctions_by_language.get(lang_code, set())
|
||||
|
||||
|
||||
def get_comma(lang_code: str) -> str:
|
||||
return commas_by_language.get(lang_code, ",")
|
@ -4,78 +4,143 @@ from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
||||
|
||||
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization@2.1",
|
||||
model_name=None,
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
return diarize_df
|
||||
def __call__(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
return_embeddings: bool = False,
|
||||
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
|
||||
"""
|
||||
Perform speaker diarization on audio.
|
||||
|
||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||
for seg in result_segments:
|
||||
wdf = seg['word-segments']
|
||||
if len(wdf['start'].dropna()) == 0:
|
||||
wdf['start'] = seg['start']
|
||||
wdf['end'] = seg['end']
|
||||
speakers = []
|
||||
for wdx, wrow in wdf.iterrows():
|
||||
if not np.isnan(wrow['start']):
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) == 0:
|
||||
speaker = None
|
||||
else:
|
||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
||||
else:
|
||||
speaker = None
|
||||
speakers.append(speaker)
|
||||
seg['word-segments']['speaker'] = speakers
|
||||
Args:
|
||||
audio: Path to audio file or audio array
|
||||
num_speakers: Exact number of speakers (if known)
|
||||
min_speakers: Minimum number of speakers to detect
|
||||
max_speakers: Maximum number of speakers to detect
|
||||
return_embeddings: Whether to return speaker embeddings
|
||||
|
||||
speaker_count = pd.Series(speakers).value_counts()
|
||||
if len(speaker_count) == 0:
|
||||
seg["speaker"]= "UNKNOWN"
|
||||
Returns:
|
||||
If return_embeddings is True:
|
||||
Tuple of (diarization dataframe, speaker embeddings dictionary)
|
||||
Otherwise:
|
||||
Just the diarization dataframe
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio_data = {
|
||||
'waveform': torch.from_numpy(audio[None, :]),
|
||||
'sample_rate': SAMPLE_RATE
|
||||
}
|
||||
|
||||
if return_embeddings:
|
||||
diarization, embeddings = self.model(
|
||||
audio_data,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
return_embeddings=True,
|
||||
)
|
||||
else:
|
||||
seg["speaker"] = speaker_count.index[0]
|
||||
diarization = self.model(
|
||||
audio_data,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
embeddings = None
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
for seg in result_segments:
|
||||
wseg = pd.DataFrame(seg["word-segments"])
|
||||
for wdx, wrow in wseg.iterrows():
|
||||
if wrow["start"] is not None:
|
||||
speaker = wrow['speaker']
|
||||
if speaker is None or speaker == np.nan:
|
||||
speaker = "UNKNOWN"
|
||||
word_seg.append(
|
||||
{
|
||||
"start": wrow["start"],
|
||||
"end": wrow["end"],
|
||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||
}
|
||||
)
|
||||
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
||||
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
||||
|
||||
# TODO: create segments but split words on new speaker
|
||||
if return_embeddings and embeddings is not None:
|
||||
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
|
||||
return diarize_df, speaker_embeddings
|
||||
|
||||
# For backwards compatibility
|
||||
if return_embeddings:
|
||||
return diarize_df, None
|
||||
else:
|
||||
return diarize_df
|
||||
|
||||
|
||||
def assign_word_speakers(
|
||||
diarize_df: pd.DataFrame,
|
||||
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||
speaker_embeddings: Optional[dict[str, list[float]]] = None,
|
||||
fill_nearest: bool = False,
|
||||
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
|
||||
"""
|
||||
Assign speakers to words and segments in the transcript.
|
||||
|
||||
Args:
|
||||
diarize_df: Diarization dataframe from DiarizationPipeline
|
||||
transcript_result: Transcription result to augment with speaker labels
|
||||
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
|
||||
fill_nearest: If True, assign speakers even when there's no direct time overlap
|
||||
|
||||
Returns:
|
||||
Updated transcript_result with speaker assignments and optionally embeddings
|
||||
"""
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
||||
# remove no hit, otherwise we look for closest (even negative intersection...)
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
seg["speaker"] = speaker
|
||||
|
||||
# assign speaker to words
|
||||
if 'words' in seg:
|
||||
for word in seg['words']:
|
||||
if 'start' in word:
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
word["speaker"] = speaker
|
||||
|
||||
# Add speaker embeddings to the result if provided
|
||||
if speaker_embeddings is not None:
|
||||
transcript_result["speaker_embeddings"] = speaker_embeddings
|
||||
|
||||
return transcript_result
|
||||
|
||||
return result_segments, word_seg
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
@ -6,81 +6,33 @@ import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .alignment import align, load_align_model
|
||||
from .asr import load_model
|
||||
from .audio import load_audio
|
||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
||||
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
|
||||
optional_int, str2bool)
|
||||
from whisperx.alignment import align, load_align_model
|
||||
from whisperx.asr import load_model
|
||||
from whisperx.audio import load_audio
|
||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
||||
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
||||
|
||||
|
||||
def cli():
|
||||
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
"""Transcription task to be called from CLI.
|
||||
|
||||
Args:
|
||||
args: Dictionary of command-line arguments.
|
||||
parser: argparse.ArgumentParser object.
|
||||
"""
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int)
|
||||
parser.add_argument("--max_speakers", default=None, type=int)
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
|
||||
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
batch_size: int = args.pop("batch_size")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
model_cache_only: bool = args.pop("model_cache_only")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
device_index: int = args.pop("device_index")
|
||||
compute_type: str = args.pop("compute_type")
|
||||
verbose: bool = args.pop("verbose")
|
||||
|
||||
# model_flush: bool = args.pop("model_flush")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -88,23 +40,47 @@ def cli():
|
||||
align_model: str = args.pop("align_model")
|
||||
interpolate_method: str = args.pop("interpolate_method")
|
||||
no_align: bool = args.pop("no_align")
|
||||
task: str = args.pop("task")
|
||||
if task == "translate":
|
||||
# translation cannot be aligned
|
||||
no_align = True
|
||||
|
||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||
|
||||
hf_token: str = args.pop("hf_token")
|
||||
vad_method: str = args.pop("vad_method")
|
||||
vad_onset: float = args.pop("vad_onset")
|
||||
vad_offset: float = args.pop("vad_offset")
|
||||
|
||||
chunk_size: int = args.pop("chunk_size")
|
||||
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
diarize_model_name: str = args.pop("diarize_model")
|
||||
print_progress: bool = args.pop("print_progress")
|
||||
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
|
||||
|
||||
# TODO: check model loading works.
|
||||
if return_speaker_embeddings and not diarize:
|
||||
warnings.warn("--speaker_embeddings has no effect without --diarize")
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
args["language"] = args["language"].lower()
|
||||
if args["language"] not in LANGUAGES:
|
||||
if args["language"] in TO_LANGUAGE_CODE:
|
||||
args["language"] = TO_LANGUAGE_CODE[args["language"]]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {args['language']}")
|
||||
|
||||
if model_name.endswith(".en") and args["language"] != "en":
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
align_language = (
|
||||
args["language"] if args["language"] is not None else "en"
|
||||
) # default to loading english if not specified
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
@ -112,8 +88,10 @@ def cli():
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
faster_whisper_threads = 4
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
faster_whisper_threads = threads
|
||||
|
||||
asr_options = {
|
||||
"beam_size": args.pop("beam_size"),
|
||||
@ -125,6 +103,8 @@ def cli():
|
||||
"no_speech_threshold": args.pop("no_speech_threshold"),
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": args.pop("initial_prompt"),
|
||||
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
||||
"suppress_numerals": args.pop("suppress_numerals"),
|
||||
}
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@ -132,22 +112,45 @@ def cli():
|
||||
if no_align:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
parser.error(f"--{option} not possible with --no_align")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
|
||||
|
||||
# Part 1: VAD & ASR Loop
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
|
||||
model = load_model(
|
||||
model_name,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
download_root=model_dir,
|
||||
compute_type=compute_type,
|
||||
language=args["language"],
|
||||
asr_options=asr_options,
|
||||
vad_method=vad_method,
|
||||
vad_options={
|
||||
"chunk_size": chunk_size,
|
||||
"vad_onset": vad_onset,
|
||||
"vad_offset": vad_offset,
|
||||
},
|
||||
task=task,
|
||||
local_files_only=model_cache_only,
|
||||
threads=faster_whisper_threads,
|
||||
)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
# >> VAD & ASR
|
||||
print(">>Performing transcription...")
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
result: TranscriptionResult = model.transcribe(
|
||||
audio,
|
||||
batch_size=batch_size,
|
||||
chunk_size=chunk_size,
|
||||
print_progress=print_progress,
|
||||
verbose=verbose,
|
||||
)
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload Whisper and VAD
|
||||
@ -159,8 +162,9 @@ def cli():
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
align_model, align_metadata = load_align_model(
|
||||
align_language, device, model_name=align_model
|
||||
)
|
||||
for result, audio_path in tmp_results:
|
||||
# >> Align
|
||||
if len(tmp_results) > 1:
|
||||
@ -172,10 +176,24 @@ def cli():
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(
|
||||
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||
)
|
||||
align_model, align_metadata = load_align_model(
|
||||
result["language"], device
|
||||
)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
|
||||
result: AlignedTranscriptionResult = align(
|
||||
result["segments"],
|
||||
align_model,
|
||||
align_metadata,
|
||||
input_audio,
|
||||
device,
|
||||
interpolate_method=interpolate_method,
|
||||
return_char_alignments=return_char_alignments,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload align model
|
||||
@ -186,26 +204,31 @@ def cli():
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
print(
|
||||
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
||||
)
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
print(">>Using model:", diarize_model_name)
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
results.append((result, input_audio_path))
|
||||
diarize_result = diarize_model(
|
||||
input_audio_path,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
return_embeddings=return_speaker_embeddings
|
||||
)
|
||||
|
||||
if return_speaker_embeddings:
|
||||
diarize_segments, speaker_embeddings = diarize_result
|
||||
else:
|
||||
diarize_segments = diarize_result
|
||||
speaker_embeddings = None
|
||||
|
||||
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
|
||||
results.append((result, input_audio_path))
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
# Remove pandas dataframes from result so that
|
||||
# we can serialize the result with json
|
||||
for seg in result["segments"]:
|
||||
seg.pop("word-segments", None)
|
||||
seg.pop("char-segments", None)
|
||||
|
||||
result["language"] = align_language
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
69
whisperx/types.py
Normal file
69
whisperx/types.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import TypedDict, Optional, List, Tuple
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
"""
|
||||
A single word of a speech.
|
||||
"""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
class SingleCharSegment(TypedDict):
|
||||
"""
|
||||
A single char of a speech.
|
||||
"""
|
||||
char: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
|
||||
class SingleSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
class SegmentData(TypedDict):
|
||||
"""
|
||||
Temporary processing data used during alignment.
|
||||
Contains cleaned and preprocessed data for each segment.
|
||||
"""
|
||||
clean_char: List[str] # Cleaned characters that exist in model dictionary
|
||||
clean_cdx: List[int] # Original indices of cleaned characters
|
||||
clean_wdx: List[int] # Indices of words containing valid characters
|
||||
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: List[SingleWordSegment]
|
||||
chars: Optional[List[SingleCharSegment]]
|
||||
|
||||
|
||||
class TranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: List[SingleSegment]
|
||||
language: str
|
||||
|
||||
|
||||
class AlignedTranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: List[SingleAlignedSegment]
|
||||
word_segments: List[SingleWordSegment]
|
@ -105,6 +105,7 @@ LANGUAGES = {
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
@ -123,6 +124,7 @@ TO_LANGUAGE_CODE = {
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@ -212,7 +214,12 @@ class WriteTXT(ResultWriter):
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
speaker = segment.get("speaker")
|
||||
text = segment["text"].strip()
|
||||
if speaker is not None:
|
||||
print(f"[{speaker}]: {text}", file=file, flush=True)
|
||||
else:
|
||||
print(text, file=file, flush=True)
|
||||
|
||||
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
@ -226,16 +233,24 @@ class SubtitlesWriter(ResultWriter):
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
|
||||
if len(result["segments"]) == 0:
|
||||
return
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
last = result["segments"][0]["words"][0]["start"]
|
||||
times: list[tuple] = []
|
||||
last = result["segments"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
||||
long_pause = not preserve_segments
|
||||
if "start" in timing:
|
||||
long_pause = long_pause and timing["start"] - last > 3.0
|
||||
else:
|
||||
long_pause = False
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
@ -251,8 +266,9 @@ class SubtitlesWriter(ResultWriter):
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
yield subtitle, times
|
||||
subtitle = []
|
||||
times = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
@ -260,40 +276,56 @@ class SubtitlesWriter(ResultWriter):
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
times.append((segment["start"], segment["end"], segment.get("speaker")))
|
||||
if "start" in timing:
|
||||
last = timing["start"]
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
yield subtitle, times
|
||||
|
||||
if "words" in result["segments"][0]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
if highlight_words:
|
||||
for subtitle, _ in iterate_subtitles():
|
||||
sstart, ssend, speaker = _[0]
|
||||
subtitle_start = self.format_timestamp(sstart)
|
||||
subtitle_end = self.format_timestamp(ssend)
|
||||
if result["language"] in LANGUAGES_WITHOUT_SPACES:
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
else:
|
||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||
has_timing = any(["start" in word for word in subtitle])
|
||||
|
||||
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||
prefix = ""
|
||||
if speaker is not None:
|
||||
prefix = f"[{speaker}]: "
|
||||
|
||||
if highlight_words and has_timing:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
if "start" in this_word:
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, prefix + subtitle_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
yield start, end, prefix + " ".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, subtitle_text
|
||||
yield subtitle_start, subtitle_end, prefix + subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
if "speaker" in segment:
|
||||
segment_text = f"[{segment['speaker']}]: {segment_text}"
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
@ -346,12 +378,34 @@ class WriteTSV(ResultWriter):
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
class WriteAudacity(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a text file that audacity can import as labels.
|
||||
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
|
||||
Yet this is not an audacity project but only a label file!
|
||||
|
||||
Please note : Audacity uses seconds in timestamps not ms!
|
||||
Also there is no header expected.
|
||||
|
||||
If speaker is provided it is prepended to the text between double square brackets [[]].
|
||||
"""
|
||||
|
||||
extension: str = "aud"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
ARROW = " "
|
||||
for segment in result["segments"]:
|
||||
print(segment["start"], file=file, end=ARROW)
|
||||
print(segment["end"], file=file, end=ARROW)
|
||||
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
json.dump(result, file)
|
||||
json.dump(result, file, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_writer(
|
||||
@ -364,6 +418,9 @@ def get_writer(
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
optional_writers = {
|
||||
"aud": WriteAudacity,
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
@ -374,10 +431,12 @@ def get_writer(
|
||||
|
||||
return write_all
|
||||
|
||||
if output_format in optional_writers:
|
||||
return optional_writers[output_format](output_dir)
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
def interpolate_nans(x, method='nearest'):
|
||||
if x.notnull().sum() > 1:
|
||||
return x.interpolate(method=method).ffill().bfill()
|
||||
else:
|
||||
return x.ffill().bfill()
|
||||
return x.ffill().bfill()
|
||||
|
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from whisperx.vads.pyannote import Pyannote as Pyannote
|
||||
from whisperx.vads.silero import Silero as Silero
|
||||
from whisperx.vads.vad import Vad as Vad
|
@ -1,55 +1,44 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from typing import Callable, Text, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
from tqdm import tqdm
|
||||
from pyannote.core import Annotation, SlidingWindowFeature
|
||||
from pyannote.core import Segment
|
||||
|
||||
from .diarize import Segment as SegmentX
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
|
||||
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
# Dynamically resolve the path to the model file
|
||||
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
|
||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||
else:
|
||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||
|
||||
# Check if the resolved model file exists
|
||||
if not os.path.exists(model_fp):
|
||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
if not os.path.isfile(model_fp):
|
||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
@ -85,21 +74,21 @@ class Binarize:
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
|
||||
Pyannote-audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float('inf')
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float('inf')
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@ -142,13 +131,12 @@ class Binarize:
|
||||
is_active = k_scores[0] > self.onset
|
||||
curr_scores = [k_scores[0]]
|
||||
curr_timestamps = [start]
|
||||
t = start
|
||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||
# currently active
|
||||
if is_active:
|
||||
if is_active:
|
||||
curr_duration = t - start
|
||||
if curr_duration > self.max_duration:
|
||||
# if curr_duration > 15:
|
||||
# import pdb; pdb.set_trace()
|
||||
search_after = len(curr_scores) // 2
|
||||
# divide segment
|
||||
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||||
@ -156,8 +144,8 @@ class Binarize:
|
||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = curr_timestamps[min_score_div_idx]
|
||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
||||
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||
# switching from active to inactive
|
||||
elif y < self.offset:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
@ -166,14 +154,14 @@ class Binarize:
|
||||
is_active = False
|
||||
curr_scores = []
|
||||
curr_timestamps = []
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
# currently inactive
|
||||
else:
|
||||
# switching from inactive to active
|
||||
if y > self.onset:
|
||||
start = t
|
||||
is_active = True
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
|
||||
# if active at the end, add final region
|
||||
if is_active:
|
||||
@ -198,11 +186,11 @@ class Binarize:
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
):
|
||||
|
||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||
@ -241,67 +229,35 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
return segmentations
|
||||
|
||||
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
class Pyannote(Vad):
|
||||
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||
active[region, k] = 1
|
||||
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||
print(">>Performing voice activity detection using Pyannote...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
return self.vad_pipeline(audio)
|
||||
|
||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||
active = active.support(collar=min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
active = active.for_json()
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return torch.from_numpy(audio).unsqueeze(0)
|
||||
|
||||
def merge_chunks(segments, chunk_size):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
# assert segments_list, "segments_list is empty."
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
|
||||
for seg in segments_list:
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
66
whisperx/vads/silero.py
Normal file
66
whisperx/vads/silero.py
Normal file
@ -0,0 +1,66 @@
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Text
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||
|
||||
|
||||
class Silero(Vad):
|
||||
# check again default values
|
||||
def __init__(self, **kwargs):
|
||||
print(">>Performing voice activity detection using Silero...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
|
||||
self.vad_onset = kwargs['vad_onset']
|
||||
self.chunk_size = kwargs['chunk_size']
|
||||
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False,
|
||||
onnx=False,
|
||||
trust_repo=True)
|
||||
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
"""use silero to get segments of speech"""
|
||||
# Only accept 16000 Hz for now.
|
||||
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||
sample_rate = audio["sample_rate"]
|
||||
if sample_rate != 16000:
|
||||
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||
|
||||
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||
model=self.vad_pipeline,
|
||||
sampling_rate=sample_rate,
|
||||
max_speech_duration_s=self.chunk_size,
|
||||
threshold=self.vad_onset
|
||||
# min_silence_duration_ms = self.min_duration_off/1000
|
||||
# min_speech_duration_ms = self.min_duration_on/1000
|
||||
# ...
|
||||
# See silero documentation for full option list
|
||||
)
|
||||
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def merge_chunks(segments_list,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from pyannote.core import Annotation, Segment
|
||||
|
||||
|
||||
class Vad:
|
||||
def __init__(self, vad_onset):
|
||||
if not (0 < vad_onset < 1):
|
||||
raise ValueError(
|
||||
"vad_onset is a decimal value between 0 and 1."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
pass
|
||||
|
||||
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float,
|
||||
offset: Optional[float]):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs: list[tuple]= []
|
||||
speaker_idxs: list[Optional[str]] = []
|
||||
|
||||
curr_start = segments[0].start
|
||||
for seg in segments:
|
||||
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
|
||||
return merged_segments
|
||||
|
||||
# Unused function
|
||||
@staticmethod
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||
active[region, k] = 1
|
||||
|
||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||
active = active.support(collar=min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
active = active.for_json()
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
Reference in New Issue
Block a user