mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
297 Commits
Author | SHA1 | Date | |
---|---|---|---|
f5b40b5366 | |||
ac0c8bd79a | |||
cd59f21d1a | |||
0aed874589 | |||
f10dbf6ab1 | |||
a7564c2ad6 | |||
e7712f496e | |||
8e53866704 | |||
3205436d58 | |||
8c58c54635 | |||
0d9807adc5 | |||
4db839018c | |||
f8d11df727 | |||
d2f0e53f71 | |||
7489ebf876 | |||
90256cc481 | |||
b41ebd4871 | |||
63bc1903c1 | |||
272714e07d | |||
44e8bf5bb6 | |||
7b3c9ce629 | |||
36d2622e27 | |||
8bfa12193b | |||
acbeba6057 | |||
fca563a782 | |||
2117909bf6 | |||
de0d8fe313 | |||
355f8e06f7 | |||
86e2b3ee74 | |||
70c639cdb5 | |||
235536e28d | |||
12604a48ea | |||
ffbc73664c | |||
289eadfc76 | |||
22a93f2932 | |||
1027367b79 | |||
5e54b872a9 | |||
6be02cccfa | |||
2f93e029c7 | |||
024bc8481b | |||
f286e7f3de | |||
73e644559d | |||
1ec527375a | |||
6695426a85 | |||
7a98456321 | |||
aaddb83aa5 | |||
c288f4812a | |||
4ebfb078c5 | |||
65b2332e13 | |||
69281f3a29 | |||
734084cdf6 | |||
9395b0de18 | |||
d57f9dc54c | |||
a90bd1ce3f | |||
79eb8fa53d | |||
10b05fc43f | |||
26d9b46888 | |||
9a8967f27e | |||
0f7f9f9f83 | |||
c60594fa3b | |||
4916192246 | |||
cbdac53e87 | |||
940a223219 | |||
a0eb31019b | |||
b08ad67a72 | |||
c18f9f979b | |||
948b3e368b | |||
e9ac5b63bc | |||
90b45459d9 | |||
81c4af96a6 | |||
1c6d9327bc | |||
0fdb55d317 | |||
51da22771f | |||
15ad5bf7df | |||
7fdbd21fe3 | |||
3ff625c561 | |||
7307306a9d | |||
3027cc32bc | |||
9e4b1b4c49 | |||
9b9e03c4cc | |||
19eff8e79a | |||
6f3bc5b7b8 | |||
9809336db6 | |||
a898b3ba94 | |||
c141074cbd | |||
a9e50ef0af | |||
161ae1f7ad | |||
a83ddbdf9b | |||
9e3a9e0e38 | |||
3f339f9515 | |||
9a9b6171e6 | |||
59b4d88d1d | |||
6f70aa6beb | |||
912920c591 | |||
58f00339af | |||
f2da2f858e | |||
78dcfaab51 | |||
d6562c26da | |||
c313f4dd5c | |||
bbaa2f0d1a | |||
e906be9688 | |||
fbbd07bece | |||
d8c9196346 | |||
2686f74bc9 | |||
8227807fa9 | |||
59962a70be | |||
06e30b2a25 | |||
6bb2f1cd48 | |||
f8cc46c6f7 | |||
942c336b8f | |||
8ae6416594 | |||
8540ff5985 | |||
5dfbfcbdc0 | |||
1c7b1a87da | |||
9f23739f90 | |||
19ab91c5a6 | |||
089cd5ab21 | |||
2b7ab95ad6 | |||
4553e0d4ed | |||
f865dfe710 | |||
4acbdd75be | |||
e9c507ce5d | |||
a5dca2cc65 | |||
8a8eeb33ee | |||
b4d7b1a422 | |||
5a16e59217 | |||
b4e4143e3b | |||
4b05198eed | |||
71a5281bde | |||
d97cdb7bcf | |||
20161935a1 | |||
1d7f8ccbf1 | |||
5756b0fb13 | |||
aaaa3de810 | |||
ba30365344 | |||
bd3aa03b6f | |||
f5c544ff90 | |||
7c2a9a8b7b | |||
9f41c49fe5 | |||
48d651e5ea | |||
4ece2369d7 | |||
52fbe5c26f | |||
6703d2774b | |||
a2af569838 | |||
0c7f32f55c | |||
6936dd6991 | |||
6b1100a919 | |||
d4a600b568 | |||
afd5ef1d58 | |||
dbeb8617f2 | |||
c6fe379d9e | |||
e9a6385d3c | |||
b522133340 | |||
49e0130e4e | |||
d4ac9531d9 | |||
66808f6147 | |||
b69956d725 | |||
a150df4310 | |||
02c0323777 | |||
14a7cab8eb | |||
acf31b754f | |||
4cdce3b927 | |||
a5356509b6 | |||
1001a055db | |||
051047bb25 | |||
c1b821a08d | |||
78e20a16a8 | |||
be07c13f75 | |||
8049dba2f7 | |||
d077abdbdf | |||
84423ca517 | |||
a22b8b009b | |||
79801167ac | |||
07fafa37b3 | |||
a0b6459c8b | |||
2a11ce3ef0 | |||
18abcf46ee | |||
652aa24919 | |||
b17908473d | |||
f137f31de6 | |||
e94b904308 | |||
ffd6167b26 | |||
4c7ce14fed | |||
0ae0d49d1d | |||
b1a98b78c9 | |||
c6d9e6cb67 | |||
31f5233949 | |||
2ca99ce909 | |||
15d9e08d3e | |||
15451d0f1c | |||
8c4a21b66d | |||
5223de2a41 | |||
f505702dc7 | |||
adf455a97c | |||
9647f60fca | |||
a8bfac6bef | |||
6d414e20e2 | |||
3c7b03935b | |||
eb771cf56d | |||
cc81ab7db7 | |||
ef965a03ed | |||
6f2ff16aad | |||
81b12af321 | |||
c1197c490e | |||
4e28492dbd | |||
6cb7267dc2 | |||
abbb66b58e | |||
ea7bb91a56 | |||
d2d840f06c | |||
0a1137e41c | |||
0767597bff | |||
cb3ed4ab9d | |||
65688208c9 | |||
72685d0398 | |||
1bb4839b0f | |||
4acb5b3abc | |||
14e593f60b | |||
66da4b3eb7 | |||
18d5fdc995 | |||
423667f00b | |||
1b092de19a | |||
69a52b00c7 | |||
9e3145cead | |||
577db33430 | |||
da6ed83dc9 | |||
7eb9692cb9 | |||
8de0e2af51 | |||
225f6b4d69 | |||
864976af23 | |||
9d736dca1c | |||
d87f6268d0 | |||
d80b98601b | |||
aa37509362 | |||
15b4c558c2 | |||
54504a2be8 | |||
8c0fee90d3 | |||
016f0293cd | |||
44daf50501 | |||
48e7caad77 | |||
8673064658 | |||
e6ecbaa68f | |||
e92325b7eb | |||
eb712f3999 | |||
30eff5a01f | |||
734ecc2844 | |||
512ab1acf9 | |||
befe2b242e | |||
f9c5ff9f08 | |||
d39c1b2319 | |||
b13778fefd | |||
076ff96eb2 | |||
0c84c26d92 | |||
d7f1d16f19 | |||
74a00eecd7 | |||
b026407fd9 | |||
a323cff654 | |||
93ed6cfa93 | |||
9797a67391 | |||
5a4382ae4d | |||
ec6a110cdf | |||
8d8c027a92 | |||
4cbd3030cc | |||
1c528d1a3c | |||
c65e7ba9b4 | |||
5a47f458ac | |||
f1032bb40a | |||
bc8a03881a | |||
42b4909bc0 | |||
bb15d6b68e | |||
23d405e1cf | |||
17e2f7f859 | |||
1d9d630fb9 | |||
9c042c2d28 | |||
a23f2aa3f7 | |||
7c5468116f | |||
a1c705b3a7 | |||
29a5e0b236 | |||
715435db42 | |||
1fc965bc1a | |||
74b98ebfaa | |||
53396adb21 | |||
63fb5fc46f | |||
d8a2b4ffc9 | |||
9ffb7e7a23 | |||
fd8f1003cf | |||
46b416296f | |||
7642390d0a | |||
8b05ad4dae | |||
5421f1d7ca | |||
91e959ec4f | |||
eabf35dff0 | |||
4919ad21fc | |||
b50aafb17b | |||
2efa136114 | |||
0b839f3f01 | |||
d31f6e0b8a | |||
c8404d9805 |
31
.github/workflows/build-and-release.yml
vendored
Normal file
31
.github/workflows/build-and-release.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: Build and release
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
version: "0.5.14"
|
||||||
|
python-version: "3.9"
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: uv build
|
||||||
|
|
||||||
|
- name: Release to Github
|
||||||
|
uses: softprops/action-gh-release@v2
|
||||||
|
with:
|
||||||
|
files: dist/*.whl
|
||||||
|
|
||||||
|
- name: Publish package to PyPi
|
||||||
|
run: uv publish
|
||||||
|
env:
|
||||||
|
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
31
.github/workflows/python-compatibility.yml
vendored
Normal file
31
.github/workflows/python-compatibility.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: Python Compatibility Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
version: "0.5.14"
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install the project
|
||||||
|
run: uv sync --all-extras
|
||||||
|
|
||||||
|
- name: Test import
|
||||||
|
run: |
|
||||||
|
uv run python -c "import whisperx; print('Successfully imported whisperx')"
|
172
.gitignore
vendored
172
.gitignore
vendored
@ -1,3 +1,171 @@
|
|||||||
whisperx.egg-info/
|
# Byte-compiled / optimized / DLL files
|
||||||
**/__pycache__/
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
19
Dockerfile
19
Dockerfile
@ -1,19 +0,0 @@
|
|||||||
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel
|
|
||||||
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \
|
|
||||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y wget && \
|
|
||||||
wget -qO - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
|
|
||||||
apt-get update && \
|
|
||||||
apt-get install -y git && \
|
|
||||||
apt-get install libsndfile1 -y && \
|
|
||||||
apt-get clean
|
|
||||||
|
|
||||||
RUN pip install --upgrade pip
|
|
||||||
RUN pip install --upgrade setuptools
|
|
||||||
RUN pip install git+https://github.com/m-bain/whisperx.git
|
|
||||||
RUN pip install jupyter ipykernel
|
|
||||||
EXPOSE 8888
|
|
||||||
# Use external volume for data
|
|
||||||
ENV NVIDIA_VISIBLE_DEVICES 1
|
|
||||||
CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--NotebookApp.token=''","--NotebookApp.password=''", "--allow-root"]
|
|
33
LICENSE
33
LICENSE
@ -1,27 +1,24 @@
|
|||||||
Copyright (c) 2022, Max Bain
|
BSD 2-Clause License
|
||||||
All rights reserved.
|
|
||||||
|
Copyright (c) 2024, Max Bain
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
modification, are permitted provided that the following conditions are met:
|
modification, are permitted provided that the following conditions are met:
|
||||||
1. Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in the
|
|
||||||
documentation and/or other materials provided with the distribution.
|
|
||||||
3. All advertising materials mentioning features or use of this software
|
|
||||||
must display the following acknowledgement:
|
|
||||||
This product includes software developed by Max Bain.
|
|
||||||
4. Neither the name of Max Bain nor the
|
|
||||||
names of its contributors may be used to endorse or promote products
|
|
||||||
derived from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
list of conditions and the following disclaimer.
|
||||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
include whisperx/assets/*
|
include whisperx/assets/*
|
||||||
include whisperx/assets/gpt2/*
|
include LICENSE
|
||||||
include whisperx/assets/multilingual/*
|
include requirements.txt
|
||||||
include whisperx/normalizers/english.json
|
|
||||||
|
127
README.md
127
README.md
@ -23,7 +23,7 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
|
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
||||||
|
|
||||||
|
|
||||||
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||||
@ -32,12 +32,12 @@
|
|||||||
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
||||||
|
|
||||||
|
|
||||||
This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
||||||
|
|
||||||
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
||||||
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
||||||
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
||||||
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID)
|
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
||||||
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
||||||
|
|
||||||
|
|
||||||
@ -52,59 +52,63 @@ This repository provides fast automatic speaker recognition (70x realtime with l
|
|||||||
|
|
||||||
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
|
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
|
||||||
|
|
||||||
- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*!
|
|
||||||
- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper.
|
|
||||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
|
|
||||||
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
|
||||||
- Character level timestamps (see `*.char.ass` file output)
|
|
||||||
- Diarization (still in beta, add `--diarize`)
|
|
||||||
|
|
||||||
<h2 align="left", id="highlights">New🚨</h2>
|
<h2 align="left", id="highlights">New🚨</h2>
|
||||||
|
|
||||||
|
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||||
|
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||||
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
|
||||||
|
|
||||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
### 1. Simple Installation (Recommended)
|
||||||
|
|
||||||
|
The easiest way to install WhisperX is through PyPi:
|
||||||
|
|
||||||
### 1. Create Python3.10 environment
|
```bash
|
||||||
|
pip install whisperx
|
||||||
`conda create --name whisperx python=3.10`
|
|
||||||
|
|
||||||
`conda activate whisperx`
|
|
||||||
|
|
||||||
|
|
||||||
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
|
||||||
|
|
||||||
`pip3 install torch torchvision torchaudio`
|
|
||||||
|
|
||||||
See other methods [here.](https://pytorch.org/get-started/locally/)
|
|
||||||
|
|
||||||
### 3. Install this repo
|
|
||||||
|
|
||||||
`pip install git+https://github.com/m-bain/whisperx.git@v3`
|
|
||||||
|
|
||||||
If already installed, update package to most recent commit
|
|
||||||
|
|
||||||
`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
|
|
||||||
|
|
||||||
If wishing to modify this package, clone and install in editable mode:
|
|
||||||
```
|
```
|
||||||
$ git clone https://github.com/m-bain/whisperX.git
|
|
||||||
$ cd whisperX
|
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
|
||||||
$ pip install -e .
|
|
||||||
|
```bash
|
||||||
|
uvx whisperx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 2. Advanced Installation Options
|
||||||
|
|
||||||
|
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
|
||||||
|
|
||||||
|
#### Option A: Install from GitHub
|
||||||
|
|
||||||
|
To install directly from the GitHub repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvx git+https://github.com/m-bain/whisperX.git
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option B: Developer Installation
|
||||||
|
|
||||||
|
If you want to modify the code or contribute to the project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/m-bain/whisperX.git
|
||||||
|
cd whisperX
|
||||||
|
uv sync --all-extras --dev
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
|
||||||
|
|
||||||
|
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||||
|
|
||||||
|
> **Note**<br>
|
||||||
|
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||||
|
|
||||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||||
|
|
||||||
@ -112,7 +116,7 @@ To **enable Speaker. Diarization**, include your Hugging Face access token that
|
|||||||
|
|
||||||
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
||||||
|
|
||||||
whisperx examples/sample01.wav
|
whisperx path/to/audio.wav
|
||||||
|
|
||||||
|
|
||||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||||
@ -126,23 +130,27 @@ https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-
|
|||||||
|
|
||||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||||
|
|
||||||
|
|
||||||
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
|
||||||
|
|
||||||
|
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||||
|
|
||||||
|
whisperx path/to/audio.wav --compute_type int8
|
||||||
|
|
||||||
### Other languages
|
### Other languages
|
||||||
|
|
||||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||||
Just pass in the `--language` code, and use the whisper `--model large`.
|
Just pass in the `--language` code, and use the whisper `--model large`.
|
||||||
|
|
||||||
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||||
|
|
||||||
|
|
||||||
#### E.g. German
|
#### E.g. German
|
||||||
whisperx --model large-v2 --language de examples/sample_de_01.wav
|
whisperx --model large-v2 --language de path/to/audio.wav
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||||
|
|
||||||
@ -163,6 +171,10 @@ compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accura
|
|||||||
# 1. Transcribe with original whisper (batched)
|
# 1. Transcribe with original whisper (batched)
|
||||||
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
||||||
|
|
||||||
|
# save model to local path (optional)
|
||||||
|
# model_dir = "/path/"
|
||||||
|
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
|
||||||
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
audio = whisperx.load_audio(audio_file)
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
print(result["segments"]) # before alignment
|
print(result["segments"]) # before alignment
|
||||||
@ -183,14 +195,21 @@ print(result["segments"]) # after alignment
|
|||||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||||
|
|
||||||
# add min/max number of speakers if known
|
# add min/max number of speakers if known
|
||||||
diarize_segments = diarize_model(input_audio_path)
|
diarize_segments = diarize_model(audio)
|
||||||
# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
|
|
||||||
result = assign_word_speakers(diarize_segments, result)
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||||
print(diarize_segments)
|
print(diarize_segments)
|
||||||
print(result["segments"]) # segments are now assigned speaker IDs
|
print(result["segments"]) # segments are now assigned speaker IDs
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Demos 🚀
|
||||||
|
|
||||||
|
[](https://replicate.com/victor-upmeet/whisperx)
|
||||||
|
[](https://replicate.com/daanelson/whisperx)
|
||||||
|
[](https://replicate.com/carnifexer/whisperx)
|
||||||
|
|
||||||
|
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||||
|
|
||||||
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||||
|
|
||||||
@ -203,14 +222,14 @@ To reduce GPU memory requirements, try any of the following (2. & 3. can affect
|
|||||||
|
|
||||||
Transcription differences from openai's whisper:
|
Transcription differences from openai's whisper:
|
||||||
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
|
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||||
|
|
||||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||||
|
|
||||||
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
|
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
|
||||||
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
||||||
- Diarization is far from perfect (working on this with custom model v4 -- see contact me).
|
- Diarization is far from perfect
|
||||||
- Language specific wav2vec2 model is needed
|
- Language specific wav2vec2 model is needed
|
||||||
|
|
||||||
|
|
||||||
@ -246,7 +265,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
|||||||
|
|
||||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||||
|
|
||||||
* [ ] Allow silero-vad as alternative VAD option
|
* [x] Allow silero-vad as alternative VAD option
|
||||||
|
|
||||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||||
|
|
||||||
@ -254,7 +273,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
|||||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||||
|
|
||||||
|
|
||||||
Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch.
|
Contact maxhbain@gmail.com for queries.
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||||
|
|
||||||
@ -268,7 +287,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
|
|||||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||||
|
|
||||||
|
|
||||||
Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
Valuable VAD & Diarization Models from:
|
||||||
|
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||||
|
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||||
|
|
||||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||||
|
|
||||||
@ -283,7 +304,7 @@ If you use this in your research, please cite the paper:
|
|||||||
@article{bain2022whisperx,
|
@article{bain2022whisperx,
|
||||||
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
|
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
|
||||||
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
|
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
|
||||||
journal={arXiv preprint, arXiv:2303.00747},
|
journal={INTERSPEECH 2023},
|
||||||
year={2023}
|
year={2023}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "11fc5246",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/opt/conda/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /opt/conda/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZNK3c1010TensorImpl36is_contiguous_nondefault_policy_implENS_12MemoryFormatE\n",
|
|
||||||
" warn(f\"Failed to load image Python extension: {e}\")\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "OutOfMemoryError",
|
|
||||||
"evalue": "CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 5.76 GiB already allocated; 59.19 MiB free; 6.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[0;32m/tmp/ipykernel_66/1447832577.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# transcribe with original whisper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwhisper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"large\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranscribe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maudio_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/whisper/__init__.py\u001b[0m in \u001b[0;36mload_model\u001b[0;34m(name, device, download_root, in_memory)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_alignment_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malignment_heads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 987\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 989\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 990\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 991\u001b[0m def register_backward_hook(\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 985\u001b[0m return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[1;32m 986\u001b[0m non_blocking, memory_format=convert_to_format)\n\u001b[0;32m--> 987\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 5.76 GiB already allocated; 59.19 MiB free; 6.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import whisperx\n",
|
|
||||||
"import whisper\n",
|
|
||||||
"\n",
|
|
||||||
"device = \"cuda\" \n",
|
|
||||||
"audio_file = \"audio.mp3\"\n",
|
|
||||||
"\n",
|
|
||||||
"# transcribe with original whisper\n",
|
|
||||||
"model = whisper.load_model(\"large\", device)\n",
|
|
||||||
"result = model.transcribe(audio_file)\n",
|
|
||||||
"\n",
|
|
||||||
"print(result[\"segments\"]) # before alignment\n",
|
|
||||||
"\n",
|
|
||||||
"# load alignment model and metadata\n",
|
|
||||||
"model_a, metadata = whisperx.load_align_model(language_code=result[\"language\"], device=device)\n",
|
|
||||||
"\n",
|
|
||||||
"# align whisper output\n",
|
|
||||||
"result_aligned = whisperx.align(result[\"segments\"], model_a, metadata, audio_file, device)\n",
|
|
||||||
"\n",
|
|
||||||
"print(result_aligned[\"segments\"]) # after alignment\n",
|
|
||||||
"print(result_aligned[\"word_segments\"]) # after alignment"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "b63e6170",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.12"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[project]
|
||||||
|
urls = { repository = "https://github.com/m-bain/whisperx" }
|
||||||
|
authors = [{ name = "Max Bain" }]
|
||||||
|
name = "whisperx"
|
||||||
|
version = "3.3.3"
|
||||||
|
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9, <3.13"
|
||||||
|
license = { text = "BSD-2-Clause" }
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"ctranslate2<4.5.0",
|
||||||
|
"faster-whisper>=1.1.1",
|
||||||
|
"nltk>=3.9.1",
|
||||||
|
"numpy>=2.0.2",
|
||||||
|
"onnxruntime>=1.19",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"pyannote-audio>=3.3.2",
|
||||||
|
"torch>=2.5.1",
|
||||||
|
"torchaudio>=2.5.1",
|
||||||
|
"transformers>=4.48.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
whisperx = "whisperx.transcribe:cli"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["whisperx*"]
|
@ -1,8 +0,0 @@
|
|||||||
torch==2.0.0
|
|
||||||
torchaudio==2.0.1
|
|
||||||
faster-whisper
|
|
||||||
transformers
|
|
||||||
ffmpeg-python==0.2.0
|
|
||||||
pandas
|
|
||||||
setuptools==65.6.3
|
|
||||||
nltk
|
|
28
setup.py
28
setup.py
@ -1,28 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="whisperx",
|
|
||||||
py_modules=["whisperx"],
|
|
||||||
version="3.1.0",
|
|
||||||
description="Time-Accurate Automatic Speech Recognition.",
|
|
||||||
readme="README.md",
|
|
||||||
python_requires=">=3.8",
|
|
||||||
author="Max Bain",
|
|
||||||
url="https://github.com/m-bain/whisperx",
|
|
||||||
license="MIT",
|
|
||||||
packages=find_packages(exclude=["tests*"]),
|
|
||||||
install_requires=[
|
|
||||||
str(r)
|
|
||||||
for r in pkg_resources.parse_requirements(
|
|
||||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
|
||||||
)
|
|
||||||
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
|
||||||
entry_points = {
|
|
||||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
|
||||||
},
|
|
||||||
include_package_data=True,
|
|
||||||
extras_require={'dev': ['pytest']},
|
|
||||||
)
|
|
226
whisperx/SubtitlesProcessor.py
Normal file
226
whisperx/SubtitlesProcessor.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import math
|
||||||
|
from whisperx.conjunctions import get_conjunctions, get_comma
|
||||||
|
|
||||||
|
def normal_round(n):
|
||||||
|
if n - math.floor(n) < 0.5:
|
||||||
|
return math.floor(n)
|
||||||
|
return math.ceil(n)
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(seconds: float, is_vtt: bool = False):
|
||||||
|
|
||||||
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
|
milliseconds = round(seconds * 1000.0)
|
||||||
|
|
||||||
|
hours = milliseconds // 3_600_000
|
||||||
|
milliseconds -= hours * 3_600_000
|
||||||
|
|
||||||
|
minutes = milliseconds // 60_000
|
||||||
|
milliseconds -= minutes * 60_000
|
||||||
|
|
||||||
|
seconds = milliseconds // 1_000
|
||||||
|
milliseconds -= seconds * 1_000
|
||||||
|
|
||||||
|
separator = '.' if is_vtt else ','
|
||||||
|
|
||||||
|
hours_marker = f"{hours:02d}:"
|
||||||
|
return (
|
||||||
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SubtitlesProcessor:
|
||||||
|
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
|
||||||
|
self.comma = get_comma(lang)
|
||||||
|
self.conjunctions = set(get_conjunctions(lang))
|
||||||
|
self.segments = segments
|
||||||
|
self.lang = lang
|
||||||
|
self.max_line_length = max_line_length
|
||||||
|
self.min_char_length_splitter = min_char_length_splitter
|
||||||
|
self.is_vtt = is_vtt
|
||||||
|
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
|
||||||
|
if self.lang in complex_script_languages:
|
||||||
|
self.max_line_length = 30
|
||||||
|
self.min_char_length_splitter = 20
|
||||||
|
|
||||||
|
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
|
||||||
|
k = 0.25
|
||||||
|
has_prev_end = i > 0 and 'end' in words[i - 1]
|
||||||
|
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
|
||||||
|
|
||||||
|
if has_prev_end:
|
||||||
|
words[i]['start'] = words[i - 1]['end']
|
||||||
|
if has_next_start:
|
||||||
|
words[i]['end'] = words[i + 1]['start']
|
||||||
|
else:
|
||||||
|
if next_segment_start_time:
|
||||||
|
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
|
||||||
|
else:
|
||||||
|
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
|
||||||
|
|
||||||
|
elif has_next_start:
|
||||||
|
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
|
||||||
|
words[i]['end'] = words[i + 1]['start']
|
||||||
|
|
||||||
|
else:
|
||||||
|
if next_segment_start_time:
|
||||||
|
words[i]['start'] = next_segment_start_time - 1
|
||||||
|
words[i]['end'] = next_segment_start_time - 0.5
|
||||||
|
else:
|
||||||
|
words[i]['start'] = 0
|
||||||
|
words[i]['end'] = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_segments(self, advanced_splitting=True):
|
||||||
|
subtitles = []
|
||||||
|
for i, segment in enumerate(self.segments):
|
||||||
|
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
|
||||||
|
|
||||||
|
if advanced_splitting:
|
||||||
|
|
||||||
|
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
|
||||||
|
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
|
||||||
|
else:
|
||||||
|
words = segment['words']
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
if 'start' not in word or 'end' not in word:
|
||||||
|
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||||
|
|
||||||
|
subtitles.append({
|
||||||
|
'start': segment['start'],
|
||||||
|
'end': segment['end'],
|
||||||
|
'text': segment['text']
|
||||||
|
})
|
||||||
|
|
||||||
|
return subtitles
|
||||||
|
|
||||||
|
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
|
||||||
|
split_points = []
|
||||||
|
last_split_point = 0
|
||||||
|
char_count = 0
|
||||||
|
|
||||||
|
words = segment.get('words', segment['text'].split())
|
||||||
|
add_space = 0 if self.lang in ['zh', 'ja'] else 1
|
||||||
|
|
||||||
|
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
|
||||||
|
char_count_after = total_char_count
|
||||||
|
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
word_text = word['word'] if isinstance(word, dict) else word
|
||||||
|
word_length = len(word_text) + add_space
|
||||||
|
char_count += word_length
|
||||||
|
char_count_after -= word_length
|
||||||
|
|
||||||
|
char_count_before = char_count - word_length
|
||||||
|
|
||||||
|
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
|
||||||
|
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||||
|
|
||||||
|
if char_count >= self.max_line_length:
|
||||||
|
midpoint = normal_round((last_split_point + i) / 2)
|
||||||
|
if char_count_before >= self.min_char_length_splitter:
|
||||||
|
split_points.append(midpoint)
|
||||||
|
last_split_point = midpoint + 1
|
||||||
|
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
|
||||||
|
|
||||||
|
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||||
|
split_points.append(i)
|
||||||
|
last_split_point = i + 1
|
||||||
|
char_count = 0
|
||||||
|
|
||||||
|
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||||
|
split_points.append(i - 1)
|
||||||
|
last_split_point = i
|
||||||
|
char_count = word_length
|
||||||
|
|
||||||
|
return split_points
|
||||||
|
|
||||||
|
|
||||||
|
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
|
||||||
|
subtitles = []
|
||||||
|
|
||||||
|
words = segment.get('words', segment['text'].split())
|
||||||
|
total_word_count = len(words)
|
||||||
|
total_time = segment['end'] - segment['start']
|
||||||
|
elapsed_time = segment['start']
|
||||||
|
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
|
||||||
|
start_idx = 0
|
||||||
|
for split_point in split_points:
|
||||||
|
|
||||||
|
fragment_words = words[start_idx:split_point + 1]
|
||||||
|
current_word_count = len(fragment_words)
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(fragment_words[0], dict):
|
||||||
|
start_time = fragment_words[0]['start']
|
||||||
|
end_time = fragment_words[-1]['end']
|
||||||
|
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
|
||||||
|
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
|
||||||
|
end_time = next_start_time_for_word
|
||||||
|
else:
|
||||||
|
fragment = prefix.join(fragment_words).strip()
|
||||||
|
current_duration = (current_word_count / total_word_count) * total_time
|
||||||
|
start_time = elapsed_time
|
||||||
|
end_time = elapsed_time + current_duration
|
||||||
|
elapsed_time += current_duration
|
||||||
|
|
||||||
|
|
||||||
|
subtitles.append({
|
||||||
|
'start': start_time,
|
||||||
|
'end': end_time,
|
||||||
|
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||||
|
})
|
||||||
|
|
||||||
|
start_idx = split_point + 1
|
||||||
|
|
||||||
|
# Handle the last fragment
|
||||||
|
if start_idx < len(words):
|
||||||
|
fragment_words = words[start_idx:]
|
||||||
|
current_word_count = len(fragment_words)
|
||||||
|
|
||||||
|
if isinstance(fragment_words[0], dict):
|
||||||
|
start_time = fragment_words[0]['start']
|
||||||
|
end_time = fragment_words[-1]['end']
|
||||||
|
else:
|
||||||
|
fragment = prefix.join(fragment_words).strip()
|
||||||
|
current_duration = (current_word_count / total_word_count) * total_time
|
||||||
|
start_time = elapsed_time
|
||||||
|
end_time = elapsed_time + current_duration
|
||||||
|
|
||||||
|
if next_start_time and (next_start_time - end_time) <= 0.8:
|
||||||
|
end_time = next_start_time
|
||||||
|
|
||||||
|
subtitles.append({
|
||||||
|
'start': start_time,
|
||||||
|
'end': end_time if end_time is not None else segment['end'],
|
||||||
|
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||||
|
})
|
||||||
|
|
||||||
|
return subtitles
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, filename="subtitles.srt", advanced_splitting=True):
|
||||||
|
|
||||||
|
subtitles = self.process_segments(advanced_splitting)
|
||||||
|
|
||||||
|
def write_subtitle(file, idx, start_time, end_time, text):
|
||||||
|
|
||||||
|
file.write(f"{idx}\n")
|
||||||
|
file.write(f"{start_time} --> {end_time}\n")
|
||||||
|
file.write(text + "\n\n")
|
||||||
|
|
||||||
|
with open(filename, 'w', encoding='utf-8') as file:
|
||||||
|
if self.is_vtt:
|
||||||
|
file.write("WEBVTT\n\n")
|
||||||
|
|
||||||
|
if advanced_splitting:
|
||||||
|
for idx, subtitle in enumerate(subtitles, 1):
|
||||||
|
start_time = format_timestamp(subtitle['start'], self.is_vtt)
|
||||||
|
end_time = format_timestamp(subtitle['end'], self.is_vtt)
|
||||||
|
text = subtitle['text'].strip()
|
||||||
|
write_subtitle(file, idx, start_time, end_time, text)
|
||||||
|
|
||||||
|
return len(subtitles)
|
@ -1,4 +1,7 @@
|
|||||||
from .transcribe import load_model
|
from whisperx.alignment import load_align_model as load_align_model, align as align
|
||||||
from .alignment import load_align_model, align
|
from whisperx.asr import load_model as load_model
|
||||||
from .audio import load_audio
|
from whisperx.audio import load_audio as load_audio
|
||||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
from whisperx.diarize import (
|
||||||
|
assign_word_speakers as assign_word_speakers,
|
||||||
|
DiarizationPipeline as DiarizationPipeline,
|
||||||
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .transcribe import cli
|
from whisperx.transcribe import cli
|
||||||
|
|
||||||
|
|
||||||
cli()
|
cli()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
""""
|
"""
|
||||||
Forced Alignment with Whisper
|
Forced Alignment with Whisper
|
||||||
C. Max Bain
|
C. Max Bain
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, Union
|
from typing import Iterable, Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -11,9 +13,18 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio
|
from whisperx.audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from whisperx.utils import interpolate_nans
|
||||||
import nltk
|
from whisperx.types import (
|
||||||
|
AlignedTranscriptionResult,
|
||||||
|
SingleSegment,
|
||||||
|
SingleAlignedSegment,
|
||||||
|
SingleWordSegment,
|
||||||
|
SegmentData,
|
||||||
|
)
|
||||||
|
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||||
|
|
||||||
|
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
@ -32,6 +43,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||||
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
||||||
|
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
||||||
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
||||||
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
||||||
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
||||||
@ -39,11 +51,30 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||||
|
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
||||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||||
|
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
|
||||||
|
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||||
|
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
||||||
|
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
||||||
|
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
|
||||||
|
"ca": "softcatala/wav2vec2-large-xlsr-catala",
|
||||||
|
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
|
||||||
|
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
|
||||||
|
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
|
||||||
|
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
|
||||||
|
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
|
||||||
|
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
||||||
|
"ro": "gigant/romanian-wav2vec2",
|
||||||
|
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
||||||
|
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
||||||
|
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
||||||
|
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
|
||||||
|
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
# use default model
|
# use default model
|
||||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||||
@ -63,8 +94,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
|||||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
|
||||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||||
@ -80,14 +111,16 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
|||||||
|
|
||||||
|
|
||||||
def align(
|
def align(
|
||||||
transcript: Iterator[dict],
|
transcript: Iterable[SingleSegment],
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
align_model_metadata: dict,
|
align_model_metadata: dict,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
device: str,
|
device: str,
|
||||||
interpolate_method: str = "nearest",
|
interpolate_method: str = "nearest",
|
||||||
return_char_alignments: bool = False,
|
return_char_alignments: bool = False,
|
||||||
):
|
print_progress: bool = False,
|
||||||
|
combined_progress: bool = False,
|
||||||
|
) -> AlignedTranscriptionResult:
|
||||||
"""
|
"""
|
||||||
Align phoneme recognition predictions to known transcription.
|
Align phoneme recognition predictions to known transcription.
|
||||||
"""
|
"""
|
||||||
@ -106,8 +139,16 @@ def align(
|
|||||||
model_type = align_model_metadata["type"]
|
model_type = align_model_metadata["type"]
|
||||||
|
|
||||||
# 1. Preprocess to keep only characters in dictionary
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
|
total_segments = len(transcript)
|
||||||
|
# Store temporary processing values
|
||||||
|
segment_data: dict[int, SegmentData] = {}
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
|
if print_progress:
|
||||||
|
base_progress = ((sdx + 1) / total_segments) * 100
|
||||||
|
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||||
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
|
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
@ -133,60 +174,83 @@ def align(
|
|||||||
elif char_ in model_dictionary.keys():
|
elif char_ in model_dictionary.keys():
|
||||||
clean_char.append(char_)
|
clean_char.append(char_)
|
||||||
clean_cdx.append(cdx)
|
clean_cdx.append(cdx)
|
||||||
|
else:
|
||||||
|
# add placeholder
|
||||||
|
clean_char.append('*')
|
||||||
|
clean_cdx.append(cdx)
|
||||||
|
|
||||||
clean_wdx = []
|
clean_wdx = []
|
||||||
for wdx, wrd in enumerate(per_word):
|
for wdx, wrd in enumerate(per_word):
|
||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||||
|
clean_wdx.append(wdx)
|
||||||
|
else:
|
||||||
|
# index for placeholder
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
|
||||||
|
|
||||||
segment["clean_char"] = clean_char
|
punkt_param = PunktParameters()
|
||||||
segment["clean_cdx"] = clean_cdx
|
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||||
segment["clean_wdx"] = clean_wdx
|
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||||
segment["sentence_spans"] = sentence_spans
|
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||||
|
|
||||||
aligned_segments = []
|
segment_data[sdx] = {
|
||||||
|
"clean_char": clean_char,
|
||||||
|
"clean_cdx": clean_cdx,
|
||||||
|
"clean_wdx": clean_wdx,
|
||||||
|
"sentence_spans": sentence_spans
|
||||||
|
}
|
||||||
|
|
||||||
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
# 2. Get prediction matrix from alignment model & align
|
# 2. Get prediction matrix from alignment model & align
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
|
|
||||||
t1 = segment["start"]
|
t1 = segment["start"]
|
||||||
t2 = segment["end"]
|
t2 = segment["end"]
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
aligned_seg = {
|
aligned_seg: SingleAlignedSegment = {
|
||||||
"start": t1,
|
"start": t1,
|
||||||
"end": t2,
|
"end": t2,
|
||||||
"text": text,
|
"text": text,
|
||||||
"words": [],
|
"words": [],
|
||||||
|
"chars": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_char_alignments:
|
if return_char_alignments:
|
||||||
aligned_seg["chars"] = []
|
aligned_seg["chars"] = []
|
||||||
|
|
||||||
# check we can align
|
# check we can align
|
||||||
if len(segment["clean_char"]) == 0:
|
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
if t1 >= MAX_DURATION:
|
||||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_clean = "".join(segment["clean_char"])
|
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||||
tokens = [model_dictionary[c] for c in text_clean]
|
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
|
|
||||||
# TODO: Probably can get some speedup gain with batched inference here
|
# TODO: Probably can get some speedup gain with batched inference here
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
|
# Handle the minimum input length for wav2vec2 models
|
||||||
|
if waveform_segment.shape[-1] < 400:
|
||||||
|
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||||
|
waveform_segment = torch.nn.functional.pad(
|
||||||
|
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lengths = None
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if model_type == "torchaudio":
|
if model_type == "torchaudio":
|
||||||
emissions, _ = model(waveform_segment.to(device))
|
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||||
elif model_type == "huggingface":
|
elif model_type == "huggingface":
|
||||||
emissions = model(waveform_segment.to(device)).logits
|
emissions = model(waveform_segment.to(device)).logits
|
||||||
else:
|
else:
|
||||||
@ -201,7 +265,8 @@ def align(
|
|||||||
blank_id = code
|
blank_id = code
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens, blank_id)
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
path = backtrack(trellis, emission, tokens, blank_id)
|
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
@ -210,7 +275,7 @@ def align(
|
|||||||
|
|
||||||
char_segments = merge_repeats(path, text_clean)
|
char_segments = merge_repeats(path, text_clean)
|
||||||
|
|
||||||
duration = t2 -t1
|
duration = t2 - t1
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
|
|
||||||
# assign timestamps to aligned characters
|
# assign timestamps to aligned characters
|
||||||
@ -218,8 +283,8 @@ def align(
|
|||||||
word_idx = 0
|
word_idx = 0
|
||||||
for cdx, char in enumerate(text):
|
for cdx, char in enumerate(text):
|
||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in segment["clean_cdx"]:
|
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||||
start = round(char_seg.start * ratio + t1, 3)
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = round(char_seg.end * ratio + t1, 3)
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = round(char_seg.score, 3)
|
score = round(char_seg.score, 3)
|
||||||
@ -245,13 +310,14 @@ def align(
|
|||||||
aligned_subsegments = []
|
aligned_subsegments = []
|
||||||
# assign sentence_idx to each character index
|
# assign sentence_idx to each character index
|
||||||
char_segments_arr["sentence-idx"] = None
|
char_segments_arr["sentence-idx"] = None
|
||||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||||
|
|
||||||
sentence_text = text[sstart:send]
|
sentence_text = text[sstart:send]
|
||||||
sentence_start = curr_chars["start"].min()
|
sentence_start = curr_chars["start"].min()
|
||||||
sentence_end = curr_chars["end"].max()
|
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||||
|
sentence_end = end_chars["end"].max()
|
||||||
sentence_words = []
|
sentence_words = []
|
||||||
|
|
||||||
for word_idx in curr_chars["word-idx"].unique():
|
for word_idx in curr_chars["word-idx"].unique():
|
||||||
@ -259,6 +325,10 @@ def align(
|
|||||||
word_text = "".join(word_chars["char"].tolist()).strip()
|
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||||
if len(word_text) == 0:
|
if len(word_text) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# dont use space character for alignment
|
||||||
|
word_chars = word_chars[word_chars["char"] != " "]
|
||||||
|
|
||||||
word_start = word_chars["start"].min()
|
word_start = word_chars["start"].min()
|
||||||
word_end = word_chars["end"].max()
|
word_end = word_chars["end"].max()
|
||||||
word_score = round(word_chars["score"].mean(), 3)
|
word_score = round(word_chars["score"].mean(), 3)
|
||||||
@ -294,6 +364,8 @@ def align(
|
|||||||
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||||
# concatenate sentences with same timestamps
|
# concatenate sentences with same timestamps
|
||||||
agg_dict = {"text": " ".join, "words": "sum"}
|
agg_dict = {"text": " ".join, "words": "sum"}
|
||||||
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||||
|
agg_dict["text"] = "".join
|
||||||
if return_char_alignments:
|
if return_char_alignments:
|
||||||
agg_dict["chars"] = "sum"
|
agg_dict["chars"] = "sum"
|
||||||
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||||
@ -301,7 +373,7 @@ def align(
|
|||||||
aligned_segments += aligned_subsegments
|
aligned_segments += aligned_subsegments
|
||||||
|
|
||||||
# create word_segments list
|
# create word_segments list
|
||||||
word_segments = []
|
word_segments: List[SingleWordSegment] = []
|
||||||
for segment in aligned_segments:
|
for segment in aligned_segments:
|
||||||
word_segments += segment["words"]
|
word_segments += segment["words"]
|
||||||
|
|
||||||
@ -310,70 +382,203 @@ def align(
|
|||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_trellis(emission, tokens, blank_id=0):
|
def get_trellis(emission, tokens, blank_id=0):
|
||||||
num_frame = emission.size(0)
|
num_frame = emission.size(0)
|
||||||
num_tokens = len(tokens)
|
num_tokens = len(tokens)
|
||||||
|
|
||||||
# Trellis has extra diemsions for both time axis and tokens.
|
trellis = torch.zeros((num_frame, num_tokens))
|
||||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||||
# The extra dim for time axis is for simplification of the code.
|
trellis[0, 1:] = -float("inf")
|
||||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||||
trellis[0, 0] = 0
|
|
||||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
|
||||||
trellis[0, -num_tokens:] = -float("inf")
|
|
||||||
trellis[-num_tokens:, 0] = float("inf")
|
|
||||||
|
|
||||||
for t in range(num_frame):
|
for t in range(num_frame - 1):
|
||||||
trellis[t + 1, 1:] = torch.maximum(
|
trellis[t + 1, 1:] = torch.maximum(
|
||||||
# Score for staying at the same token
|
# Score for staying at the same token
|
||||||
trellis[t, 1:] + emission[t, blank_id],
|
trellis[t, 1:] + emission[t, blank_id],
|
||||||
# Score for changing to the next token
|
# Score for changing to the next token
|
||||||
trellis[t, :-1] + emission[t, tokens],
|
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||||
|
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||||
)
|
)
|
||||||
return trellis
|
return trellis
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||||
|
"""Processing token emission scores containing wildcards (vectorized version)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_emission: Emission probability vector for the current frame
|
||||||
|
tokens: List of token indices
|
||||||
|
blank_id: ID of the blank token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: Maximum probability score for each token position
|
||||||
|
"""
|
||||||
|
assert 0 <= blank_id < len(frame_emission)
|
||||||
|
|
||||||
|
# Convert tokens to a tensor if they are not already
|
||||||
|
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||||
|
|
||||||
|
# Create a mask to identify wildcard positions
|
||||||
|
wildcard_mask = (tokens == -1)
|
||||||
|
|
||||||
|
# Get scores for non-wildcard positions
|
||||||
|
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||||
|
|
||||||
|
# Create a mask and compute the maximum value without modifying frame_emission
|
||||||
|
max_valid_score = frame_emission.clone() # Create a copy
|
||||||
|
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||||
|
max_valid_score = max_valid_score.max()
|
||||||
|
|
||||||
|
# Use where operation to combine results
|
||||||
|
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Point:
|
class Point:
|
||||||
token_index: int
|
token_index: int
|
||||||
time_index: int
|
time_index: int
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||||
# Note:
|
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
# j and t are indices for trellis, which has extra dimensions
|
|
||||||
# for time and tokens at the beginning.
|
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||||
# When referring to time frame index `T` in trellis,
|
while j > 0:
|
||||||
# the corresponding index in emission is `T-1`.
|
# Should not happen but just in case
|
||||||
# Similarly, when referring to token index `J` in trellis,
|
assert t > 0
|
||||||
# the corresponding index in transcript is `J-1`.
|
|
||||||
j = trellis.size(1) - 1
|
|
||||||
t_start = torch.argmax(trellis[:, j]).item()
|
|
||||||
|
|
||||||
path = []
|
|
||||||
for t in range(t_start, 0, -1):
|
|
||||||
# 1. Figure out if the current position was stay or change
|
# 1. Figure out if the current position was stay or change
|
||||||
# Note (again):
|
# Frame-wise score of stay vs change
|
||||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
p_stay = emission[t - 1, blank_id]
|
||||||
# Score for token staying the same from time frame J-1 to T.
|
# p_change = emission[t - 1, tokens[j]]
|
||||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
# Score for token changing from C-1 at T-1 to J at T.
|
|
||||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
|
||||||
|
|
||||||
# 2. Store the path with frame-wise probability.
|
# Context-aware score for stay vs change
|
||||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
stayed = trellis[t - 1, j] + p_stay
|
||||||
# Return token index and time index in non-trellis coordinate.
|
changed = trellis[t - 1, j - 1] + p_change
|
||||||
path.append(Point(j - 1, t - 1, prob))
|
|
||||||
|
|
||||||
# 3. Update the token
|
# Update position
|
||||||
|
t -= 1
|
||||||
if changed > stayed:
|
if changed > stayed:
|
||||||
j -= 1
|
j -= 1
|
||||||
if j == 0:
|
|
||||||
break
|
# Store the path with frame-wise probability.
|
||||||
else:
|
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||||
# failed
|
path.append(Point(j, t, prob))
|
||||||
return None
|
|
||||||
|
# Now j == 0, which means, it reached the SoS.
|
||||||
|
# Fill up the rest for the sake of visualization
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
return path[::-1]
|
return path[::-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Path:
|
||||||
|
points: List[Point]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamState:
|
||||||
|
"""State in beam search."""
|
||||||
|
token_index: int # Current token position
|
||||||
|
time_index: int # Current time step
|
||||||
|
score: float # Cumulative score
|
||||||
|
path: List[Point] # Path history
|
||||||
|
|
||||||
|
|
||||||
|
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||||
|
"""Standard CTC beam search backtracking implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||||
|
and N is the number of tokens (including the blank token).
|
||||||
|
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||||
|
tokens (List[int]): List of token indices (excluding the blank token).
|
||||||
|
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||||
|
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Point]: the best path
|
||||||
|
"""
|
||||||
|
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
|
|
||||||
|
init_state = BeamState(
|
||||||
|
token_index=J,
|
||||||
|
time_index=T,
|
||||||
|
score=trellis[T, J],
|
||||||
|
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||||
|
)
|
||||||
|
|
||||||
|
beams = [init_state]
|
||||||
|
|
||||||
|
while beams and beams[0].token_index > 0:
|
||||||
|
next_beams = []
|
||||||
|
|
||||||
|
for beam in beams:
|
||||||
|
t, j = beam.time_index, beam.token_index
|
||||||
|
|
||||||
|
if t <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p_stay = emission[t - 1, blank_id]
|
||||||
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
|
|
||||||
|
stay_score = trellis[t - 1, j]
|
||||||
|
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||||
|
|
||||||
|
# Stay
|
||||||
|
if not math.isinf(stay_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=stay_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# Change
|
||||||
|
if j > 0 and not math.isinf(change_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j - 1,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=change_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_beam = beams[0]
|
||||||
|
t = best_beam.time_index
|
||||||
|
j = best_beam.token_index
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
best_beam.path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
|
return best_beam.path[::-1]
|
||||||
|
|
||||||
|
|
||||||
# Merge the labels
|
# Merge the labels
|
||||||
@dataclass
|
@dataclass
|
||||||
class Segment:
|
class Segment:
|
||||||
|
346
whisperx/asr.py
346
whisperx/asr.py
@ -1,79 +1,29 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
from typing import List, Optional, Union
|
||||||
from typing import List, Union
|
from dataclasses import replace
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import faster_whisper
|
import faster_whisper
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
from .vad import load_vad_model, merge_chunks
|
from whisperx.types import SingleSegment, TranscriptionResult
|
||||||
|
from whisperx.vads import Vad, Silero, Pyannote
|
||||||
|
|
||||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
|
||||||
vad_options=None, model=None):
|
|
||||||
'''Load a Whisper model for inference.
|
|
||||||
Args:
|
|
||||||
whisper_arch: str - The name of the Whisper model to load.
|
|
||||||
device: str - The device to load the model on.
|
|
||||||
compute_type: str - The compute type to use for the model.
|
|
||||||
options: dict - A dictionary of options to use for the model.
|
|
||||||
language: str - The language of the model. (use English for now)
|
|
||||||
Returns:
|
|
||||||
A Whisper pipeline.
|
|
||||||
'''
|
|
||||||
|
|
||||||
if whisper_arch.endswith(".en"):
|
|
||||||
language = "en"
|
|
||||||
|
|
||||||
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
|
|
||||||
if language is not None:
|
|
||||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
|
|
||||||
else:
|
|
||||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
default_asr_options = {
|
|
||||||
"beam_size": 5,
|
|
||||||
"best_of": 5,
|
|
||||||
"patience": 1,
|
|
||||||
"length_penalty": 1,
|
|
||||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
|
||||||
"compression_ratio_threshold": 2.4,
|
|
||||||
"log_prob_threshold": -1.0,
|
|
||||||
"no_speech_threshold": 0.6,
|
|
||||||
"condition_on_previous_text": False,
|
|
||||||
"initial_prompt": None,
|
|
||||||
"prefix": None,
|
|
||||||
"suppress_blank": True,
|
|
||||||
"suppress_tokens": [-1],
|
|
||||||
"without_timestamps": True,
|
|
||||||
"max_initial_timestamp": 0.0,
|
|
||||||
"word_timestamps": False,
|
|
||||||
"prepend_punctuations": "\"'“¿([{-",
|
|
||||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
|
||||||
}
|
|
||||||
|
|
||||||
if asr_options is not None:
|
|
||||||
default_asr_options.update(asr_options)
|
|
||||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
|
||||||
|
|
||||||
default_vad_options = {
|
|
||||||
"vad_onset": 0.500,
|
|
||||||
"vad_offset": 0.363
|
|
||||||
}
|
|
||||||
|
|
||||||
if vad_options is not None:
|
|
||||||
default_vad_options.update(vad_options)
|
|
||||||
|
|
||||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
|
||||||
|
|
||||||
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
|
numeral_symbol_tokens = []
|
||||||
|
for i in range(tokenizer.eot):
|
||||||
|
token = tokenizer.decode([i]).removeprefix(" ")
|
||||||
|
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
|
||||||
|
if has_numeral_symbol:
|
||||||
|
numeral_symbol_tokens.append(i)
|
||||||
|
return numeral_symbol_tokens
|
||||||
|
|
||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
'''
|
'''
|
||||||
@ -81,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
def generate_segment_batched(
|
||||||
|
self,
|
||||||
|
features: np.ndarray,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: TranscriptionOptions,
|
||||||
|
encoder_output=None,
|
||||||
|
):
|
||||||
batch_size = features.shape[0]
|
batch_size = features.shape[0]
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
@ -95,6 +51,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
|
hotwords=options.hotwords
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_output = self.encode(features)
|
encoder_output = self.encode(features)
|
||||||
@ -106,13 +63,12 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
encoder_output,
|
encoder_output,
|
||||||
[prompt] * batch_size,
|
[prompt] * batch_size,
|
||||||
# length_penalty=options.length_penalty,
|
beam_size=options.beam_size,
|
||||||
# max_length=self.max_length,
|
patience=options.patience,
|
||||||
# return_scores=True,
|
length_penalty=options.length_penalty,
|
||||||
# return_no_speech_prob=True,
|
max_length=self.max_length,
|
||||||
# suppress_blank=options.suppress_blank,
|
suppress_blank=options.suppress_blank,
|
||||||
# suppress_tokens=options.suppress_tokens,
|
suppress_tokens=options.suppress_tokens,
|
||||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||||
@ -135,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
# unsqueeze if batch size = 1
|
# unsqueeze if batch size = 1
|
||||||
if len(features.shape) == 2:
|
if len(features.shape) == 2:
|
||||||
features = np.expand_dims(features, 0)
|
features = np.expand_dims(features, 0)
|
||||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
features = get_ctranslate2_storage(features)
|
||||||
|
|
||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
@ -148,18 +104,23 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# - add support for custom inference kwargs
|
# - add support for custom inference kwargs
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: WhisperModel,
|
||||||
vad,
|
vad,
|
||||||
options,
|
vad_params: dict,
|
||||||
tokenizer=None,
|
options: TranscriptionOptions,
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
tokenizer: Optional[Tokenizer] = None,
|
||||||
framework = "pt",
|
device: Union[int, str, "torch.device"] = -1,
|
||||||
**kwargs
|
framework="pt",
|
||||||
|
language: Optional[str] = None,
|
||||||
|
suppress_numerals: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.options = options
|
self.options = options
|
||||||
|
self.preset_language = language
|
||||||
|
self.suppress_numerals = suppress_numerals
|
||||||
self._batch_size = kwargs.pop("batch_size", None)
|
self._batch_size = kwargs.pop("batch_size", None)
|
||||||
self._num_workers = 1
|
self._num_workers = 1
|
||||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||||
@ -179,6 +140,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
super(Pipeline, self).__init__()
|
super(Pipeline, self).__init__()
|
||||||
self.vad_model = vad
|
self.vad_model = vad
|
||||||
|
self._vad_params = vad_params
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
preprocess_kwargs = {}
|
preprocess_kwargs = {}
|
||||||
@ -188,7 +150,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
def preprocess(self, audio):
|
def preprocess(self, audio):
|
||||||
audio = audio['inputs']
|
audio = audio['inputs']
|
||||||
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
|
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||||
|
features = log_mel_spectrogram(
|
||||||
|
audio,
|
||||||
|
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||||
|
padding=N_SAMPLES - audio.shape[0],
|
||||||
|
)
|
||||||
return {'inputs': features}
|
return {'inputs': features}
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
@ -199,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def get_iterator(
|
def get_iterator(
|
||||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
self,
|
||||||
|
inputs,
|
||||||
|
num_workers: int,
|
||||||
|
batch_size: int,
|
||||||
|
preprocess_params: dict,
|
||||||
|
forward_params: dict,
|
||||||
|
postprocess_params: dict,
|
||||||
):
|
):
|
||||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||||
@ -214,8 +187,17 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
self,
|
||||||
):
|
audio: Union[str, np.ndarray],
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
num_workers=0,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
task: Optional[str] = None,
|
||||||
|
chunk_size=30,
|
||||||
|
print_progress=False,
|
||||||
|
combined_progress=False,
|
||||||
|
verbose=False,
|
||||||
|
) -> TranscriptionResult:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
|
|
||||||
@ -226,41 +208,87 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# print(f2-f1)
|
# print(f2-f1)
|
||||||
yield {'inputs': audio[f1:f2]}
|
yield {'inputs': audio[f1:f2]}
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||||
vad_segments = merge_chunks(vad_segments, 30)
|
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||||
|
if issubclass(type(self.vad_model), Vad):
|
||||||
del_tokenizer = False
|
waveform = self.vad_model.preprocess_audio(audio)
|
||||||
if self.tokenizer is None:
|
merge_chunks = self.vad_model.merge_chunks
|
||||||
language = self.detect_language(audio)
|
|
||||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
|
|
||||||
del_tokenizer = True
|
|
||||||
else:
|
else:
|
||||||
language = self.tokenizer.language_code
|
waveform = Pyannote.preprocess_audio(audio)
|
||||||
|
merge_chunks = Pyannote.merge_chunks
|
||||||
|
|
||||||
segments = []
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||||
|
vad_segments = merge_chunks(
|
||||||
|
vad_segments,
|
||||||
|
chunk_size,
|
||||||
|
onset=self._vad_params["vad_onset"],
|
||||||
|
offset=self._vad_params["vad_offset"],
|
||||||
|
)
|
||||||
|
if self.tokenizer is None:
|
||||||
|
language = language or self.detect_language(audio)
|
||||||
|
task = task or "transcribe"
|
||||||
|
self.tokenizer = Tokenizer(
|
||||||
|
self.model.hf_tokenizer,
|
||||||
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
language = language or self.tokenizer.language_code
|
||||||
|
task = task or self.tokenizer.task
|
||||||
|
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||||
|
self.tokenizer = Tokenizer(
|
||||||
|
self.model.hf_tokenizer,
|
||||||
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.suppress_numerals:
|
||||||
|
previous_suppress_tokens = self.options.suppress_tokens
|
||||||
|
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||||
|
print(f"Suppressing numeral and symbol tokens")
|
||||||
|
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||||
|
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||||
|
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||||
|
|
||||||
|
segments: List[SingleSegment] = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
|
total_segments = len(vad_segments)
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
|
if print_progress:
|
||||||
|
base_progress = ((idx + 1) / total_segments) * 100
|
||||||
|
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||||
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
text = out['text']
|
text = out['text']
|
||||||
if batch_size in [0, 1, None]:
|
if batch_size in [0, 1, None]:
|
||||||
text = text[0]
|
text = text[0]
|
||||||
|
if verbose:
|
||||||
|
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
|
||||||
segments.append(
|
segments.append(
|
||||||
{
|
{
|
||||||
"text": out['text'],
|
"text": text,
|
||||||
"start": round(vad_segments[idx]['start'], 3),
|
"start": round(vad_segments[idx]['start'], 3),
|
||||||
"end": round(vad_segments[idx]['end'], 3)
|
"end": round(vad_segments[idx]['end'], 3)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if del_tokenizer:
|
# revert the tokenizer if multilingual inference is enabled
|
||||||
|
if self.preset_language is None:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# revert suppressed tokens if suppress_numerals is enabled
|
||||||
|
if self.suppress_numerals:
|
||||||
|
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
|
||||||
|
|
||||||
return {"segments": segments, "language": language}
|
return {"segments": segments, "language": language}
|
||||||
|
|
||||||
|
def detect_language(self, audio: np.ndarray) -> str:
|
||||||
def detect_language(self, audio: np.ndarray):
|
|
||||||
if audio.shape[0] < N_SAMPLES:
|
if audio.shape[0] < N_SAMPLES:
|
||||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||||
|
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||||
|
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||||
encoder_output = self.model.encode(segment)
|
encoder_output = self.model.encode(segment)
|
||||||
results = self.model.model.detect_language(encoder_output)
|
results = self.model.model.detect_language(encoder_output)
|
||||||
@ -268,3 +296,121 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
whisper_arch: str,
|
||||||
|
device: str,
|
||||||
|
device_index=0,
|
||||||
|
compute_type="float16",
|
||||||
|
asr_options: Optional[dict] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
vad_model: Optional[Vad]= None,
|
||||||
|
vad_method: Optional[str] = "pyannote",
|
||||||
|
vad_options: Optional[dict] = None,
|
||||||
|
model: Optional[WhisperModel] = None,
|
||||||
|
task="transcribe",
|
||||||
|
download_root: Optional[str] = None,
|
||||||
|
local_files_only=False,
|
||||||
|
threads=4,
|
||||||
|
) -> FasterWhisperPipeline:
|
||||||
|
"""Load a Whisper model for inference.
|
||||||
|
Args:
|
||||||
|
whisper_arch - The name of the Whisper model to load.
|
||||||
|
device - The device to load the model on.
|
||||||
|
compute_type - The compute type to use for the model.
|
||||||
|
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||||
|
options - A dictionary of options to use for the model.
|
||||||
|
language - The language of the model. (use English for now)
|
||||||
|
model - The WhisperModel instance to use.
|
||||||
|
download_root - The root directory to download the model to.
|
||||||
|
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
||||||
|
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||||
|
Returns:
|
||||||
|
A Whisper pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if whisper_arch.endswith(".en"):
|
||||||
|
language = "en"
|
||||||
|
|
||||||
|
model = model or WhisperModel(whisper_arch,
|
||||||
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
|
compute_type=compute_type,
|
||||||
|
download_root=download_root,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
cpu_threads=threads)
|
||||||
|
if language is not None:
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||||
|
else:
|
||||||
|
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
default_asr_options = {
|
||||||
|
"beam_size": 5,
|
||||||
|
"best_of": 5,
|
||||||
|
"patience": 1,
|
||||||
|
"length_penalty": 1,
|
||||||
|
"repetition_penalty": 1,
|
||||||
|
"no_repeat_ngram_size": 0,
|
||||||
|
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||||
|
"compression_ratio_threshold": 2.4,
|
||||||
|
"log_prob_threshold": -1.0,
|
||||||
|
"no_speech_threshold": 0.6,
|
||||||
|
"condition_on_previous_text": False,
|
||||||
|
"prompt_reset_on_temperature": 0.5,
|
||||||
|
"initial_prompt": None,
|
||||||
|
"prefix": None,
|
||||||
|
"suppress_blank": True,
|
||||||
|
"suppress_tokens": [-1],
|
||||||
|
"without_timestamps": True,
|
||||||
|
"max_initial_timestamp": 0.0,
|
||||||
|
"word_timestamps": False,
|
||||||
|
"prepend_punctuations": "\"'“¿([{-",
|
||||||
|
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||||
|
"multilingual": model.model.is_multilingual,
|
||||||
|
"suppress_numerals": False,
|
||||||
|
"max_new_tokens": None,
|
||||||
|
"clip_timestamps": None,
|
||||||
|
"hallucination_silence_threshold": None,
|
||||||
|
"hotwords": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if asr_options is not None:
|
||||||
|
default_asr_options.update(asr_options)
|
||||||
|
|
||||||
|
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||||
|
del default_asr_options["suppress_numerals"]
|
||||||
|
|
||||||
|
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
|
default_vad_options = {
|
||||||
|
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||||
|
"vad_onset": 0.500,
|
||||||
|
"vad_offset": 0.363
|
||||||
|
}
|
||||||
|
|
||||||
|
if vad_options is not None:
|
||||||
|
default_vad_options.update(vad_options)
|
||||||
|
|
||||||
|
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||||
|
if vad_model is not None:
|
||||||
|
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||||
|
vad_model = vad_model
|
||||||
|
else:
|
||||||
|
if vad_method == "silero":
|
||||||
|
vad_model = Silero(**default_vad_options)
|
||||||
|
elif vad_method == "pyannote":
|
||||||
|
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||||
|
|
||||||
|
return FasterWhisperPipeline(
|
||||||
|
model=model,
|
||||||
|
vad=vad_model,
|
||||||
|
options=default_asr_options,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
language=language,
|
||||||
|
suppress_numerals=suppress_numerals,
|
||||||
|
vad_params=default_vad_options,
|
||||||
|
)
|
||||||
|
Binary file not shown.
BIN
whisperx/assets/pytorch_model.bin
Normal file
BIN
whisperx/assets/pytorch_model.bin
Normal file
Binary file not shown.
@ -1,18 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import ffmpeg
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .utils import exact_div
|
from whisperx.utils import exact_div
|
||||||
|
|
||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
N_FFT = 400
|
N_FFT = 400
|
||||||
N_MELS = 80
|
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
@ -23,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
|||||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
@ -40,14 +39,27 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
A NumPy array containing the audio waveform, in float32 dtype.
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
# Requires the ffmpeg CLI to be installed.
|
||||||
out, _ = (
|
cmd = [
|
||||||
ffmpeg.input(file, threads=0)
|
"ffmpeg",
|
||||||
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
"-nostdin",
|
||||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
"-threads",
|
||||||
)
|
"0",
|
||||||
except ffmpeg.Error as e:
|
"-i",
|
||||||
|
file,
|
||||||
|
"-f",
|
||||||
|
"s16le",
|
||||||
|
"-ac",
|
||||||
|
"1",
|
||||||
|
"-acodec",
|
||||||
|
"pcm_s16le",
|
||||||
|
"-ar",
|
||||||
|
str(sr),
|
||||||
|
"-",
|
||||||
|
]
|
||||||
|
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
@ -80,7 +92,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
Allows decoupling librosa dependency; saved using:
|
Allows decoupling librosa dependency; saved using:
|
||||||
@ -90,7 +102,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
|||||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
||||||
with np.load(
|
with np.load(
|
||||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||||
) as f:
|
) as f:
|
||||||
@ -99,7 +111,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
|||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
n_mels: int = N_MELS,
|
n_mels: int,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
|
47
whisperx/conjunctions.py
Normal file
47
whisperx/conjunctions.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# conjunctions.py
|
||||||
|
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
|
||||||
|
conjunctions_by_language = {
|
||||||
|
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
|
||||||
|
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
|
||||||
|
'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
|
||||||
|
'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
|
||||||
|
'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'},
|
||||||
|
'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
|
||||||
|
'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
|
||||||
|
'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
|
||||||
|
'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
|
||||||
|
'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'},
|
||||||
|
'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
|
||||||
|
'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
|
||||||
|
'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
|
||||||
|
'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
|
||||||
|
'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
|
||||||
|
'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
|
||||||
|
'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
|
||||||
|
'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
|
||||||
|
'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
|
||||||
|
'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
|
||||||
|
'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
|
||||||
|
'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'},
|
||||||
|
'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'},
|
||||||
|
'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
|
||||||
|
'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
commas_by_language = {
|
||||||
|
'ja': '、',
|
||||||
|
'zh': ',',
|
||||||
|
'fa': '،',
|
||||||
|
'ur': '،'
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_conjunctions(lang_code: str) -> Set[str]:
|
||||||
|
return conjunctions_by_language.get(lang_code, set())
|
||||||
|
|
||||||
|
|
||||||
|
def get_comma(lang_code: str) -> str:
|
||||||
|
return commas_by_language.get(lang_code, ",")
|
@ -4,10 +4,14 @@ from pyannote.audio import Pipeline
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||||
|
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
||||||
|
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization@2.1",
|
model_name="pyannote/speaker-diarization-3.1",
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
device: Optional[Union[str, torch.device]] = "cpu",
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
@ -15,16 +19,31 @@ class DiarizationPipeline:
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
def __call__(
|
||||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
self,
|
||||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
audio: Union[str, np.ndarray],
|
||||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
num_speakers: Optional[int] = None,
|
||||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
min_speakers: Optional[int] = None,
|
||||||
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
max_speakers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
|
audio_data = {
|
||||||
|
'waveform': torch.from_numpy(audio[None, :]),
|
||||||
|
'sample_rate': SAMPLE_RATE
|
||||||
|
}
|
||||||
|
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
|
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
||||||
|
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
||||||
|
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
||||||
return diarize_df
|
return diarize_df
|
||||||
|
|
||||||
|
|
||||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
def assign_word_speakers(
|
||||||
|
diarize_df: pd.DataFrame,
|
||||||
|
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||||
|
fill_nearest=False,
|
||||||
|
) -> dict:
|
||||||
transcript_segments = transcript_result["segments"]
|
transcript_segments = transcript_result["segments"]
|
||||||
for seg in transcript_segments:
|
for seg in transcript_segments:
|
||||||
# assign speaker to segment (if any)
|
# assign speaker to segment (if any)
|
||||||
@ -60,7 +79,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
|||||||
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
self.speaker = speaker
|
self.speaker = speaker
|
||||||
|
@ -1,17 +1,27 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
import importlib.metadata
|
||||||
|
import platform
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .alignment import align, load_align_model
|
from whisperx.alignment import align, load_align_model
|
||||||
from .asr import load_model
|
from whisperx.asr import load_model
|
||||||
from .audio import load_audio
|
from whisperx.audio import load_audio
|
||||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||||
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
|
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
||||||
optional_int, str2bool)
|
from whisperx.utils import (
|
||||||
|
LANGUAGES,
|
||||||
|
TO_LANGUAGE_CODE,
|
||||||
|
get_writer,
|
||||||
|
optional_float,
|
||||||
|
optional_int,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@ -19,13 +29,15 @@ def cli():
|
|||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||||
|
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
|
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||||
|
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||||
|
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
@ -38,8 +50,10 @@ def cli():
|
|||||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||||
|
|
||||||
# vad params
|
# vad params
|
||||||
|
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||||
|
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||||
|
|
||||||
# diarization params
|
# diarization params
|
||||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||||
@ -49,10 +63,12 @@ def cli():
|
|||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
|
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||||
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
@ -63,22 +79,30 @@ def cli():
|
|||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
|
|
||||||
|
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||||
|
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||||
|
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
batch_size: int = args.pop("batch_size")
|
batch_size: int = args.pop("batch_size")
|
||||||
|
model_dir: str = args.pop("model_dir")
|
||||||
|
model_cache_only: bool = args.pop("model_cache_only")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
device: str = args.pop("device")
|
device: str = args.pop("device")
|
||||||
|
device_index: int = args.pop("device_index")
|
||||||
compute_type: str = args.pop("compute_type")
|
compute_type: str = args.pop("compute_type")
|
||||||
|
verbose: bool = args.pop("verbose")
|
||||||
|
|
||||||
# model_flush: bool = args.pop("model_flush")
|
# model_flush: bool = args.pop("model_flush")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@ -86,23 +110,42 @@ def cli():
|
|||||||
align_model: str = args.pop("align_model")
|
align_model: str = args.pop("align_model")
|
||||||
interpolate_method: str = args.pop("interpolate_method")
|
interpolate_method: str = args.pop("interpolate_method")
|
||||||
no_align: bool = args.pop("no_align")
|
no_align: bool = args.pop("no_align")
|
||||||
|
task: str = args.pop("task")
|
||||||
|
if task == "translate":
|
||||||
|
# translation cannot be aligned
|
||||||
|
no_align = True
|
||||||
|
|
||||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||||
|
|
||||||
hf_token: str = args.pop("hf_token")
|
hf_token: str = args.pop("hf_token")
|
||||||
|
vad_method: str = args.pop("vad_method")
|
||||||
vad_onset: float = args.pop("vad_onset")
|
vad_onset: float = args.pop("vad_onset")
|
||||||
vad_offset: float = args.pop("vad_offset")
|
vad_offset: float = args.pop("vad_offset")
|
||||||
|
|
||||||
|
chunk_size: int = args.pop("chunk_size")
|
||||||
|
|
||||||
diarize: bool = args.pop("diarize")
|
diarize: bool = args.pop("diarize")
|
||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
print_progress: bool = args.pop("print_progress")
|
||||||
|
|
||||||
|
if args["language"] is not None:
|
||||||
|
args["language"] = args["language"].lower()
|
||||||
|
if args["language"] not in LANGUAGES:
|
||||||
|
if args["language"] in TO_LANGUAGE_CODE:
|
||||||
|
args["language"] = TO_LANGUAGE_CODE[args["language"]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported language: {args['language']}")
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
if model_name.endswith(".en") and args["language"] != "en":
|
||||||
if args["language"] is not None:
|
if args["language"] is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||||
)
|
)
|
||||||
args["language"] = "en"
|
args["language"] = "en"
|
||||||
|
align_language = (
|
||||||
|
args["language"] if args["language"] is not None else "en"
|
||||||
|
) # default to loading english if not specified
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
temperature = args.pop("temperature")
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||||
@ -110,8 +153,10 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
temperature = [temperature]
|
temperature = [temperature]
|
||||||
|
|
||||||
|
faster_whisper_threads = 4
|
||||||
if (threads := args.pop("threads")) > 0:
|
if (threads := args.pop("threads")) > 0:
|
||||||
torch.set_num_threads(threads)
|
torch.set_num_threads(threads)
|
||||||
|
faster_whisper_threads = threads
|
||||||
|
|
||||||
asr_options = {
|
asr_options = {
|
||||||
"beam_size": args.pop("beam_size"),
|
"beam_size": args.pop("beam_size"),
|
||||||
@ -123,6 +168,8 @@ def cli():
|
|||||||
"no_speech_threshold": args.pop("no_speech_threshold"),
|
"no_speech_threshold": args.pop("no_speech_threshold"),
|
||||||
"condition_on_previous_text": False,
|
"condition_on_previous_text": False,
|
||||||
"initial_prompt": args.pop("initial_prompt"),
|
"initial_prompt": args.pop("initial_prompt"),
|
||||||
|
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
||||||
|
"suppress_numerals": args.pop("suppress_numerals"),
|
||||||
}
|
}
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
@ -130,7 +177,7 @@ def cli():
|
|||||||
if no_align:
|
if no_align:
|
||||||
for option in word_options:
|
for option in word_options:
|
||||||
if args[option]:
|
if args[option]:
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
parser.error(f"--{option} not possible with --no_align")
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
@ -139,13 +186,36 @@ def cli():
|
|||||||
results = []
|
results = []
|
||||||
tmp_results = []
|
tmp_results = []
|
||||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
|
model = load_model(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
|
download_root=model_dir,
|
||||||
|
compute_type=compute_type,
|
||||||
|
language=args["language"],
|
||||||
|
asr_options=asr_options,
|
||||||
|
vad_method=vad_method,
|
||||||
|
vad_options={
|
||||||
|
"chunk_size": chunk_size,
|
||||||
|
"vad_onset": vad_onset,
|
||||||
|
"vad_offset": vad_offset,
|
||||||
|
},
|
||||||
|
task=task,
|
||||||
|
local_files_only=model_cache_only,
|
||||||
|
threads=faster_whisper_threads,
|
||||||
|
)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
# >> VAD & ASR
|
# >> VAD & ASR
|
||||||
print(">>Performing transcription...")
|
print(">>Performing transcription...")
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
result: TranscriptionResult = model.transcribe(
|
||||||
|
audio,
|
||||||
|
batch_size=batch_size,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
print_progress=print_progress,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
# Unload Whisper and VAD
|
# Unload Whisper and VAD
|
||||||
@ -157,8 +227,9 @@ def cli():
|
|||||||
if not no_align:
|
if not no_align:
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
results = []
|
results = []
|
||||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
align_model, align_metadata = load_align_model(
|
||||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
align_language, device, model_name=align_model
|
||||||
|
)
|
||||||
for result, audio_path in tmp_results:
|
for result, audio_path in tmp_results:
|
||||||
# >> Align
|
# >> Align
|
||||||
if len(tmp_results) > 1:
|
if len(tmp_results) > 1:
|
||||||
@ -170,10 +241,23 @@ def cli():
|
|||||||
if align_model is not None and len(result["segments"]) > 0:
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
# load new language
|
# load new language
|
||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
print(
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||||
|
)
|
||||||
|
align_model, align_metadata = load_align_model(
|
||||||
|
result["language"], device
|
||||||
|
)
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
result: AlignedTranscriptionResult = align(
|
||||||
|
result["segments"],
|
||||||
|
align_model,
|
||||||
|
align_metadata,
|
||||||
|
input_audio,
|
||||||
|
device,
|
||||||
|
interpolate_method=interpolate_method,
|
||||||
|
return_char_alignments=return_char_alignments,
|
||||||
|
print_progress=print_progress,
|
||||||
|
)
|
||||||
|
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
@ -185,18 +269,24 @@ def cli():
|
|||||||
# >> Diarize
|
# >> Diarize
|
||||||
if diarize:
|
if diarize:
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
print(
|
||||||
|
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
||||||
|
)
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(
|
||||||
|
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||||
|
)
|
||||||
result = assign_word_speakers(diarize_segments, result)
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
results.append((result, input_audio_path))
|
results.append((result, input_audio_path))
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
|
result["language"] = align_language
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
69
whisperx/types.py
Normal file
69
whisperx/types.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from typing import TypedDict, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class SingleWordSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single word of a speech.
|
||||||
|
"""
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
class SingleCharSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single char of a speech.
|
||||||
|
"""
|
||||||
|
char: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single segment (up to multiple sentences) of a speech.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentData(TypedDict):
|
||||||
|
"""
|
||||||
|
Temporary processing data used during alignment.
|
||||||
|
Contains cleaned and preprocessed data for each segment.
|
||||||
|
"""
|
||||||
|
clean_char: List[str] # Cleaned characters that exist in model dictionary
|
||||||
|
clean_cdx: List[int] # Original indices of cleaned characters
|
||||||
|
clean_wdx: List[int] # Indices of words containing valid characters
|
||||||
|
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
|
||||||
|
|
||||||
|
|
||||||
|
class SingleAlignedSegment(TypedDict):
|
||||||
|
"""
|
||||||
|
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
words: List[SingleWordSegment]
|
||||||
|
chars: Optional[List[SingleCharSegment]]
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResult(TypedDict):
|
||||||
|
"""
|
||||||
|
A list of segments and word segments of a speech.
|
||||||
|
"""
|
||||||
|
segments: List[SingleSegment]
|
||||||
|
language: str
|
||||||
|
|
||||||
|
|
||||||
|
class AlignedTranscriptionResult(TypedDict):
|
||||||
|
"""
|
||||||
|
A list of segments and word segments of a speech.
|
||||||
|
"""
|
||||||
|
segments: List[SingleAlignedSegment]
|
||||||
|
word_segments: List[SingleWordSegment]
|
@ -105,6 +105,7 @@ LANGUAGES = {
|
|||||||
"ba": "bashkir",
|
"ba": "bashkir",
|
||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
@ -123,6 +124,7 @@ TO_LANGUAGE_CODE = {
|
|||||||
"castilian": "es",
|
"castilian": "es",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@ -212,7 +214,12 @@ class WriteTXT(ResultWriter):
|
|||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
speaker = segment.get("speaker")
|
||||||
|
text = segment["text"].strip()
|
||||||
|
if speaker is not None:
|
||||||
|
print(f"[{speaker}]: {text}", file=file, flush=True)
|
||||||
|
else:
|
||||||
|
print(text, file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
class SubtitlesWriter(ResultWriter):
|
class SubtitlesWriter(ResultWriter):
|
||||||
@ -226,12 +233,15 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||||
|
|
||||||
|
if len(result["segments"]) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
def iterate_subtitles():
|
def iterate_subtitles():
|
||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: list[dict] = []
|
||||||
times = []
|
times: list[tuple] = []
|
||||||
last = result["segments"][0]["start"]
|
last = result["segments"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
for i, original_timing in enumerate(segment["words"]):
|
||||||
@ -277,7 +287,10 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
sstart, ssend, speaker = _[0]
|
sstart, ssend, speaker = _[0]
|
||||||
subtitle_start = self.format_timestamp(sstart)
|
subtitle_start = self.format_timestamp(sstart)
|
||||||
subtitle_end = self.format_timestamp(ssend)
|
subtitle_end = self.format_timestamp(ssend)
|
||||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
if result["language"] in LANGUAGES_WITHOUT_SPACES:
|
||||||
|
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||||
|
else:
|
||||||
|
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||||
has_timing = any(["start" in word for word in subtitle])
|
has_timing = any(["start" in word for word in subtitle])
|
||||||
|
|
||||||
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||||
@ -293,7 +306,7 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
start = self.format_timestamp(this_word["start"])
|
start = self.format_timestamp(this_word["start"])
|
||||||
end = self.format_timestamp(this_word["end"])
|
end = self.format_timestamp(this_word["end"])
|
||||||
if last != start:
|
if last != start:
|
||||||
yield last, start, subtitle_text
|
yield last, start, prefix + subtitle_text
|
||||||
|
|
||||||
yield start, end, prefix + " ".join(
|
yield start, end, prefix + " ".join(
|
||||||
[
|
[
|
||||||
@ -365,12 +378,34 @@ class WriteTSV(ResultWriter):
|
|||||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||||
|
|
||||||
|
class WriteAudacity(ResultWriter):
|
||||||
|
"""
|
||||||
|
Write a transcript to a text file that audacity can import as labels.
|
||||||
|
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
|
||||||
|
Yet this is not an audacity project but only a label file!
|
||||||
|
|
||||||
|
Please note : Audacity uses seconds in timestamps not ms!
|
||||||
|
Also there is no header expected.
|
||||||
|
|
||||||
|
If speaker is provided it is prepended to the text between double square brackets [[]].
|
||||||
|
"""
|
||||||
|
|
||||||
|
extension: str = "aud"
|
||||||
|
|
||||||
|
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||||
|
ARROW = " "
|
||||||
|
for segment in result["segments"]:
|
||||||
|
print(segment["start"], file=file, end=ARROW)
|
||||||
|
print(segment["end"], file=file, end=ARROW)
|
||||||
|
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||||
json.dump(result, file)
|
json.dump(result, file, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def get_writer(
|
def get_writer(
|
||||||
@ -383,6 +418,9 @@ def get_writer(
|
|||||||
"tsv": WriteTSV,
|
"tsv": WriteTSV,
|
||||||
"json": WriteJSON,
|
"json": WriteJSON,
|
||||||
}
|
}
|
||||||
|
optional_writers = {
|
||||||
|
"aud": WriteAudacity,
|
||||||
|
}
|
||||||
|
|
||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
@ -393,6 +431,8 @@ def get_writer(
|
|||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
|
if output_format in optional_writers:
|
||||||
|
return optional_writers[output_format](output_dir)
|
||||||
return writers[output_format](output_dir)
|
return writers[output_format](output_dir)
|
||||||
|
|
||||||
def interpolate_nans(x, method='nearest'):
|
def interpolate_nans(x, method='nearest'):
|
||||||
|
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from whisperx.vads.pyannote import Pyannote as Pyannote
|
||||||
|
from whisperx.vads.silero import Silero as Silero
|
||||||
|
from whisperx.vads.vad import Vad as Vad
|
@ -1,52 +1,41 @@
|
|||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import urllib
|
from typing import Callable, Text, Union
|
||||||
from typing import Callable, Optional, Text, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
from pyannote.audio import Model
|
from pyannote.audio import Model
|
||||||
from pyannote.audio.core.io import AudioFile
|
from pyannote.audio.core.io import AudioFile
|
||||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||||
from pyannote.audio.pipelines.utils import PipelineModel
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
from pyannote.core import Annotation, SlidingWindowFeature
|
||||||
from tqdm import tqdm
|
from pyannote.core import Segment
|
||||||
|
|
||||||
from .diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||||
model_dir = torch.hub._get_torch_home()
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
|
||||||
|
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok = True)
|
os.makedirs(model_dir, exist_ok = True)
|
||||||
if model_fp is None:
|
if model_fp is None:
|
||||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
# Dynamically resolve the path to the model file
|
||||||
|
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
|
||||||
|
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||||
|
else:
|
||||||
|
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||||
|
|
||||||
|
# Check if the resolved model file exists
|
||||||
|
if not os.path.exists(model_fp):
|
||||||
|
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||||
|
|
||||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
if not os.path.isfile(model_fp):
|
|
||||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
|
||||||
with tqdm(
|
|
||||||
total=int(source.info().get("Content-Length")),
|
|
||||||
ncols=80,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as loop:
|
|
||||||
while True:
|
|
||||||
buffer = source.read(8192)
|
|
||||||
if not buffer:
|
|
||||||
break
|
|
||||||
|
|
||||||
output.write(buffer)
|
|
||||||
loop.update(len(buffer))
|
|
||||||
|
|
||||||
model_bytes = open(model_fp, "rb").read()
|
model_bytes = open(model_fp, "rb").read()
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
|
||||||
)
|
|
||||||
|
|
||||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||||
hyperparameters = {"onset": vad_onset,
|
hyperparameters = {"onset": vad_onset,
|
||||||
@ -92,14 +81,14 @@ class Binarize:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
onset: float = 0.5,
|
onset: float = 0.5,
|
||||||
offset: Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
min_duration_on: float = 0.0,
|
min_duration_on: float = 0.0,
|
||||||
min_duration_off: float = 0.0,
|
min_duration_off: float = 0.0,
|
||||||
pad_onset: float = 0.0,
|
pad_onset: float = 0.0,
|
||||||
pad_offset: float = 0.0,
|
pad_offset: float = 0.0,
|
||||||
max_duration: float = float('inf')
|
max_duration: float = float('inf')
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -142,13 +131,12 @@ class Binarize:
|
|||||||
is_active = k_scores[0] > self.onset
|
is_active = k_scores[0] > self.onset
|
||||||
curr_scores = [k_scores[0]]
|
curr_scores = [k_scores[0]]
|
||||||
curr_timestamps = [start]
|
curr_timestamps = [start]
|
||||||
|
t = start
|
||||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||||
# currently active
|
# currently active
|
||||||
if is_active:
|
if is_active:
|
||||||
curr_duration = t - start
|
curr_duration = t - start
|
||||||
if curr_duration > self.max_duration:
|
if curr_duration > self.max_duration:
|
||||||
# if curr_duration > 15:
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
search_after = len(curr_scores) // 2
|
search_after = len(curr_scores) // 2
|
||||||
# divide segment
|
# divide segment
|
||||||
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||||||
@ -156,8 +144,8 @@ class Binarize:
|
|||||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||||
active[region, k] = label
|
active[region, k] = label
|
||||||
start = curr_timestamps[min_score_div_idx]
|
start = curr_timestamps[min_score_div_idx]
|
||||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||||
# switching from active to inactive
|
# switching from active to inactive
|
||||||
elif y < self.offset:
|
elif y < self.offset:
|
||||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
@ -166,14 +154,14 @@ class Binarize:
|
|||||||
is_active = False
|
is_active = False
|
||||||
curr_scores = []
|
curr_scores = []
|
||||||
curr_timestamps = []
|
curr_timestamps = []
|
||||||
|
curr_scores.append(y)
|
||||||
|
curr_timestamps.append(t)
|
||||||
# currently inactive
|
# currently inactive
|
||||||
else:
|
else:
|
||||||
# switching from inactive to active
|
# switching from inactive to active
|
||||||
if y > self.onset:
|
if y > self.onset:
|
||||||
start = t
|
start = t
|
||||||
is_active = True
|
is_active = True
|
||||||
curr_scores.append(y)
|
|
||||||
curr_timestamps.append(t)
|
|
||||||
|
|
||||||
# if active at the end, add final region
|
# if active at the end, add final region
|
||||||
if is_active:
|
if is_active:
|
||||||
@ -198,11 +186,11 @@ class Binarize:
|
|||||||
|
|
||||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
segmentation: PipelineModel = "pyannote/segmentation",
|
segmentation: PipelineModel = "pyannote/segmentation",
|
||||||
fscore: bool = False,
|
fscore: bool = False,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
**inference_kwargs,
|
**inference_kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||||
@ -241,67 +229,35 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
|||||||
return segmentations
|
return segmentations
|
||||||
|
|
||||||
|
|
||||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
class Pyannote(Vad):
|
||||||
|
|
||||||
active = Annotation()
|
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||||
for k, vad_t in enumerate(vad_arr):
|
print(">>Performing voice activity detection using Pyannote...")
|
||||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
super().__init__(kwargs['vad_onset'])
|
||||||
active[region, k] = 1
|
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||||
|
|
||||||
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
|
return self.vad_pipeline(audio)
|
||||||
|
|
||||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
@staticmethod
|
||||||
active = active.support(collar=min_duration_off)
|
def preprocess_audio(audio):
|
||||||
|
return torch.from_numpy(audio).unsqueeze(0)
|
||||||
|
|
||||||
# remove tracks shorter than min_duration_on
|
@staticmethod
|
||||||
if min_duration_on > 0:
|
def merge_chunks(segments,
|
||||||
for segment, track in list(active.itertracks()):
|
chunk_size,
|
||||||
if segment.duration < min_duration_on:
|
onset: float = 0.5,
|
||||||
del active[segment, track]
|
offset: Optional[float] = None,
|
||||||
|
):
|
||||||
|
assert chunk_size > 0
|
||||||
|
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||||
|
segments = binarize(segments)
|
||||||
|
segments_list = []
|
||||||
|
for speech_turn in segments.get_timeline():
|
||||||
|
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
|
|
||||||
active = active.for_json()
|
if len(segments_list) == 0:
|
||||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
print("No active speech found in audio")
|
||||||
return active_segs
|
return []
|
||||||
|
assert segments_list, "segments_list is empty."
|
||||||
def merge_chunks(segments, chunk_size):
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||||
"""
|
|
||||||
Merge operation described in paper
|
|
||||||
"""
|
|
||||||
curr_end = 0
|
|
||||||
merged_segments = []
|
|
||||||
seg_idxs = []
|
|
||||||
speaker_idxs = []
|
|
||||||
|
|
||||||
assert chunk_size > 0
|
|
||||||
binarize = Binarize(max_duration=chunk_size)
|
|
||||||
segments = binarize(segments)
|
|
||||||
segments_list = []
|
|
||||||
for speech_turn in segments.get_timeline():
|
|
||||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
||||||
|
|
||||||
if len(segments_list) == 0:
|
|
||||||
print("No active speech found in audio")
|
|
||||||
return []
|
|
||||||
# assert segments_list, "segments_list is empty."
|
|
||||||
# Make sur the starting point is the start of the segment.
|
|
||||||
curr_start = segments_list[0].start
|
|
||||||
|
|
||||||
for seg in segments_list:
|
|
||||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
|
||||||
merged_segments.append({
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
})
|
|
||||||
curr_start = seg.start
|
|
||||||
seg_idxs = []
|
|
||||||
speaker_idxs = []
|
|
||||||
curr_end = seg.end
|
|
||||||
seg_idxs.append((seg.start, seg.end))
|
|
||||||
speaker_idxs.append(seg.speaker)
|
|
||||||
# add final
|
|
||||||
merged_segments.append({
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
})
|
|
||||||
return merged_segments
|
|
66
whisperx/vads/silero.py
Normal file
66
whisperx/vads/silero.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from io import IOBase
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Mapping, Text
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
|
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||||
|
|
||||||
|
|
||||||
|
class Silero(Vad):
|
||||||
|
# check again default values
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
print(">>Performing voice activity detection using Silero...")
|
||||||
|
super().__init__(kwargs['vad_onset'])
|
||||||
|
|
||||||
|
self.vad_onset = kwargs['vad_onset']
|
||||||
|
self.chunk_size = kwargs['chunk_size']
|
||||||
|
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
force_reload=False,
|
||||||
|
onnx=False,
|
||||||
|
trust_repo=True)
|
||||||
|
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||||
|
|
||||||
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
|
"""use silero to get segments of speech"""
|
||||||
|
# Only accept 16000 Hz for now.
|
||||||
|
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||||
|
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||||
|
sample_rate = audio["sample_rate"]
|
||||||
|
if sample_rate != 16000:
|
||||||
|
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||||
|
|
||||||
|
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||||
|
model=self.vad_pipeline,
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
max_speech_duration_s=self.chunk_size,
|
||||||
|
threshold=self.vad_onset
|
||||||
|
# min_silence_duration_ms = self.min_duration_off/1000
|
||||||
|
# min_speech_duration_ms = self.min_duration_on/1000
|
||||||
|
# ...
|
||||||
|
# See silero documentation for full option list
|
||||||
|
)
|
||||||
|
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments_list,
|
||||||
|
chunk_size,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
):
|
||||||
|
assert chunk_size > 0
|
||||||
|
if len(segments_list) == 0:
|
||||||
|
print("No active speech found in audio")
|
||||||
|
return []
|
||||||
|
assert segments_list, "segments_list is empty."
|
||||||
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from pyannote.core import Annotation, Segment
|
||||||
|
|
||||||
|
|
||||||
|
class Vad:
|
||||||
|
def __init__(self, vad_onset):
|
||||||
|
if not (0 < vad_onset < 1):
|
||||||
|
raise ValueError(
|
||||||
|
"vad_onset is a decimal value between 0 and 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments,
|
||||||
|
chunk_size,
|
||||||
|
onset: float,
|
||||||
|
offset: Optional[float]):
|
||||||
|
"""
|
||||||
|
Merge operation described in paper
|
||||||
|
"""
|
||||||
|
curr_end = 0
|
||||||
|
merged_segments = []
|
||||||
|
seg_idxs: list[tuple]= []
|
||||||
|
speaker_idxs: list[Optional[str]] = []
|
||||||
|
|
||||||
|
curr_start = segments[0].start
|
||||||
|
for seg in segments:
|
||||||
|
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
curr_start = seg.start
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
curr_end = seg.end
|
||||||
|
seg_idxs.append((seg.start, seg.end))
|
||||||
|
speaker_idxs.append(seg.speaker)
|
||||||
|
# add final
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return merged_segments
|
||||||
|
|
||||||
|
# Unused function
|
||||||
|
@staticmethod
|
||||||
|
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||||
|
active = Annotation()
|
||||||
|
for k, vad_t in enumerate(vad_arr):
|
||||||
|
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||||
|
active[region, k] = 1
|
||||||
|
|
||||||
|
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||||
|
active = active.support(collar=min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
active = active.for_json()
|
||||||
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||||
|
return active_segs
|
Reference in New Issue
Block a user