343 Commits

Author SHA1 Message Date
b94778fd60 Merge 399010fd12 into d700b56c9c 2025-06-13 16:23:14 +00:00
399010fd12 Revert "docs: add troubleshooting section for libcudnn dependencies in README"
This reverts commit 6fe0a8784a.

Revert the commit now that the issue is fixed.

Signed-off-by: CHEN, CHUN <jim60105@gmail.com>
2025-06-14 00:22:57 +08:00
d3dcb1175f chore: restrict onnxruntime to version 1.19 for python 3.9 compatibility
- Restrict the onnxruntime dependency to versions >=1.19 and <1.20.0 to avoid potential compatibility issues.

Signed-off-by: CHEN, CHUN <jim60105@gmail.com>
2025-06-14 00:21:53 +08:00
4f99f1f67c chore: restrict torch version to below 2.4 in dependencies
torch depends on libcudnn9 from version 2.4.0 onward.
If we restrict torch<2.4.0, there is no need to manually install libcudnn8 and also save about 1GB disk space.

- Update torch dependency to be below version 2.4.0 instead of at least 2.5.1
- Change torchaudio dependency to have no minimum version specified

Signed-off-by: CHEN, CHUN <jim60105@gmail.com>
2025-06-14 00:21:53 +08:00
d700b56c9c docs: add missing torch import to Python usage example in README 2025-06-08 03:34:49 -06:00
bog
b343241253 feat: add diarize_model arg to CLI (#1101) 2025-05-31 13:32:31 +02:00
6fe0a8784a docs: add troubleshooting section for libcudnn dependencies in README 2025-05-31 05:20:06 -06:00
5012650d0f chore: update lockfile 2025-05-03 16:25:43 +02:00
108bd0c400 chore: add lockfile check step to CI workflows 2025-05-03 16:25:43 +02:00
b2d50a027b chore: bump version 2025-05-03 11:38:54 +02:00
36d552cad3 fix: remove DiarizationPipeline from public API 2025-05-03 09:25:59 +02:00
7d36b832f9 refactor: update CLI entry point 2025-05-03 09:25:59 +02:00
d2a493e910 refactor: implement lazy loading for module imports in whisperx 2025-05-03 09:25:59 +02:00
f5b40b5366 chore: update version to 3.3.3 in pyproject.toml and uv.lock 2025-05-01 11:08:54 +02:00
ac0c8bd79a feat: add version and Python version arguments to CLI 2025-05-01 11:08:54 +02:00
cd59f21d1a fix: downgrade ctranslate2 dependency version 2025-05-01 11:08:54 +02:00
0aed874589 Remove duplicated item
"lv": "latvian"
2025-04-12 11:08:15 +02:00
f10dbf6ab1 fix: update setuptools configuration to include package discovery for whisperx 2025-03-25 18:49:44 +01:00
a7564c2ad6 docs: update installation instructions 2025-03-25 17:02:41 +01:00
e7712f496e refactor: update import statements to use explicit module paths across multiple files 2025-03-25 16:24:21 +01:00
8e53866704 feat: pass hotwords argument to get_prompt (#1073)
Co-authored-by: Jade Moillic <jade.moillic@radiofrance.com>
2025-03-24 10:47:47 +01:00
3205436d58 Merge pull request #1002 from Barabazs/feat/uv 2025-03-23 12:59:46 +00:00
8c58c54635 Revert "feat: add Basque alignment model (#1074)" (#1077)
This reverts commit 0d9807adc5.
2025-03-05 15:19:23 +01:00
0d9807adc5 feat: add Basque alignment model (#1074) 2025-03-04 14:55:30 +01:00
4db839018c feat: add Tagalog (tl - Filipino) Phoneme-based ASR Model (#1067) 2025-02-23 09:59:48 +01:00
f8d11df727 docs: Update README example commands with generic audio path 2025-02-19 08:24:04 +01:00
d2f0e53f71 chore: remove tmp workflow 2025-02-12 08:23:23 +01:00
7489ebf876 feat: update build and release workflow to use uv for package installation and publishing 2025-02-12 08:23:23 +01:00
90256cc481 feat: use uv recommended setup 2025-02-12 08:23:23 +01:00
b41ebd4871 chore: add numpy to deps 2025-02-12 08:23:23 +01:00
63bc1903c1 feat: update Python compatibility workflow to use uv 2025-02-12 08:23:23 +01:00
272714e07d feat: use uv for building package 2025-02-12 08:23:23 +01:00
44e8bf5bb6 Merge pull request #1024 from philmcmahon/local-files-only-param
Add models_cache_only param
2025-01-27 14:26:19 +00:00
7b3c9ce629 Add models_cache_only param 2025-01-27 12:16:37 +00:00
36d2622e27 feat: add Latvian align model 2025-01-25 09:45:17 +01:00
8bfa12193b Merge pull request #1006 from tan90xx/main
chore: fix variable naming inconsistency from `segments` to `segments_list`
2025-01-20 14:05:34 +00:00
acbeba6057 Update silero.py 2025-01-20 20:01:21 +08:00
fca563a782 Update silero.py 2025-01-20 19:52:37 +08:00
2117909bf6 Merge pull request #1005 from tan90xx/main
chore: handle empty segments_list case in silero
2025-01-19 13:51:34 +00:00
de0d8fe313 chore: handle empty segments_list case in silero
prevent errors
2025-01-19 21:20:56 +08:00
355f8e06f7 Merge pull request #1003 from Barabazs/chore/remove-aws-url
chore: remove deprecated VAD_SEGMENTATION_URL
2025-01-17 15:28:24 +00:00
86e2b3ee74 chore: remove deprecated VAD_SEGMENTATION_URL 2025-01-17 09:12:05 +01:00
70c639cdb5 doc: refer to DEFAULT_ALIGN_MODELS_HF for other langs 2025-01-17 08:47:44 +01:00
235536e28d Update links to language models in README 2025-01-17 08:47:44 +01:00
12604a48ea Merge pull request #986 from bfs18/main
support timestamp for numbers.
2025-01-14 21:03:51 +00:00
ffbc73664c change the docstrings and comments to English 2025-01-13 22:56:48 +08:00
289eadfc76 fix a merge error. 2025-01-13 20:26:27 +08:00
22a93f2932 Merge branch 'main' into main 2025-01-13 19:34:21 +08:00
1027367b79 Merge pull request #995 from winking324/main
fix vad_method is none
2025-01-13 10:10:29 +00:00
5e54b872a9 Merge branch 'main' into main 2025-01-13 10:09:20 +00:00
6be02cccfa Update asr.py 2025-01-13 10:08:09 +00:00
2f93e029c7 feat: add SegmentData type for temporary processing during alignment 2025-01-13 10:45:50 +01:00
024bc8481b refactor: consolidate segment data handling in alignment function 2025-01-13 10:45:50 +01:00
f286e7f3de refactor: improve type hints and clean up imports 2025-01-13 10:45:50 +01:00
73e644559d refactor: remove namespace for consistency 2025-01-13 10:45:50 +01:00
1ec527375a fix vad_method is none 2025-01-13 13:53:35 +08:00
6695426a85 fix new vad paths 2025-01-12 12:50:15 +00:00
7a98456321 Merge pull request #888 from 3manifold/silero-vad
Silero VAD support
2025-01-11 17:15:27 +00:00
aaddb83aa5 switch from case to ifelse 2025-01-11 17:11:21 +00:00
c288f4812a Merge branch 'main' into silero-vad 2025-01-11 17:05:53 +00:00
4ebfb078c5 make no beam consistent with backtrack. 2025-01-09 23:13:11 +08:00
65b2332e13 make align a bit faster. 2025-01-09 19:33:26 +08:00
69281f3a29 support timestamps for numbers. 2025-01-09 15:23:40 +08:00
734084cdf6 bump: update version to 3.3.1 2025-01-08 18:00:34 +01:00
9395b0de18 Update tmp.yml 2025-01-08 17:59:28 +01:00
d57f9dc54c Create tmp.yml 2025-01-08 17:59:28 +01:00
a90bd1ce3f dataclasses replace method 2025-01-08 17:59:13 +01:00
79eb8fa53d Accept alternative VAD methods. Extend to use Silero VAD. 2025-01-06 13:41:46 +01:00
10b05fc43f refactor: replace NamedTuple with TranscriptionOptions in FasterWhisperPipeline 2025-01-05 18:56:19 +01:00
26d9b46888 feat: include speaker information in WriteTXT when diarizing 2025-01-05 18:21:34 +01:00
9a8967f27e refactor: add type hints 2025-01-05 11:48:24 +01:00
0f7f9f9f83 refactor: simplify imports for better type inference 2025-01-05 11:48:24 +01:00
c60594fa3b fix: update import statement for conjunctions module 2025-01-05 11:48:24 +01:00
4916192246 chore: bump whisperX to 3.3.0 2025-01-02 14:09:10 +01:00
cbdac53e87 chore: update ctranslate2 version to restrict <4.5.0 2025-01-02 14:09:10 +01:00
940a223219 fix: add UTF-8 encoding when reading README.md 2025-01-02 12:43:59 +01:00
a0eb31019b chore: update license in setup.py 2025-01-02 08:41:04 +01:00
b08ad67a72 docs: update installation instructions in README 2025-01-02 08:35:45 +01:00
c18f9f979b fix: update README image source and enhance setup.py for long description 2025-01-02 08:30:04 +01:00
948b3e368b chore: update gitignore 2025-01-01 18:47:40 +01:00
e9ac5b63bc chore: clean up MANIFEST.in by removing unnecessary asset inclusions 2025-01-01 18:47:40 +01:00
90b45459d9 feat: add build and release workflow 2025-01-01 18:47:40 +01:00
81c4af96a6 feat: add Python compatibility testing workflow
feat: restrict Python versions to 3.9 - 3.12
2025-01-01 15:29:03 +01:00
1c6d9327bc feat: use model_dir as cache_dir for wav2vec2 (#681) 2025-01-01 13:22:27 +01:00
0fdb55d317 feat: add local_files_only option on whisperx.load_model for offline mode (#867)
Adds the parameter local_files_only (default False for consistency) to whisperx.load_model so that the user can avoid downloading the file and return the path to the local cached file if it exists.

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:16:45 +01:00
51da22771f feat: add verbose output (#759)
---------

Co-authored-by: Abhishek Sharma <abhishek@zipteams.com>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:07:52 +01:00
15ad5bf7df feat: update versions for pyannote:3.3.2 and faster-whisper:1.1.0 (#936)
* chore: bump faster-whisper to 1.1.0

* chore: bump pyannote to 3.3.2

* feat: add multilingual option in load_model function

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:41:09 +01:00
7fdbd21fe3 feat: add support for faster-whisper 1.0.3 (#875)
---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:07:42 +01:00
3ff625c561 feat: update faster-whisper to 1.0.2 (#814)
* Update faster-whisper to 1.0.2 to enable model distil-large-v3

* feat: add hotwords option to default_asr_options

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 09:41:22 +01:00
7307306a9d chore: bump version 2024-12-18 09:03:04 +01:00
3027cc32bc Update MANIFEST.in to include necessary files 2024-12-17 08:11:49 +01:00
9e4b1b4c49 fix: Force ctranslate to version 4.4.0
Force ctranslate to version 4.4.0 due libcudnn_ops_infer.so.8:
https://github.com/SYSTRAN/faster-whisper/issues/729

Co-authored-by: Icaro Bombonato <ibombonatosites@gmail.com>
2024-12-16 13:30:08 +01:00
9b9e03c4cc feat: update Norwegian models (#687)
Updated Norwegian Bokmål and Norwegian Nynorsk models

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-16 11:08:48 +01:00
19eff8e79a feat: add new align models (#922)
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-16 11:06:43 +01:00
6f3bc5b7b8 Added Romanian phoneme-based ASR model (#791)
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-16 08:09:53 +01:00
9809336db6 Fix link in README.md 2024-12-16 08:04:59 +01:00
a898b3ba94 Remove typo in error message 2024-12-16 08:02:42 +01:00
c141074cbd Merge pull request #945 from m-bain/m-bain/local_model
move model to assets
2024-12-14 22:54:56 -06:00
a9e50ef0af move model to assets 2024-12-14 22:53:53 -06:00
161ae1f7ad Merge pull request #944 from m-bain/m-bain/local_model
local vad model
2024-12-14 22:34:38 -06:00
a83ddbdf9b local vad model 2024-12-14 22:16:43 -06:00
9e3a9e0e38 Merge pull request #852 from jan-panoch/main
Update alignment.py - added alignment for  sk and sl languages
2024-08-20 00:05:56 +08:00
3f339f9515 Update alignment.py - remove commented-out alignment modules for hr language 2024-08-09 13:00:12 +02:00
9a9b6171e6 Update alignment.py - trying another hr alignment 2024-08-08 08:37:55 +02:00
59b4d88d1d Update alignment.py - trying another hr alignment file 2024-08-08 08:29:11 +02:00
6f70aa6beb Update alignment.py - added croatian (hr) language 2024-08-08 08:10:55 +02:00
912920c591 Update alignment.py - added alignment for sk and sl languages 2024-08-07 10:05:17 +02:00
58f00339af BSD 2 LICENSE 2024-07-11 13:01:15 +04:00
f2da2f858e Update README.md 2024-03-20 15:47:18 +00:00
78dcfaab51 upgrade faster-whisper 2024-02-23 09:30:12 +00:00
d6562c26da Merge pull request #716 from cococig/fix/faster-whisper-from-pypi
fix: update faster-whisper dependencies
2024-02-22 16:51:06 +00:00
c313f4dd5c fix: update faster-whisper dependencies 2024-02-23 01:42:22 +09:00
bbaa2f0d1a update kwargs 2024-02-22 15:59:14 +00:00
e906be9688 Merge pull request #703 from victor-upmeet/large-v3-demo
Add Replicate large-v3 demo
2024-02-18 15:43:51 +00:00
fbbd07bece Merge pull request #669 from KossaiSbai/ks/supress-numeral-symbol-tokens-message
Get rid of numeral_symbol_tokens variable in printed message
2024-02-18 15:43:23 +00:00
d8c9196346 Add Replicate large-v3 demo 2024-02-18 12:17:11 +01:00
2686f74bc9 Get rid of numeral_symbol_tokens variable in printed message 2024-01-19 22:25:21 +00:00
8227807fa9 Delete build/lib/whisperx directory 2024-01-02 19:36:36 -07:00
59962a70be Merge pull request #646 from santialferez/diarize-patch-1
Update pyannote to v3.1.1 to fix a diarization problem (and diarize.py)
2024-01-03 02:35:53 +00:00
06e30b2a25 Merge pull request #654 from Swami-Abhinav/provide-custom-load-vad
Added option to load Custom VAD model to load model method
2024-01-01 17:38:30 +00:00
6bb2f1cd48 Added Vad custom option 2024-01-01 14:56:51 +05:30
f8cc46c6f7 Merge pull request #648 from canoalberto/main
Fixes --model_dir path
2023-12-28 21:23:42 +00:00
942c336b8f Fixes --model_dir path 2023-12-27 14:03:54 -05:00
8ae6416594 update setup.py to install pyannote.audio==3.1.1, update diarize.py to include num_speakers; to fix Issue #592 2023-12-26 13:01:49 +01:00
8540ff5985 Merge pull request #636 from NbAiLab/peregilk-patch-1
Adding Norwegian Bokmål and Norwegian Nynorsk
2023-12-19 15:55:20 +00:00
5dfbfcbdc0 Adding Norwegian Bokmål and Norwegian Nynorsk
Adding Wav2Vec2-models for Norwegian Bokmål and Norwegian Nynorsk. The models are testet together with WhisperX, and works great. For Bokmål I have added the 1B model, even if I see fairly little difference between that and the 300M model. For Norwegian Nynorsk only a 300M exist.The quality of the Wav2Vec models are also reported here: https://arxiv.org/abs/2307.01672
2023-12-19 08:48:21 +01:00
1c7b1a87da Merge pull request #630 from mlopsengr/patch-1
Update README.md
2023-12-17 15:53:44 +00:00
9f23739f90 Update README.md
Demonstrates use of argument to save model to local path.
2023-12-15 13:46:32 +00:00
19ab91c5a6 Merge pull request #618 from gillens/main
Update README to correct speaker diarization version link
2023-12-10 17:35:42 -06:00
089cd5ab21 Merge pull request #585 from kurianbenoy/ml-asr
Add alignment model for Malayalam
2023-12-10 17:35:14 -06:00
2b7ab95ad6 Update README to Correct Speaker Diarization Version Link
Currently errors if user just accepts terms for README link version
3.0. Version 3.1 introduced in pull request #586
2023-12-07 12:48:21 -08:00
4553e0d4ed Merge pull request #617 from MahmoudAshraf97/main 2023-12-04 16:15:48 +00:00
f865dfe710 fix typo 2023-12-04 17:38:50 +03:00
4acbdd75be add "yue" to supported languages that was added along with Large-V3 2023-12-04 17:27:54 +03:00
e9c507ce5d Merge pull request #605 from M0HID/patch-1
fix link
2023-11-28 11:56:29 +00:00
a5dca2cc65 Merge pull request #603 from spbisc97/patch-1
pip compliance for git+ installs
2023-11-28 01:24:35 +00:00
8a8eeb33ee Update README.md 2023-11-27 17:15:28 +00:00
b4d7b1a422 pip compliance for git+ installs
Minimal change to let pip install requirements
2023-11-26 18:37:04 +01:00
5a16e59217 Merge pull request #599 from MahmoudAshraf97/main
support for `large-v3`
2023-11-26 12:34:16 +00:00
b4e4143e3b install faster-whisper using git as pypi is not updated anymore 2023-11-25 17:42:36 +00:00
4b05198eed bump faster-whisper to 0.10 2023-11-25 12:11:08 +00:00
71a5281bde support for large-v3 2023-11-25 12:09:00 +00:00
d97cdb7bcf Merge pull request #586 from remic33/main 2023-11-17 10:48:57 +00:00
20161935a1 feat: pass model to 3.1 in code 2023-11-17 11:12:16 +01:00
1d7f8ccbf1 feat: get rid of pyannote versioning and go to 3.1 2023-11-17 11:03:23 +01:00
5756b0fb13 Update alignment.py 2023-11-17 05:21:23 +05:30
aaaa3de810 Update alignment.py 2023-11-17 05:18:19 +05:30
ba30365344 Merge pull request #584 from DougTrajano/patch-1
Move load_model after WhisperModel
2023-11-16 12:09:21 +00:00
bd3aa03b6f Move load_model after WhisperModel 2023-11-16 08:59:28 -03:00
f5c544ff90 Merge pull request #581 from davidmartinrius/catalan_align_model
Add align model for catalan language.
2023-11-16 10:54:24 +00:00
7c2a9a8b7b Merge pull request #580 from kaka1909/main
Update asr.py and make the model parameter be used
2023-11-16 10:54:02 +00:00
9f41c49fe5 Add align model for catalan language. 2023-11-16 11:43:36 +01:00
48d651e5ea Update asr.py and make the model parameter be used 2023-11-16 15:29:24 +08:00
4ece2369d7 Merge pull request #556 from sorgfresser/remove-space-segment-align
no align based on space
2023-11-11 02:03:56 +00:00
52fbe5c26f Merge pull request #570 from hidenori-endo/main
Drop ffmpeg-python dependency and call ffmpeg directly.
2023-11-09 18:39:53 +00:00
6703d2774b Drop ffmpeg-python dependency 2023-11-10 03:26:47 +09:00
a2af569838 Merge pull request #554 from sorgfresser/fix-binarize-unbound
fix unboundlocalerror
2023-11-07 10:54:24 +00:00
0c7f32f55c no align based on space 2023-11-03 19:47:00 +01:00
6936dd6991 default t 2023-11-03 18:50:15 +01:00
6b1100a919 Merge pull request #549 from amolinasalazar/minor_fixes
Minor fixes for word options and subtitles
2023-10-31 12:26:47 -07:00
d4a600b568 REMOVE duplicated code 2023-10-31 18:55:50 +01:00
afd5ef1d58 FIX warnings for word options 2023-10-31 18:55:35 +01:00
dbeb8617f2 Merge pull request #521 from kaihe-stori/update-readme
Add a special note about Speaker-Diarization-3.0 in readme
2023-10-25 11:18:47 -07:00
c6fe379d9e Merge pull request #517 from jkukul/support-language-names-as-parameters
Support language names in `--language` parameter.
2023-10-25 11:16:30 -07:00
e9a6385d3c Merge pull request #541 from justinwlin/main
Update setup.py to download pyannote depending on platform
2023-10-25 11:14:11 -07:00
b522133340 Update setup.py to be adaptive to platform 2023-10-24 18:42:14 -04:00
49e0130e4e Merge pull request #531 from accessful-ai/main 2023-10-17 06:54:22 -07:00
d4ac9531d9 Update setup.py 2023-10-17 15:23:38 +02:00
66808f6147 Merge pull request #529 from MahmoudAshraf97/main 2023-10-16 10:53:18 -07:00
b69956d725 . 2023-10-16 20:43:37 +03:00
a150df4310 Merge pull request #527 from jkukul/pass-beam-size-to-fast-whisper 2023-10-15 07:15:13 -07:00
02c0323777 fix 2023-10-15 16:25:15 +03:00
14a7cab8eb Pass patience and beam_size to faster-whisper. 2023-10-14 13:51:29 +02:00
acf31b754f update readme 2023-10-11 22:56:38 -04:00
4cdce3b927 Merge pull request #518 from characat0/main
fix(diarize): key error on empty track
2023-10-10 12:54:35 -07:00
a5356509b6 fix(diarize): key error on empty track 2023-10-10 14:50:41 -05:00
1001a055db Support language names in --language. 2023-10-10 13:55:47 +02:00
051047bb25 Merge pull request #510 from MahmoudAshraf97/main
fix minimum input length for torch wav2vec2 models
2023-10-05 15:31:08 -07:00
c1b821a08d fix list markdown 2023-10-05 15:14:29 -07:00
78e20a16a8 update links 2023-10-05 15:14:03 -07:00
be07c13f75 read does actually work... 2023-10-05 14:48:39 -07:00
8049dba2f7 fix minimum input length for torch wav2vec2 models 2023-10-06 00:41:23 +03:00
d077abdbdf Merge pull request #509 from valentt/patch-1
Update README.md
2023-10-05 14:13:20 -07:00
84423ca517 Update README.md
Added info that Hugging Face token has to be write token because read token doesn't work.
2023-10-05 19:14:28 +02:00
a22b8b009b Merge pull request #507 from compasspathways/fix/pass-vad-options
Fix: Allow vad options to be configurable by passing to FasterWhisperPipeline and merge_chunks.
2023-10-05 07:48:19 -07:00
79801167ac Fix: Allow vad options to be configurable by correctly passing down to FasterWhisperPipeline. 2023-10-05 10:06:34 -04:00
07fafa37b3 Merge pull request #494 from mvoggu/main
fix: ZeroDivisionError when --print_progress True
2023-09-27 07:46:06 -07:00
a0b6459c8b fix: ZeroDivisionError when --print_progress True 2023-09-27 20:10:43 +05:30
2a11ce3ef0 Merge pull request #487 from piuy11/main
Update alignment.py
2023-09-26 14:17:46 -07:00
18abcf46ee Merge pull request #492 from remic33/pyannote3
Pyannote3
2023-09-26 14:16:57 -07:00
652aa24919 change pyannote version 2023-09-26 23:04:28 +02:00
b17908473d correct 3.0 pyannote weights 2023-09-26 17:18:20 +02:00
f137f31de6 Update alignment.py 2023-09-25 15:33:06 +09:00
e94b904308 Merge pull request #474 from sorgfresser/pin-faster-whisper 2023-09-19 16:53:42 -07:00
ffd6167b26 Merge pull request #473 from sorgfresser/fix-faster-whisper-threads 2023-09-19 16:53:34 -07:00
4c7ce14fed pin faster whisper 2023-09-14 13:19:11 +02:00
0ae0d49d1d add faster whisper threading 2023-09-14 11:47:51 +02:00
b1a98b78c9 Merge pull request #472 from darwintree/main
chore(writer): improve text display(ja etc) in json file
2023-09-10 08:37:39 -06:00
c6d9e6cb67 chore(writer): improve text display(ja etc) in json file 2023-09-10 22:02:47 +08:00
31f5233949 Merge pull request #459 from awerks/main
A solution to long subitles and words without timestamps
2023-09-06 10:09:27 -06:00
2ca99ce909 A solution to long subitles
Example usage: 
subtitles_proccessor = SubtitlesProcessor(output["segments"], detected_language, max_line_length = 50, min_char_length_splitter = 35)
subtitles_proccessor.save("subtitles.srt", advanced_splitting = True)
2023-09-04 21:49:34 +02:00
15d9e08d3e Merge pull request #458 from remic33/correct_default_asr_options
fix: correct defaut_asr_options with new options (patch 0.8)
2023-09-04 09:22:16 -06:00
15451d0f1c fix: correct defaut_asr_options with new options (patch 0.8) 2023-09-04 17:08:19 +02:00
8c4a21b66d Merge pull request #440 from jim60105/main
chore(writer): Join words without spaces for ja, zh
2023-08-29 11:22:30 -06:00
5223de2a41 fix: UnboundLocalError: local variable 'align_language' referenced before assignment 2023-08-30 01:11:09 +08:00
f505702dc7 chore(writer): Join words without spaces for ja, zh
fix #248, fix #310
2023-08-30 01:11:09 +08:00
adf455a97c Merge pull request #445 from jim60105/add-merge-chunk-size-as-argument
feat: Add merge chunks chunk_size as arguments.
2023-08-29 10:05:14 -06:00
9647f60fca Merge branch 'main' into add-merge-chunk-size-as-argument 2023-08-29 10:05:05 -06:00
a8bfac6bef Merge pull request #427 from awerks/main
Update alignment.py
2023-08-29 10:03:46 -06:00
6d414e20e2 Merge pull request #438 from invisprints/fix-speaker-missing
fix missing speaker prefix
2023-08-29 10:03:06 -06:00
3c7b03935b Merge pull request #430 from dotgrid/dotgrid-docs-patch
Document --compute_type command line option
2023-08-29 10:02:51 -06:00
eb771cf56d feat: Add merge chunks chunk_size as arguments.
Suggest from https://github.com/m-bain/whisperX/issues/200#issuecomment-1666507780
2023-08-29 23:09:02 +08:00
cc81ab7db7 fix missing prefix
Fixed missing the speaker part when enable --highlight_words
2023-08-25 12:08:16 +08:00
ef965a03ed Merge pull request #431 from CaRniFeXeR/main
adds link to whisperX medium on replicate.com
2023-08-21 17:25:15 +01:00
6f2ff16aad Merge pull request #1 from CaRniFeXeR/CaRniFeXeR-replicate-models
adds link to whisperX medium on replicate and updates replicate bades…
2023-08-21 08:20:25 +08:00
81b12af321 adds link to whisperX medium on replicate and updates replicate bades in README.md 2023-08-21 08:16:46 +08:00
c1197c490e Document --compute_type command line option 2023-08-19 08:19:49 +01:00
4e28492dbd Update alignment.py 2023-08-17 14:57:53 +02:00
6cb7267dc2 Update alignment.py 2023-08-17 14:56:54 +02:00
abbb66b58e Update alignment.py 2023-08-17 14:53:53 +02:00
ea7bb91a56 Update asr.py 2023-08-17 14:49:57 +02:00
d2d840f06c Update utils.py 2023-08-17 14:45:23 +02:00
0a1137e41c Merge pull request #429 from sorgfresser/no-segments-writer
fix writer fail on segments 0
2023-08-17 13:20:38 +01:00
0767597bff fix writer fail on segments 0 2023-08-17 14:18:16 +02:00
cb3ed4ab9d Update transcribe.py 2023-08-16 16:22:29 +02:00
65688208c9 Update alignment.py 2023-08-16 16:18:00 +02:00
72685d0398 Update asr.py 2023-08-16 16:15:24 +02:00
1bb4839b0f Update alignment.py 2023-08-16 16:13:28 +02:00
4acb5b3abc Update asr.py 2023-08-16 16:11:46 +02:00
14e593f60b Update alignment.py 2023-08-16 16:08:25 +02:00
66da4b3eb7 Merge pull request #418 from Ayushi-Desynova/main-1
Update alignment.py
2023-08-10 12:14:08 +01:00
18d5fdc995 Add telugu language to alignment.py 2023-08-10 12:13:52 +01:00
423667f00b Update alignment.py 2023-08-09 17:08:56 +05:30
1b092de19a Merge pull request #395 from Joemgu7/main
Fix repeat transcription on different languages and proper suppress_numerals use
2023-08-02 13:44:27 +01:00
69a52b00c7 Merge pull request #400 from davidas1/fast-diarize
make diarization faster
2023-08-02 13:43:20 +01:00
9e3145cead more 2023-08-02 10:36:56 +03:00
577db33430 more 2023-08-02 10:35:20 +03:00
da6ed83dc9 more 2023-08-02 10:34:42 +03:00
7eb9692cb9 more 2023-08-02 10:32:02 +03:00
8de0e2af51 make diarization faster 2023-08-02 10:11:43 +03:00
225f6b4d69 fix suppress_numerals 2023-07-29 19:34:51 +02:00
864976af23 fix issue by resetting tokenizer 2023-07-29 18:56:33 +02:00
9d736dca1c add some warning if languages do not match 2023-07-29 18:20:59 +02:00
d87f6268d0 fix preset language 2023-07-29 18:13:36 +02:00
d80b98601b Merge pull request #255 from tijszwinkels/cuda-11.8
Suggest using pytorch-cuda 11.8 instead of 11.7
2023-07-25 00:29:08 +01:00
aa37509362 Merge branch 'main' into cuda-11.8 2023-07-25 00:28:53 +01:00
15b4c558c2 Merge pull request #352 from daanelson/replicate-demo
adding link to Replicate demo
2023-07-24 10:48:24 +01:00
54504a2be8 Merge pull request #374 from abCods/main
Add Urdu model support for alignment
2023-07-24 10:47:52 +01:00
8c0fee90d3 Update alignment.py 2023-07-24 10:47:41 +01:00
016f0293cd Merge pull request #378 from baer/patch-1
Remove torchvision from README
2023-07-24 10:47:14 +01:00
44daf50501 Merge pull request #382 from mabergerx/patch-1
Update transcribe.py -> small change in `batch_size` description
2023-07-24 10:46:55 +01:00
48e7caad77 Update transcribe.py -> small change in batch_size description
Changed the description of the `batch_size` parameter.
2023-07-24 11:45:38 +02:00
8673064658 Remove torchvision from README 2023-07-20 17:02:34 -07:00
e6ecbaa68f Remove spacing 2023-07-20 03:20:47 +05:00
e92325b7eb Remove the fix 2023-07-20 03:19:37 +05:00
eb712f3999 Rectify refernce to the word 2023-07-20 02:54:06 +05:00
30eff5a01f Replace double quotes to single for JSON parsing 2023-07-20 02:32:37 +05:00
734ecc2844 Add Urdu model support for alignment 2023-07-17 19:29:41 +05:00
512ab1acf9 adding Replicate demo 2023-06-30 18:22:10 -07:00
befe2b242e torch 2+ 2023-06-07 22:43:29 +01:00
f9c5ff9f08 Merge pull request #309 from Ca-ressemble-a-du-fake/patch-1
Add Audacity export
2023-06-07 11:50:05 +01:00
d39c1b2319 add "aud" to output_format 2023-06-07 11:48:49 +01:00
b13778fefd make aud optional 2023-06-07 11:47:49 +01:00
076ff96eb2 Add Audacity export
This exports the transcript to a text file that can be directly imported in Audacity as label file. This is useful to quickly check the transcript-audio alignment.
2023-06-07 05:49:49 +02:00
0c84c26d92 Merge pull request #303 from m-bain/v3
Suppress numerals
2023-06-05 15:46:26 +01:00
d7f1d16f19 suppress numerals change logic 2023-06-05 15:44:17 +01:00
74a00eecd7 suppress numerals fix 2023-06-05 15:33:04 +01:00
b026407fd9 Merge branch 'v3' of https://github.com/m-bain/whisperX into v3
Conflicts:
	whisperx/asr.py
2023-06-05 15:30:02 +01:00
a323cff654 --suppress_numerals option, ensures non-numerical words, for wav2vec2 alignment 2023-06-05 15:27:42 +01:00
93ed6cfa93 interspeech 2023-06-01 16:54:16 +01:00
9797a67391 Merge pull request #294 from SohaibAnwaar/fix/typehint-bug-fix
fix: Bug  in type  hinting
2023-05-30 11:13:22 +01:00
5a4382ae4d fix: Bug in type hinting 2023-05-30 15:11:07 +05:00
ec6a110cdf Merge pull request #290 from m-bain/main
push contributions from main
2023-05-29 12:55:24 +01:00
8d8c027a92 Merge pull request #278 from Mr-Turtleeeee/add_align_for_vi
Add war2vec model for Vietnamese
2023-05-29 12:54:37 +01:00
4cbd3030cc no sentence split on mr. mrs. dr... 2023-05-29 12:48:14 +01:00
1c528d1a3c Merge pull request #284 from prameshbajra/main 2023-05-27 11:19:13 +01:00
c65e7ba9b4 Merge pull request #280 from Thebys/patch-1 2023-05-27 11:18:27 +01:00
5a47f458ac Added download path parameter. 2023-05-27 11:38:54 +02:00
f1032bb40a VAD unequal stack size, remove debug change 2023-05-26 20:39:19 +01:00
bc8a03881a Merge pull request #281 from m-bain/v3
fix Unequal Stack Size VAD error
2023-05-26 20:37:57 +01:00
42b4909bc0 fix Unequal Stack Size VAD error 2023-05-26 20:36:03 +01:00
bb15d6b68e Add Czech alignment model
This PR adds the following Czech alignment model: https://huggingface.co/comodoro/wav2vec2-xls-r-300m-cs-250.

I have successfully tested this with several Czech audio recordings with length of up to 3 hours, and the results are satisfactory.

However, I have received the following warnings and I am not sure how relevant it is:
```
Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file C:\Users\Thebys\.cache\torch\whisperx-vad-segmentation.bin`
Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.0. Bad things might happen unless you revert torch to 1.x.
```
2023-05-26 21:17:01 +02:00
23d405e1cf Merge branch 'main' into add_align_for_vi 2023-05-26 17:14:09 +01:00
17e2f7f859 Merge pull request #277 from Boulaouaney/add-Korean-alignment-model
added Korean wav2vec2 model
2023-05-26 17:12:47 +01:00
1d9d630fb9 added Korean wav2vec2 model 2023-05-26 20:33:16 +09:00
9c042c2d28 Add war2vec model for Vietnamese 2023-05-26 16:46:55 +07:00
a23f2aa3f7 Merge pull request #269 from sorgfresser/transcribe_keywords
Add transcribe keywords
2023-05-21 12:08:44 +01:00
7c5468116f Merge branch 'm-bain:main' into transcribe_keywords 2023-05-20 16:03:40 +02:00
a1c705b3a7 fix tokenizer is None 2023-05-20 15:52:45 +02:00
29a5e0b236 Merge pull request #266 from sorgfresser/main
Add device_index option
2023-05-20 14:45:34 +01:00
715435db42 add tokenizer is None case 2023-05-20 15:42:21 +02:00
1fc965bc1a add task, language keyword to transcribe 2023-05-20 15:30:25 +02:00
74b98ebfaa ensure device_index not None 2023-05-20 13:11:30 +02:00
53396adb21 add device_index 2023-05-20 13:02:46 +02:00
63fb5fc46f Suggest using pytorch-cuda 11.8 instead of 11.7
This prevents CuFFT errors on newer cards such as the RTX 4090 and RTX 6000 Ada.

fixes #254
2023-05-16 12:07:09 +02:00
d8a2b4ffc9 Merge pull request #246 from m-bain/v3
V3
2023-05-13 12:18:09 +01:00
9ffb7e7a23 Merge branch 'v3' of https://github.com/m-bain/whisperX into v3
Conflicts:
	setup.py
2023-05-13 12:16:33 +01:00
fd8f1003cf add translate, fix word_timestamp error 2023-05-13 12:14:06 +01:00
46b416296f Merge pull request #123 from koldbrandt/danish_alignment
Danish alignment model
2023-05-09 23:10:24 +01:00
7642390d0a Merge branch 'main' into danish_alignment 2023-05-09 23:10:13 +01:00
8b05ad4dae Merge pull request #235 from sorgfresser/main
Add custom typing for results
2023-05-09 23:05:02 +01:00
5421f1d7ca remove v3 tag on pip install 2023-05-09 13:42:50 +01:00
91e959ec4f Merge branch 'm-bain:main' into main 2023-05-08 20:46:25 +02:00
eabf35dff0 Custom result types 2023-05-08 20:45:34 +02:00
4919ad21fc Merge pull request #233 from sorgfresser/main
Fix tuple unpacking
2023-05-08 19:05:47 +01:00
b50aafb17b Fix tuple unpacking 2023-05-08 20:03:42 +02:00
2efa136114 update python usage example 2023-05-08 17:20:38 +01:00
0b839f3f01 Update README.md 2023-05-07 20:36:08 +01:00
1caddfb564 Merge pull request #225 from m-bain/v3
V3
2023-05-07 20:31:16 +01:00
7ad554c64f Merge branch 'main' into v3 2023-05-07 20:30:57 +01:00
4603f010a5 update readme, setup, add option to return char_timestamps 2023-05-07 20:28:33 +01:00
24008aa1ed fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based) 2023-05-07 15:32:58 +01:00
07361ba1d7 add device to dia pipeline @sorgfresser 2023-05-05 11:53:51 +01:00
4e2ac4e4e9 torch2.0, remove compile for now, round to times to 3 decimal 2023-05-04 20:38:13 +01:00
d2116b98ca Merge pull request #210 from sorgfresser/v3
Update pyannote and torch version
2023-05-04 20:32:06 +01:00
d8f0ef4a19 Set diarization device manually 2023-05-04 16:25:34 +02:00
1b62c61c71 Merge pull request #216 from aramlang/blank_id-fix
Enable Hebrew support
2023-05-04 01:13:23 +01:00
2d59eb9726 Add torch compile to log mel spectrogram 2023-05-03 23:17:44 +02:00
cb53661070 Enable Hebrew support 2023-05-03 11:26:12 -05:00
2a6830492c Fix pyannote to specific commit 2023-05-02 20:25:56 +02:00
da3aabe181 Merge branch 'm-bain:v3' into v3 2023-05-02 18:55:43 +02:00
067189248f Use pyannote develop branch and torch version 2 2023-05-02 18:44:43 +02:00
b666523004 add v3 pre-release comment, and v4 progress update 2023-05-02 15:10:40 +01:00
69e038cbc4 Merge pull request #209 from SohaibAnwaar/feat-dockerfile
feat: adding the docker file
2023-05-02 14:55:30 +01:00
9fb51412c0 Merge pull request #208 from arnavmehta7/patch-1 2023-05-02 10:55:13 +01:00
a693a779fa feat: adding the docker file 2023-05-02 13:28:20 +05:00
64ca208cc8 Fixed the word_start variable not initialized bug. 2023-05-02 13:13:02 +05:30
5becc99e56 Version bump pyannote, pytorch 2023-05-01 13:47:41 +02:00
e24ca9e0a2 Merge pull request #205 from prashanthellina/v3-fix-diarization 2023-04-30 21:08:45 +01:00
601c91140f references #202, attempt to fix speaker diarization failing in v3 2023-04-30 17:33:24 +00:00
31a9ec7466 Merge pull request #204 from sorgfresser/v3 2023-04-30 18:29:46 +01:00
b9c8c5072b Pad language detection if audio is too short 2023-04-30 18:34:18 +02:00
a903e57cf1 Merge pull request #199 from thomasmol/v3 2023-04-29 23:35:42 +01:00
cb176a186e added num_workers to fix pickling error 2023-04-29 19:51:05 +02:00
5b85c5433f Update setup.py 2023-04-28 16:47:04 +01:00
cc7e168d2b add checkout command 2023-04-25 12:14:23 +01:00
db97f29678 update pip install 2023-04-25 11:19:23 +01:00
25be8210e5 add v3 tag for install 2023-04-25 10:07:34 +01:00
0efad26066 pass compute_type 2023-04-24 21:26:44 +01:00
2a29f0ec6a add compute types 2023-04-24 21:24:22 +01:00
558d980535 v3 init 2023-04-24 21:08:43 +01:00
d31f6e0b8a Merge branch 'm-bain:main' into danish_alignment 2023-03-06 10:52:47 +01:00
c8404d9805 added a danish alignment model 2023-03-04 13:20:40 +01:00
27 changed files with 5716 additions and 1557 deletions

34
.github/workflows/build-and-release.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: Build and release
on:
release:
types: [published]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.5.14"
python-version: "3.9"
- name: Check if lockfile is up to date
run: uv lock --check
- name: Build package
run: uv build
- name: Release to Github
uses: softprops/action-gh-release@v2
with:
files: dist/*.whl
- name: Publish package to PyPi
run: uv publish
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -0,0 +1,34 @@
name: Python Compatibility Test
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.5.14"
python-version: ${{ matrix.python-version }}
- name: Check if lockfile is up to date
run: uv lock --check
- name: Install the project
run: uv sync --all-extras
- name: Test import
run: |
uv run python -c "import whisperx; print('Successfully imported whisperx')"

173
.gitignore vendored
View File

@ -1,2 +1,171 @@
whisperx.egg-info/
**/__pycache__/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc

39
LICENSE
View File

@ -1,27 +1,24 @@
Copyright (c) 2022, Max Bain
All rights reserved.
BSD 2-Clause License
Copyright (c) 2024, Max Bain
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. All advertising materials mentioning features or use of this software
must display the following acknowledgement:
This product includes software developed by Max Bain.
4. Neither the name of Max Bain nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,4 +1,3 @@
include whisperx/assets/*
include whisperx/assets/gpt2/*
include whisperx/assets/multilingual/*
include whisperx/normalizers/english.json
include LICENSE
include requirements.txt

294
README.md
View File

@ -13,36 +13,30 @@
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
alt="GitHub license">
</a>
<a href="https://arxiv.org/abs/2303.00747">
<img src="http://img.shields.io/badge/Arxiv-2303.00747-B31B1B.svg"
alt="ArXiv paper">
</a>
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
</a>
</p>
<p align="center">
<a href="#what-is-it">What is it</a>
<a href="#setup">Setup</a>
<a href="#example">Usage</a>
<a href="#other-languages">Multilingual</a>
<a href="#contribute">Contribute</a>
<a href="EXAMPLES.md">More examples</a>
<a href="https://arxiv.org/abs/2303.00747">Paper</a>
</p>
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment.
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
- 👯 Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
- 🗣 VAD preprocessing, reduces hallucination & batching with no WER degradation
</p>
<h2 align="left", id="what-is-it">What is it 🔎</h2>
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
@ -50,59 +44,75 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
<h2 align="left", id="highlights">New🚨</h2>
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
- _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2>
Install this package using
`pip install git+https://github.com/m-bain/whisperx.git`
### 1. Simple Installation (Recommended)
If already installed, update package to most recent commit
The easiest way to install WhisperX is through PyPi:
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
If wishing to modify this package, clone and install in editable mode:
```bash
pip install whisperx
```
$ git clone https://github.com/m-bain/whisperX.git
$ cd whisperX
$ pip install -e .
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
```bash
uvx whisperx
```
### 2. Advanced Installation Options
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
#### Option A: Install from GitHub
To install directly from the GitHub repository:
```bash
uvx git+https://github.com/m-bain/whisperX.git
```
#### Option B: Developer Installation
If you want to modify the code or contribute to the project:
```bash
git clone https://github.com/m-bain/whisperX.git
cd whisperX
uv sync --all-extras --dev
```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Setup not working???
Safest to use install pytorch as follows (for gpu)
`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -c pytorch
`
### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
> **Note**<br>
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
<h2 align="left" id="example">Usage 💬 (command line)</h2>
### English
Run whisper on example segment (using default params)
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx examples/sample01.wav
whisperx path/to/audio.wav
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
@ -110,105 +120,170 @@ Compare this to original whisper out the box, where many transcriptions are out
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
To run on CPU instead of GPU (and for running on Mac OS X):
whisperx path/to/audio.wav --compute_type int8
### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
Just pass in the `--language` code, and use the whisper `--model large`.
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
#### E.g. German
whisperx --model large --language de examples/sample_de_01.wav
whisperx --model large-v2 --language de path/to/audio.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
See more examples in other languages [here](EXAMPLES.md).
## Python usage 🐍
## Python usage 🐍
```python
import whisperx
import whisper
import gc
device = "cuda"
device = "cuda"
audio_file = "audio.mp3"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
# transcribe with original whisper
model = whisper.load_model("large", device)
result = model.transcribe(audio_file)
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# load alignment model and metadata
# delete model if low on GPU resources
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
print(result["segments"]) # after alignment
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
# delete model if low on GPU resources
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```
## Demos 🚀
<h2 align="left" id="whisper-mod">Whisper Modifications</h2>
[![Replicate (large-v3](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v3&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/victor-upmeet/whisperx)
[![Replicate (large-v2](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v2&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/daanelson/whisperx)
[![Replicate (medium)](https://img.shields.io/static/v1?label=Replicate+WhisperX+medium&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/carnifexer/whisperx)
In addition to forced alignment, the following two modifications have been made to the whisper transcription method:
If you don't have access to your own GPUs, use the links above to try out WhisperX.
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
<h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2>
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
1. reduce batch size, e.g. `--batch_size 4`
2. use a smaller ASR model `--model base`
3. Use lighter compute type `--compute_type int8`
Transcription differences from openai's whisper:
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
<h2 align="left" id="limitations">Limitations ⚠️</h2>
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
- Overlapping speech is not handled particularly well by whisper nor whisperx
- Diariazation is far from perfect.
- Diarization is far from perfect
- Language specific wav2vec2 model is needed
<h2 align="left" id="contribute">Contribute 🧑‍🏫</h2>
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success.
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
The next major upgrade we are working on is whisper with speaker diarization, so if you have any experience on this please share.
Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
<h2 align="left" id="coming-soon">TODO 🗓</h2>
* [x] Multilingual init
- [x] Multilingual init
* [x] Subtitle .ass output
- [x] Automatic align model selection based on language detection
* [x] Automatic align model selection based on language detection
- [x] Python usage
* [x] Python usage
- [x] Incorporating speaker diarization
* [x] Character level timestamps
- [x] Model flush, for low gpu mem resources
* [x] Incorporating speaker diarization
- [x] Faster-whisper backend
* [ ] Automatic .wav conversion to make VAD compatible
- [x] Add max-line etc. see (openai's whisper utils.py)
* [ ] Model flush, for low gpu mem resources
- [x] Sentence-level segments (nltk toolbox)
* [ ] Improve diarization (word level). *Harder than first thought...*
- [x] Improve alignment logic
- [ ] update examples with diarization and word highlighting
- [ ] Subtitle .ass output <- bring this back (removed in v3)
- [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
- [x] Allow silero-vad as alternative VAD option
- [ ] Improve diarization (word level). _Harder than first thought..._
<h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
Contact maxhbain@gmail.com for queries.
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from:
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
<h2 align="left" id="cite">Citation</h2>
If you use this in your research, please cite the paper:
@ -217,40 +292,7 @@ If you use this in your research, please cite the paper:
@article{bain2022whisperx,
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
journal={arXiv preprint, arXiv:2303.00747},
journal={INTERSPEECH 2023},
year={2023}
}
```
as well the following works, used in each stage of the pipeline:
```bibtex
@article{radford2022robust,
title={Robust speech recognition via large-scale weak supervision},
author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
journal={arXiv preprint arXiv:2212.04356},
year={2022}
}
```
```bibtex
@article{baevski2020wav2vec,
title={wav2vec 2.0: A framework for self-supervised learning of speech representations},
author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael},
journal={Advances in neural information processing systems},
volume={33},
pages={12449--12460},
year={2020}
}
```
```bibtex
@inproceedings{bredin2020pyannote,
title={Pyannote. audio: neural building blocks for speaker diarization},
author={Bredin, Herv{\'e} and Yin, Ruiqing and Coria, Juan Manuel and Gelly, Gregory and Korshunov, Pavel and Lavechin, Marvin and Fustes, Diego and Titeux, Hadrien and Bouaziz, Wassim and Gill, Marie-Philippe},
booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={7124--7128},
year={2020},
organization={IEEE}
}
```

36
pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[project]
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.3.4"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"
license = { text = "BSD-2-Clause" }
dependencies = [
"ctranslate2<4.5.0",
"faster-whisper>=1.1.1",
"nltk>=3.9.1",
"numpy>=2.0.2",
"onnxruntime>=1.19,<1.20.0",
"pandas>=2.2.3",
"pyannote-audio>=3.3.2",
"torch<2.4.0",
"torchaudio",
"transformers>=4.48.0",
]
[project.scripts]
whisperx = "whisperx.__main__:cli"
[build-system]
requires = ["setuptools"]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
where = ["."]
include = ["whisperx*"]

View File

@ -1,10 +0,0 @@
numpy
pandas
torch >=1.9
torchaudio >=0.10,<1.0
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
pyannote.audio
openai-whisper==20230314

View File

@ -1,28 +0,0 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
version="2.0",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
author="Max Bain",
url="https://github.com/m-bain/whisperx",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
},
include_package_data=True,
extras_require={'dev': ['pytest']},
)

2974
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,226 @@
import math
from whisperx.conjunctions import get_conjunctions, get_comma
def normal_round(n):
if n - math.floor(n) < 0.5:
return math.floor(n)
return math.ceil(n)
def format_timestamp(seconds: float, is_vtt: bool = False):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
separator = '.' if is_vtt else ','
hours_marker = f"{hours:02d}:"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
)
class SubtitlesProcessor:
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
self.comma = get_comma(lang)
self.conjunctions = set(get_conjunctions(lang))
self.segments = segments
self.lang = lang
self.max_line_length = max_line_length
self.min_char_length_splitter = min_char_length_splitter
self.is_vtt = is_vtt
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
if self.lang in complex_script_languages:
self.max_line_length = 30
self.min_char_length_splitter = 20
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
k = 0.25
has_prev_end = i > 0 and 'end' in words[i - 1]
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
if has_prev_end:
words[i]['start'] = words[i - 1]['end']
if has_next_start:
words[i]['end'] = words[i + 1]['start']
else:
if next_segment_start_time:
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
else:
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
elif has_next_start:
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
words[i]['end'] = words[i + 1]['start']
else:
if next_segment_start_time:
words[i]['start'] = next_segment_start_time - 1
words[i]['end'] = next_segment_start_time - 0.5
else:
words[i]['start'] = 0
words[i]['end'] = 0
def process_segments(self, advanced_splitting=True):
subtitles = []
for i, segment in enumerate(self.segments):
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
if advanced_splitting:
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
else:
words = segment['words']
for i, word in enumerate(words):
if 'start' not in word or 'end' not in word:
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
subtitles.append({
'start': segment['start'],
'end': segment['end'],
'text': segment['text']
})
return subtitles
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
split_points = []
last_split_point = 0
char_count = 0
words = segment.get('words', segment['text'].split())
add_space = 0 if self.lang in ['zh', 'ja'] else 1
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
char_count_after = total_char_count
for i, word in enumerate(words):
word_text = word['word'] if isinstance(word, dict) else word
word_length = len(word_text) + add_space
char_count += word_length
char_count_after -= word_length
char_count_before = char_count - word_length
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
if char_count >= self.max_line_length:
midpoint = normal_round((last_split_point + i) / 2)
if char_count_before >= self.min_char_length_splitter:
split_points.append(midpoint)
last_split_point = midpoint + 1
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i)
last_split_point = i + 1
char_count = 0
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i - 1)
last_split_point = i
char_count = word_length
return split_points
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
subtitles = []
words = segment.get('words', segment['text'].split())
total_word_count = len(words)
total_time = segment['end'] - segment['start']
elapsed_time = segment['start']
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
start_idx = 0
for split_point in split_points:
fragment_words = words[start_idx:split_point + 1]
current_word_count = len(fragment_words)
if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
end_time = next_start_time_for_word
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration
elapsed_time += current_duration
subtitles.append({
'start': start_time,
'end': end_time,
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})
start_idx = split_point + 1
# Handle the last fragment
if start_idx < len(words):
fragment_words = words[start_idx:]
current_word_count = len(fragment_words)
if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration
if next_start_time and (next_start_time - end_time) <= 0.8:
end_time = next_start_time
subtitles.append({
'start': start_time,
'end': end_time if end_time is not None else segment['end'],
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})
return subtitles
def save(self, filename="subtitles.srt", advanced_splitting=True):
subtitles = self.process_segments(advanced_splitting)
def write_subtitle(file, idx, start_time, end_time, text):
file.write(f"{idx}\n")
file.write(f"{start_time} --> {end_time}\n")
file.write(text + "\n\n")
with open(filename, 'w', encoding='utf-8') as file:
if self.is_vtt:
file.write("WEBVTT\n\n")
if advanced_splitting:
for idx, subtitle in enumerate(subtitles, 1):
start_time = format_timestamp(subtitle['start'], self.is_vtt)
end_time = format_timestamp(subtitle['end'], self.is_vtt)
text = subtitle['text'].strip()
write_subtitle(file, idx, start_time, end_time, text)
return len(subtitles)

View File

@ -1,3 +1,31 @@
from .transcribe import transcribe, transcribe_with_vad
from .alignment import load_align_model, align
from .vad import load_vad_model
import importlib
def _lazy_import(name):
module = importlib.import_module(f"whisperx.{name}")
return module
def load_align_model(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.load_align_model(*args, **kwargs)
def align(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.align(*args, **kwargs)
def load_model(*args, **kwargs):
asr = _lazy_import("asr")
return asr.load_model(*args, **kwargs)
def load_audio(*args, **kwargs):
audio = _lazy_import("audio")
return audio.load_audio(*args, **kwargs)
def assign_word_speakers(*args, **kwargs):
diarize = _lazy_import("diarize")
return diarize.assign_word_speakers(*args, **kwargs)

View File

@ -1,4 +1,88 @@
from .transcribe import cli
import argparse
import importlib.metadata
import platform
import torch
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
optional_int, str2bool)
cli()
def cli():
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on
args = parser.parse_args().__dict__
from whisperx.transcribe import transcribe_task
transcribe_task(args, parser)
if __name__ == "__main__":
cli()

View File

@ -1,17 +1,30 @@
""""
"""
Forced Alignment with Whisper
C. Max Bain
"""
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch
from dataclasses import dataclass
from whisper.audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from whisperx.audio import SAMPLE_RATE, load_audio
from whisperx.utils import interpolate_nans
from whisperx.types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -30,6 +43,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
@ -37,10 +51,30 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
"ca": "softcatala/wav2vec2-large-xlsr-catala",
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
"ro": "gigant/romanian-wav2vec2",
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
"gl": "ifrz/wav2vec2-large-xlsr-galician",
"ka": "xsway/wav2vec2-large-xlsr-georgian",
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
}
def load_align_model(language_code, device, model_name=None, model_dir=None):
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
@ -60,8 +94,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = Wav2Vec2Processor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
@ -77,427 +111,474 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
transcript: Iterator[dict],
transcript: Iterable[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# we only pad if not using VAD filtering
if "seg_text" not in segment:
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# strip spaces at beginning / end, but keep track of the amount.
if print_progress:
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = list(np.cumsum(seg_lens))
sub_seg_idx = 0
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = text.split(" ")
else:
per_word = text
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
clean_char, clean_cdx = [], []
for cdx, char in enumerate(text):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(text) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx)
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().values
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
sub_start = int(srow["subsegment-idx-start"])
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
wseg["start"].fillna(srow["start"], inplace=True)
wseg["end"].fillna(srow["end"], inplace=True)
wseg["segment-text-start"].fillna(0, inplace=True)
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
# fixes bug for single segment in transcript
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment_data[sdx] = {
"clean_char": clean_char,
"clean_cdx": clean_cdx,
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = []
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
"words": [],
"chars": None,
}
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
if t1 >= MAX_DURATION:
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary.get(c, -1) for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2]
# Handle the minimum input length for wav2vec2 models
if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad(
waveform_segment, (0, 400 - waveform_segment.shape[-1])
)
else:
lengths = None
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
aligned_segments.append(aligned_seg)
continue
char_segments = merge_repeats(path, text_clean)
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters
char_segments_arr = []
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
char_segments_arr.append(
{
"char": char,
"start": start,
"end": end,
"score": score,
"word-idx": word_idx,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES:
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
end_chars = curr_chars[curr_chars["char"] != ' ']
sentence_end = end_chars["end"].max()
sentence_words = []
for word_idx in curr_chars["word-idx"].unique():
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue
# dont use space character for alignment
word_chars = word_chars[word_chars["char"] != " "]
word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable
word_segment = {"word": word_text}
if not np.isnan(word_start):
word_segment["start"] = word_start
if not np.isnan(word_end):
word_segment["end"] = word_end
if not np.isnan(word_score):
word_segment["score"] = word_score
sentence_words.append(word_segment)
aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"words": sentence_words,
})
if return_char_alignments:
curr_chars = curr_chars[["char", "start", "end", "score"]]
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
aligned_subsegments[-1]["chars"] = curr_chars
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
if model_lang in LANGUAGES_WITHOUT_SPACES:
agg_dict["text"] = "".join
if return_char_alignments:
agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
aligned_subsegments = aligned_subsegments.to_dict('records')
aligned_segments += aligned_subsegments
# create word_segments list
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]
return {"segments": aligned_segments, "word_segments": word_segments}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1:, 0] = float("inf")
for t in range(num_frame):
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
)
return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)
# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# 3. Update the token
# Update position
t -= 1
if changed > stayed:
j -= 1
if j == 0:
break
else:
# failed
return None
# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break
if not beams:
return None
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
# Merge the labels
@dataclass
class Segment:

View File

@ -1,433 +1,416 @@
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import os
from typing import List, Optional, Union
from dataclasses import replace
import ctranslate2
import faster_whisper
import numpy as np
import torch
import tqdm
import ffmpeg
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
CHUNK_LENGTH,
log_mel_spectrogram,
pad_or_trim,
load_audio
)
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.timing import add_word_timestamps
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import (
exact_div,
format_timestamp,
make_safe,
)
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
if TYPE_CHECKING:
from whisper.model import Whisper
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import SingleSegment, TranscriptionResult
from whisperx.vads import Vad, Silero, Pyannote
from .vad import merge_chunks
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor] = None,
mel: np.ndarray = None,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper.
We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
has_numeral_symbol = any(c in "0123456789%" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens
Parameters
----------
model: Whisper
The Whisper model instance
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
mel: np.ndarray
Mel spectrogram of audio segment.
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
if mel is None:
if audio is None:
raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
hotwords=options.hotwords
)
encoder_output = self.encode(features)
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
tokens_batch = [x.sequences_ids[0] for x in result]
previous_seek = seek
current_segments = []
def decode_batch(tokens: List[List[int]]) -> str:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
text = decode_batch(tokens_batch)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
return text
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# clamp end-time to at least be 1 frame after start-time
end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
return self.model.encode(features, to_cpu=to_cpu)
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model: WhisperModel,
vad,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.suppress_numerals = suppress_numerals
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
super(Pipeline, self).__init__()
self.vad_model = vad
self._vad_params = vad_params
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, audio):
audio = audio['inputs']
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
def postprocess(self, model_outputs):
return model_outputs
def get_iterator(
self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO hack by collating feature_extractor and image_processor
def stack(items):
return {'inputs': torch.stack([x['inputs'] for x in items])}
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def transcribe(
self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
onset=self._vad_params["vad_onset"],
offset=self._vad_params["vad_offset"],
)
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
print(f"Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def transcribe_with_vad(
model: "Whisper",
audio: str,
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
**kwargs
):
"""
Transcribe per VAD segment
"""
vad_segments = vad_pipeline(audio)
# if not torch.is_tensor(audio):
# if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
prev = 0
output = {"segments": []}
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
if len(vad_segments) == 0:
return output
print(">>Performing transcription...")
for sdx, seg_t in enumerate(vad_segments):
if verbose:
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
audio = audio[local_f_start:] # seek forward
seg_audio = audio[:local_f_end-local_f_start] # seek forward
prev = seg_f_start
local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
# need to pad
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t["text"] = result["text"]
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
}
)
output["language"] = output["segments"][0]["language"]
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
return output
# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
"""
if whisper_arch.endswith(".en"):
language = "en"
model = model or WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
cpu_threads=threads)
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、",
"multilingual": model.model.is_multilingual,
"suppress_numerals": False,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
"hotwords": None,
}
if asr_options is not None:
default_asr_options.update(asr_options)
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500,
"vad_offset": 0.363
}
if vad_options is not None:
default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model
else:
if vad_method == "silero":
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")
return FasterWhisperPipeline(
model=model,
vad=vad_model,
options=default_asr_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)

Binary file not shown.

Binary file not shown.

159
whisperx/audio.py Normal file
View File

@ -0,0 +1,159 @@
import os
import subprocess
from functools import lru_cache
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from whisperx.utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI to be installed.
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
file,
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(sr),
"-",
]
out = subprocess.run(cmd, capture_output=True, check=True).stdout
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

47
whisperx/conjunctions.py Normal file
View File

@ -0,0 +1,47 @@
# conjunctions.py
from typing import Set
conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', '', 'ossia', 'cioè'},
'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
'zh': {'', '', '但是', '因为', '任何', '', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', '', 'pois', 'nem', 'senão'},
'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
'vi': {'', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', '', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', '', 'giống', 'cũng', 'tức'},
'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '', '그래서', '', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', ''},
'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
'hi': {'और', 'या', 'पर', 'तो', '', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', '', 'हालांकि'}
}
commas_by_language = {
'ja': '',
'zh': '',
'fa': '،',
'ur': '،'
}
def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -1,76 +1,86 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
from whisperx.audio import load_audio, SAMPLE_RATE
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline:
def __init__(
self,
model_name="pyannote/speaker-diarization@2.1",
model_name=None,
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str):
device = torch.device(device)
model_config = model_name or "pyannote/speaker-diarization-3.1"
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
return diarize_df
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = seg['word-segments']
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
if not np.isnan(wrow['start']):
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
else:
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
speaker_count = pd.Series(speakers).value_counts()
if len(speaker_count) == 0:
seg["speaker"]= "UNKNOWN"
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
# remove no hit, otherwise we look for closest (even negative intersection...)
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
seg["speaker"] = speaker_count.index[0]
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
seg["speaker"] = speaker
# assign speaker to words
if 'words' in seg:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
return transcript_result
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
self.start = start
self.end = end
self.speaker = speaker

View File

@ -1,131 +1,82 @@
import argparse
import os
import gc
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tempfile
import ffmpeg
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.audio import SAMPLE_RATE
from whisper.utils import (
optional_float,
optional_int,
str2bool,
)
from .alignment import load_align_model, align
from .asr import transcribe, transcribe_with_vad
from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import get_writer
from .vad import load_vad_model
from whisperx.alignment import align, load_align_model
from whisperx.asr import load_model
from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
def cli():
from whisper import available_models
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
"""Transcription task to be called from CLI.
Args:
args: Dictionary of command-line arguments.
parser: argparse.ArgumentParser object.
"""
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
# vad params
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir")
model_cache_only: bool = args.pop("model_cache_only")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
verbose: bool = args.pop("verbose")
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
chunk_size: int = args.pop("chunk_size")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
if vad_filter:
from pyannote.audio import Pipeline
from pyannote.audio import Model, Pipeline
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
else:
vad_model = None
if args["language"] is not None:
args["language"] = args["language"].lower()
if args["language"] not in LANGUAGES:
if args["language"] in TO_LANGUAGE_CODE:
args["language"] = TO_LANGUAGE_CODE[args["language"]]
else:
raise ValueError(f"Unsupported language: {args['language']}")
# if model_flush:
# print(">>Model flushing activated... Only loading model after ASR stage")
# del align_model
# align_model = ""
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if model_name.endswith(".en") and args["language"] != "en":
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
)
args["language"] = "en"
align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@ -133,42 +84,73 @@ def cli():
else:
temperature = [temperature]
faster_whisper_threads = 4
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
faster_whisper_threads = threads
from whisper import load_model
asr_options = {
"beam_size": args.pop("beam_size"),
"patience": args.pop("patience"),
"length_penalty": args.pop("length_penalty"),
"temperatures": temperature,
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
"log_prob_threshold": args.pop("logprob_threshold"),
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
}
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if no_align:
for option in word_options:
if args[option]:
parser.error(f"--{option} not possible with --no_align")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"):
input_audio_path = audio_path
tfile = None
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR
if vad_model is not None:
if not audio_path.endswith(".wav"):
print(">>VAD requires .wav format, converting to wav as a tempfile...")
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
if tmp_dir is not None:
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
else:
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
print(">>Performing VAD...")
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
else:
print(">>Performing transcription...")
result = transcribe(model, input_audio_path, temperature=temperature, **args)
results.append((result, input_audio_path))
print(">>Performing transcription...")
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
del model
del vad_model
gc.collect()
torch.cuda.empty_cache()
@ -176,19 +158,39 @@ def cli():
if not no_align:
tmp_results = results
results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for result, input_audio_path in tmp_results:
align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results:
# >> Align
if len(tmp_results) > 1:
input_audio = audio_path
else:
# lazily load audio from part 1
input_audio = audio
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
results.append((result, input_audio_path))
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))
# Unload align model
del align_model
@ -198,23 +200,21 @@ def cli():
# >> Diarize
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results
print(">>Performing diarization...")
print(">>Using model:", diarize_model_name)
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
result = {"segments": results_segments, "word_segments": word_segments}
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
writer(result, audio_path)
# cleanup
if input_audio_path != audio_path:
os.remove(input_audio_path)
if __name__ == "__main__":
cli()
result["language"] = align_language
writer(result, audio_path, writer_args)

69
whisperx/types.py Normal file
View File

@ -0,0 +1,69 @@
from typing import TypedDict, Optional, List, Tuple
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
words: List[SingleWordSegment]
chars: Optional[List[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: List[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: List[SingleAlignedSegment]
word_segments: List[SingleWordSegment]

View File

@ -1,279 +1,332 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
import numpy as np
from typing import Callable, Optional, TextIO
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
return x.ffill().bfill()
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
print(segment['start'], file=file, end="\t")
print(segment['end'], file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options)
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
class WriteTXT(ResultWriter):
extension: str = "txt"
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def write_ass(transcript: Iterator[dict],
file: TextIO,
resolution: str = "word",
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
strip=True, **kwargs):
"""
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
Generate Advanced SubStation Alpha (ass) file from results to
display both phrase-level & word-level timestamp simultaneously by:
-using segment-level timestamps display phrases as usual
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
transcript: dict
results from modified model
file: TextIO
file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
underline: bool
whether to underline a word at its corresponding timestamp
prefmt: str
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
suffmt: str
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
font: str
word font (default: Arial)
font_size: int
word font size (default: 48)
kwargs:
used for format styles:
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
"""
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
kwargs[k] = f'&H{kwargs[k]}'
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
if font:
fmt_style_dict.update(Fontname=font)
if font_size:
fmt_style_dict.update(Fontsize=font_size)
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
if prefmt or suffmt:
if suffmt:
assert prefmt, 'prefmt must be used along with suffmt'
else:
suffmt = r'\r'
else:
if not color:
color = 'HFF00'
underline_code = r'\u1' if underline else ''
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
suffmt = r'{\r}'
def secs_to_hhmmss(secs: Tuple[float, int]):
mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60)
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
if idx_0 == -1:
text = chars
else:
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = []
for segment in transcript:
# if "12" in segment['text']:
# import pdb; pdb.set_trace()
if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
if "speaker" in segment:
speaker_str = f"[{segment['speaker']}]: "
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
speaker_str = ""
for cdx, crow in res_segs.iterrows():
if not np.isnan(crow['start']):
if resolution == "char":
idx_0 = cdx
idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"chars": speaker_str + segment['text'],
"start": prev,
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
}
ass_arr.append(f_word_ts)
prev = crow['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str)
print(text, file=file, flush=True)
from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
class WriteASS(ResultWriter):
extension: str = "ass"
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resolution="word")
class WriteASSchar(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resolution="char")
class WritePickle(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
pd.DataFrame(result["segments"]).to_pickle(file)
class WriteSRTWord(ResultWriter):
extension: str = "word.srt"
always_include_hours: bool = True
decimal_marker: str = ","
def iterate_result(self, result: dict):
for segment in result["word_segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text
def write_result(self, result: dict, file: TextIO):
if "word_segments" not in result:
if len(result["segments"]) == 0:
return
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times: list[tuple] = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments
if "start" in timing:
long_pause = long_pause and timing["start"] - last > 3.0
else:
long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle, times
subtitle = []
times = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
times.append((segment["start"], segment["end"], segment.get("speaker")))
if "start" in timing:
last = timing["start"]
if len(subtitle) > 0:
yield subtitle, times
if "words" in result["segments"][0]:
for subtitle, _ in iterate_subtitles():
sstart, ssend, speaker = _[0]
subtitle_start = self.format_timestamp(sstart)
subtitle_end = self.format_timestamp(ssend)
if result["language"] in LANGUAGES_WITHOUT_SPACES:
subtitle_text = "".join([word["word"] for word in subtitle])
else:
subtitle_text = " ".join([word["word"] for word in subtitle])
has_timing = any(["start" in word for word in subtitle])
# add [$SPEAKER_ID]: to each subtitle if speaker is available
prefix = ""
if speaker is not None:
prefix = f"[{speaker}]: "
if highlight_words and has_timing:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
if "start" in this_word:
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, prefix + subtitle_text
yield start, end, prefix + " ".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, prefix + subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if "speaker" in segment:
segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
@ -282,36 +335,108 @@ class WriteSRTWord(ResultWriter):
decimal_marker=self.decimal_marker,
)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteAudacity(ResultWriter):
"""
Write a transcript to a text file that audacity can import as labels.
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
Yet this is not an audacity project but only a label file!
Please note : Audacity uses seconds in timestamps not ms!
Also there is no header expected.
If speaker is provided it is prepended to the text between double square brackets [[]].
"""
extension: str = "aud"
def write_result(self, result: dict, file: TextIO, options: dict):
ARROW = " "
for segment in result["segments"]:
print(segment["start"], file=file, end=ARROW)
print(segment["end"], file=file, end=ARROW)
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file, ensure_ascii=False)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"ass": WriteASS,
"srt-word": WriteSRTWord,
# "ass-char": WriteASSchar,
# "pickle": WritePickle,
# "json": WriteJSON,
"json": WriteJSON,
}
writers_other = {
"pkl": WritePickle,
"ass-char": WriteASSchar
optional_writers = {
"aud": WriteAudacity,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
writer(result, file)
writer(result, file, options)
return write_all
if output_format in writers:
return writers[output_format](output_dir)
elif output_format in writers_other:
return writers_other[output_format](output_dir)
if output_format in optional_writers:
return optional_writers[output_format](output_dir)
return writers[output_format](output_dir)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
return x.ffill().bfill()

View File

@ -0,0 +1,3 @@
from whisperx.vads.pyannote import Pyannote as Pyannote
from whisperx.vads.silero import Silero as Silero
from whisperx.vads.vad import Vad as Vad

View File

@ -1,54 +1,44 @@
import os
import urllib
import pandas as pd
from typing import Callable, Text, Union
from typing import Optional
import numpy as np
import torch
import hashlib
from tqdm import tqdm
from typing import Optional, Callable, Union, Text
from pyannote.audio.core.io import AudioFile
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from .diarize import Segment as SegmentX
from typing import List, Tuple, Optional
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
# Dynamically resolve the path to the model file
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
@ -84,21 +74,21 @@ class Binarize:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
@ -141,13 +131,12 @@ class Binarize:
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
# if curr_duration > 15:
# import pdb; pdb.set_trace()
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
@ -155,8 +144,8 @@ class Binarize:
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
@ -165,14 +154,14 @@ class Binarize:
is_active = False
curr_scores = []
curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
curr_scores.append(y)
curr_timestamps.append(t)
# if active at the end, add final region
if is_active:
@ -197,11 +186,11 @@ class Binarize:
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
@ -240,67 +229,35 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
return segmentations
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
class Pyannote(Vad):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
def merge_chunks(segments, chunk_size):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

66
whisperx/vads/silero.py Normal file
View File

@ -0,0 +1,66 @@
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments_list,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

74
whisperx/vads/vad.py Normal file
View File

@ -0,0 +1,74 @@
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
# Unused function
@staticmethod
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs