mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
41 Commits
Author | SHA1 | Date | |
---|---|---|---|
847a3cd85b | |||
2b1ffa12b8 | |||
57f5957e0e | |||
27fe502344 | |||
f7093e60d3 | |||
a1d2229416 | |||
4cb167a225 | |||
2e307814dd | |||
d687cf3358 | |||
0a3fd11562 | |||
29e95b746b | |||
039af89a86 | |||
9f26112d5c | |||
fd2a093754 | |||
31f069752f | |||
4cdf7ef856 | |||
d294e29ad9 | |||
0eae9e1f50 | |||
1b08661e42 | |||
a49799294b | |||
d83c74a79f | |||
acaefa09a1 | |||
76f79f600a | |||
33073f9bba | |||
50f3965fdb | |||
df2b1b70cb | |||
c19cf407d8 | |||
8081ef2dcd | |||
c6dbac76c8 | |||
69673eb39b | |||
5b8c8a7bd3 | |||
7f2159a953 | |||
16d24b1c96 | |||
d20a2a4ea2 | |||
312f1cc50c | |||
99b6e79fbf | |||
e7773358a3 | |||
6b2aa4ff3e | |||
c3de5e9580 | |||
58d7191949 | |||
286a2f2c14 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
custom: https://www.buymeacoffee.com/maxhbain
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## Other Languages
|
## Other Languages
|
||||||
|
|
||||||
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
|
||||||
|
|
||||||
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
|
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
|
||||||
|
|
||||||
|
53
README.md
53
README.md
@ -27,7 +27,6 @@
|
|||||||
<a href="EXAMPLES.md">More examples</a>
|
<a href="EXAMPLES.md">More examples</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h6 align="center">Made by Max Bain • :globe_with_meridians: <a href="https://www.maxbain.com">https://www.maxbain.com</a></h6>
|
|
||||||
|
|
||||||
<img width="1216" align="center" alt="whisperx-arch" src="https://user-images.githubusercontent.com/36994049/211200186-8b779e26-0bfd-4127-aee2-5a9238b95e1f.png">
|
<img width="1216" align="center" alt="whisperx-arch" src="https://user-images.githubusercontent.com/36994049/211200186-8b779e26-0bfd-4127-aee2-5a9238b95e1f.png">
|
||||||
|
|
||||||
@ -50,9 +49,10 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
|
|||||||
|
|
||||||
<h2 align="left", id="highlights">New🚨</h2>
|
<h2 align="left", id="highlights">New🚨</h2>
|
||||||
|
|
||||||
|
- Batch processing: Add `--vad_filter --parallel_bs [int]` for transcribing long audio file in batches (only supported with VAD filtering). Replace `[int]` with a batch size that fits your GPU memory, e.g. `--parallel_bs 16`.
|
||||||
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
||||||
- Character level timestamps (see `*.char.ass` file output)
|
- Character level timestamps (see `*.char.ass` file output)
|
||||||
- Diarization (still in beta, add `--diarization`)
|
- Diarization (still in beta, add `--diarize`)
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
@ -71,9 +71,13 @@ $ cd whisperX
|
|||||||
$ pip install -e .
|
$ pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
|
|
||||||
|
### Voice Activity Detection Filtering & Diarization
|
||||||
|
To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||||
|
|
||||||
### English
|
### English
|
||||||
@ -85,7 +89,7 @@ Run whisper on example segment (using default params)
|
|||||||
|
|
||||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
|
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
|
whisperx examples/sample01.wav --model large-v2 --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
|
||||||
|
|
||||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||||
|
|
||||||
@ -149,8 +153,9 @@ In addition to forced alignment, the following two modifications have been made
|
|||||||
|
|
||||||
- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data
|
- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data
|
||||||
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
||||||
- Assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
|
- If not using VAD filter, whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
|
||||||
- Hacked this up quite quickly, there might be some errors, please raise an issue if you encounter any.
|
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
||||||
|
- Diariazation is far from perfect.
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
||||||
@ -161,38 +166,46 @@ The next major upgrade we are working on is whisper with speaker diarization, so
|
|||||||
|
|
||||||
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
|
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
|
||||||
|
|
||||||
[x] ~~Multilingual init~~ done
|
* [x] Multilingual init
|
||||||
|
|
||||||
[x] ~~Subtitle .ass output~~ done
|
* [x] Subtitle .ass output
|
||||||
|
|
||||||
[x] ~~Automatic align model selection based on language detection~~ done
|
* [x] Automatic align model selection based on language detection
|
||||||
|
|
||||||
[x] ~~Python usage~~ done
|
* [x] Python usage
|
||||||
|
|
||||||
[x] ~~Character level timestamps~~
|
* [x] Character level timestamps
|
||||||
|
|
||||||
[x] ~~Incorporating speaker diarization~~
|
* [x] Incorporating speaker diarization
|
||||||
|
|
||||||
[ ] Improve diarization (word level)
|
* [x] Inference speedup with batch processing
|
||||||
|
|
||||||
[ ] Inference speedup with batch processing
|
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||||
|
|
||||||
<h2 align="left" id="contact">Contact 📇</h2>
|
|
||||||
|
|
||||||
Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for business things.
|
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||||
|
|
||||||
|
Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for queries
|
||||||
|
|
||||||
|
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
||||||
|
|
||||||
Of course, this is mostly just a modification to [openAI's whisper](https://github.com/openai/whisper).
|
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and University of Oxford.
|
||||||
As well as accreditation to this [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
|
||||||
|
|
||||||
|
|
||||||
|
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
|
||||||
|
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="cite">Citation</h2>
|
<h2 align="left" id="cite">Citation</h2>
|
||||||
If you use this in your research, just cite the repo,
|
If you use this in your research, for now just cite the repo,
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{bain2022whisperx,
|
@misc{bain2022whisperx,
|
||||||
author = {Bain, Max},
|
author = {Bain, Max and Han, Tengda},
|
||||||
title = {WhisperX},
|
title = {WhisperX},
|
||||||
year = {2022},
|
year = {2022},
|
||||||
publisher = {GitHub},
|
publisher = {GitHub},
|
||||||
|
@ -11,8 +11,8 @@ from tqdm import tqdm
|
|||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
from .model import Whisper, ModelDimensions
|
from .model import Whisper, ModelDimensions
|
||||||
from .transcribe import transcribe, load_align_model, align, transcribe_with_vad
|
from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel
|
||||||
|
from .alignment import load_align_model, align
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
@ -1,9 +1,443 @@
|
|||||||
|
""""
|
||||||
|
Forced Alignment with Whisper
|
||||||
|
C. Max Bain
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Union, Iterator, TYPE_CHECKING
|
||||||
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
|
from .utils import interpolate_nans
|
||||||
|
|
||||||
|
|
||||||
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
|
DEFAULT_ALIGN_MODELS_TORCH = {
|
||||||
|
"en": "WAV2VEC2_ASR_BASE_960H",
|
||||||
|
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
||||||
|
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
||||||
|
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
||||||
|
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_ALIGN_MODELS_HF = {
|
||||||
|
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
||||||
|
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
||||||
|
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
||||||
|
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||||
|
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||||
|
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
||||||
|
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
||||||
|
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
||||||
|
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
||||||
|
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
|
||||||
|
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||||
|
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||||
|
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_align_model(language_code, device, model_name=None):
|
||||||
|
if model_name is None:
|
||||||
|
# use default model
|
||||||
|
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||||
|
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
||||||
|
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||||
|
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||||
|
else:
|
||||||
|
print(f"There is no default alignment model set for this language ({language_code}).\
|
||||||
|
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
||||||
|
raise ValueError(f"No default align-model for language: {language_code}")
|
||||||
|
|
||||||
|
if model_name in torchaudio.pipelines.__all__:
|
||||||
|
pipeline_type = "torchaudio"
|
||||||
|
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||||
|
align_model = bundle.get_model().to(device)
|
||||||
|
labels = bundle.get_labels()
|
||||||
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||||
|
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||||
|
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
||||||
|
pipeline_type = "huggingface"
|
||||||
|
align_model = align_model.to(device)
|
||||||
|
labels = processor.tokenizer.get_vocab()
|
||||||
|
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
||||||
|
|
||||||
|
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
||||||
|
|
||||||
|
return align_model, align_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def align(
|
||||||
|
transcript: Iterator[dict],
|
||||||
|
model: torch.nn.Module,
|
||||||
|
align_model_metadata: dict,
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
device: str,
|
||||||
|
extend_duration: float = 0.0,
|
||||||
|
start_from_previous: bool = True,
|
||||||
|
interpolate_method: str = "nearest",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Force align phoneme recognition predictions to known transcription
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
transcript: Iterator[dict]
|
||||||
|
The Whisper model instance
|
||||||
|
|
||||||
|
model: torch.nn.Module
|
||||||
|
Alignment model (wav2vec2)
|
||||||
|
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
|
device: str
|
||||||
|
cuda device
|
||||||
|
|
||||||
|
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
||||||
|
diarization segments with speaker labels.
|
||||||
|
|
||||||
|
extend_duration: float
|
||||||
|
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||||
|
|
||||||
|
If the gzip compression ratio is above this value, treat as failed
|
||||||
|
|
||||||
|
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||||
|
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||||
|
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(audio):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
|
audio = torch.from_numpy(audio)
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||||
|
|
||||||
|
model_dictionary = align_model_metadata["dictionary"]
|
||||||
|
model_lang = align_model_metadata["language"]
|
||||||
|
model_type = align_model_metadata["type"]
|
||||||
|
|
||||||
|
aligned_segments = []
|
||||||
|
|
||||||
|
prev_t2 = 0
|
||||||
|
|
||||||
|
char_segments_arr = {
|
||||||
|
"segment-idx": [],
|
||||||
|
"subsegment-idx": [],
|
||||||
|
"word-idx": [],
|
||||||
|
"char": [],
|
||||||
|
"start": [],
|
||||||
|
"end": [],
|
||||||
|
"score": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for sdx, segment in enumerate(transcript):
|
||||||
|
while True:
|
||||||
|
segment_align_success = False
|
||||||
|
|
||||||
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
|
transcription = segment["text"]
|
||||||
|
|
||||||
|
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||||
|
# e.g. "$300" -> "three hundred dollars"
|
||||||
|
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||||
|
|
||||||
|
# split into words
|
||||||
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
|
per_word = transcription.split(" ")
|
||||||
|
else:
|
||||||
|
per_word = transcription
|
||||||
|
|
||||||
|
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||||
|
clean_char, clean_cdx = [], []
|
||||||
|
for cdx, char in enumerate(transcription):
|
||||||
|
char_ = char.lower()
|
||||||
|
# wav2vec2 models use "|" character to represent spaces
|
||||||
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
|
char_ = char_.replace(" ", "|")
|
||||||
|
|
||||||
|
# ignore whitespace at beginning and end of transcript
|
||||||
|
if cdx < num_leading:
|
||||||
|
pass
|
||||||
|
elif cdx > len(transcription) - num_trailing - 1:
|
||||||
|
pass
|
||||||
|
elif char_ in model_dictionary.keys():
|
||||||
|
clean_char.append(char_)
|
||||||
|
clean_cdx.append(cdx)
|
||||||
|
|
||||||
|
clean_wdx = []
|
||||||
|
for wdx, wrd in enumerate(per_word):
|
||||||
|
if any([c in model_dictionary.keys() for c in wrd]):
|
||||||
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
|
# if no characters are in the dictionary, then we skip this segment...
|
||||||
|
if len(clean_char) == 0:
|
||||||
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
|
break
|
||||||
|
|
||||||
|
transcription_cleaned = "".join(clean_char)
|
||||||
|
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||||
|
|
||||||
|
# we only pad if not using VAD filtering
|
||||||
|
if "seg_text" not in segment:
|
||||||
|
# pad according original timestamps
|
||||||
|
t1 = max(segment["start"] - extend_duration, 0)
|
||||||
|
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||||
|
|
||||||
|
# use prev_t2 as current t1 if it"s later
|
||||||
|
if start_from_previous and t1 < prev_t2:
|
||||||
|
t1 = prev_t2
|
||||||
|
|
||||||
|
# check if timestamp range is still valid
|
||||||
|
if t1 >= MAX_DURATION:
|
||||||
|
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||||
|
break
|
||||||
|
if t2 - t1 < 0.02:
|
||||||
|
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||||
|
break
|
||||||
|
|
||||||
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
|
|
||||||
|
waveform_segment = audio[:, f1:f2]
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if model_type == "torchaudio":
|
||||||
|
emissions, _ = model(waveform_segment.to(device))
|
||||||
|
elif model_type == "huggingface":
|
||||||
|
emissions = model(waveform_segment.to(device)).logits
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||||
|
emissions = torch.log_softmax(emissions, dim=-1)
|
||||||
|
|
||||||
|
emission = emissions[0].cpu().detach()
|
||||||
|
|
||||||
|
trellis = get_trellis(emission, tokens)
|
||||||
|
path = backtrack(trellis, emission, tokens)
|
||||||
|
if path is None:
|
||||||
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
|
break
|
||||||
|
char_segments = merge_repeats(path, transcription_cleaned)
|
||||||
|
# word_segments = merge_words(char_segments)
|
||||||
|
|
||||||
|
|
||||||
|
# sub-segments
|
||||||
|
if "seg-text" not in segment:
|
||||||
|
segment["seg-text"] = [transcription]
|
||||||
|
|
||||||
|
v = 0
|
||||||
|
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||||
|
seg_lens_cumsum = [v := v + n for n in seg_lens]
|
||||||
|
sub_seg_idx = 0
|
||||||
|
|
||||||
|
wdx = 0
|
||||||
|
duration = t2 - t1
|
||||||
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
|
for cdx, char in enumerate(transcription + " "):
|
||||||
|
is_last = False
|
||||||
|
if cdx == len(transcription):
|
||||||
|
break
|
||||||
|
elif cdx+1 == len(transcription):
|
||||||
|
is_last = True
|
||||||
|
|
||||||
|
|
||||||
|
start, end, score = None, None, None
|
||||||
|
if cdx in clean_cdx:
|
||||||
|
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||||
|
start = char_seg.start * ratio + t1
|
||||||
|
end = char_seg.end * ratio + t1
|
||||||
|
score = char_seg.score
|
||||||
|
|
||||||
|
char_segments_arr["char"].append(char)
|
||||||
|
char_segments_arr["start"].append(start)
|
||||||
|
char_segments_arr["end"].append(end)
|
||||||
|
char_segments_arr["score"].append(score)
|
||||||
|
char_segments_arr["word-idx"].append(wdx)
|
||||||
|
char_segments_arr["segment-idx"].append(sdx)
|
||||||
|
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
||||||
|
|
||||||
|
# word-level info
|
||||||
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||||
|
# character == word
|
||||||
|
wdx += 1
|
||||||
|
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||||
|
wdx += 1
|
||||||
|
|
||||||
|
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||||
|
wdx = 0
|
||||||
|
sub_seg_idx += 1
|
||||||
|
|
||||||
|
prev_t2 = segment["end"]
|
||||||
|
|
||||||
|
segment_align_success = True
|
||||||
|
# end while True loop
|
||||||
|
break
|
||||||
|
|
||||||
|
# reset prev_t2 due to drifting issues
|
||||||
|
if not segment_align_success:
|
||||||
|
prev_t2 = 0
|
||||||
|
|
||||||
|
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||||
|
not_space = char_segments_arr["char"] != " "
|
||||||
|
|
||||||
|
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
||||||
|
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
||||||
|
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||||
|
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||||
|
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||||
|
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
||||||
|
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
||||||
|
|
||||||
|
word_segments_arr = {}
|
||||||
|
|
||||||
|
# start of word is first char with a timestamp
|
||||||
|
word_segments_arr["start"] = per_word_grp["start"].min().values
|
||||||
|
# end of word is last char with a timestamp
|
||||||
|
word_segments_arr["end"] = per_word_grp["end"].max().values
|
||||||
|
# score of word is mean (excluding nan)
|
||||||
|
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
||||||
|
|
||||||
|
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
||||||
|
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
||||||
|
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||||
|
|
||||||
|
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
||||||
|
segments_arr = {}
|
||||||
|
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||||
|
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
||||||
|
segments_arr = pd.DataFrame(segments_arr)
|
||||||
|
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
||||||
|
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
||||||
|
|
||||||
|
# interpolate missing words / sub-segments
|
||||||
|
if interpolate_method != "ignore":
|
||||||
|
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
||||||
|
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||||
|
# we still know which word timestamps are interpolated because their score == nan
|
||||||
|
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
|
||||||
|
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
|
||||||
|
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||||
|
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
|
||||||
|
# merge words & subsegments which are missing times
|
||||||
|
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||||
|
|
||||||
|
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||||
|
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||||
|
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||||
|
|
||||||
|
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||||
|
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||||
|
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||||
|
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
||||||
|
else:
|
||||||
|
word_segments_arr.dropna(inplace=True)
|
||||||
|
segments_arr.dropna(inplace=True)
|
||||||
|
|
||||||
|
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
||||||
|
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
||||||
|
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
||||||
|
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
||||||
|
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
aligned_segments = []
|
||||||
|
aligned_segments_word = []
|
||||||
|
|
||||||
|
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
||||||
|
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
||||||
|
|
||||||
|
for sdx, srow in segments_arr.iterrows():
|
||||||
|
|
||||||
|
seg_idx = int(srow["segment-idx"])
|
||||||
|
try:
|
||||||
|
sub_start = int(srow["subsegment-idx-start"])
|
||||||
|
except:
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
sub_end = int(srow["subsegment-idx-end"])
|
||||||
|
|
||||||
|
seg = transcript[seg_idx]
|
||||||
|
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||||
|
|
||||||
|
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||||
|
wseg["start"].fillna(srow["start"], inplace=True)
|
||||||
|
wseg["end"].fillna(srow["end"], inplace=True)
|
||||||
|
wseg["segment-text-start"].fillna(0, inplace=True)
|
||||||
|
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
||||||
|
|
||||||
|
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||||
|
# fixes bug for single segment in transcript
|
||||||
|
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
||||||
|
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
||||||
|
if 'level_1' in cseg: del cseg['level_1']
|
||||||
|
if 'level_0' in cseg: del cseg['level_0']
|
||||||
|
cseg.reset_index(inplace=True)
|
||||||
|
aligned_segments.append(
|
||||||
|
{
|
||||||
|
"start": srow["start"],
|
||||||
|
"end": srow["end"],
|
||||||
|
"text": text,
|
||||||
|
"word-segments": wseg,
|
||||||
|
"char-segments": cseg
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_raw_text(word_row):
|
||||||
|
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
||||||
|
|
||||||
|
wdx = 0
|
||||||
|
curr_text = get_raw_text(wseg.iloc[wdx])
|
||||||
|
if len(wseg) > 1:
|
||||||
|
for _, wrow in wseg.iloc[1:].iterrows():
|
||||||
|
if wrow['start'] != wseg.iloc[wdx]['start']:
|
||||||
|
aligned_segments_word.append(
|
||||||
|
{
|
||||||
|
"text": curr_text.strip(),
|
||||||
|
"start": wseg.iloc[wdx]["start"],
|
||||||
|
"end": wseg.iloc[wdx]["end"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
curr_text = ""
|
||||||
|
curr_text += " " + get_raw_text(wrow)
|
||||||
|
wdx += 1
|
||||||
|
aligned_segments_word.append(
|
||||||
|
{
|
||||||
|
"text": curr_text.strip(),
|
||||||
|
"start": wseg.iloc[wdx]["start"],
|
||||||
|
"end": wseg.iloc[wdx]["end"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
def get_trellis(emission, tokens, blank_id=0):
|
def get_trellis(emission, tokens, blank_id=0):
|
||||||
num_frame = emission.size(0)
|
num_frame = emission.size(0)
|
||||||
num_tokens = len(tokens)
|
num_tokens = len(tokens)
|
||||||
|
56
whisperx/diarize.py
Normal file
56
whisperx/diarize.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||||
|
|
||||||
|
for seg in result_segments:
|
||||||
|
wdf = seg['word-segments']
|
||||||
|
if len(wdf['start'].dropna()) == 0:
|
||||||
|
wdf['start'] = seg['start']
|
||||||
|
wdf['end'] = seg['end']
|
||||||
|
speakers = []
|
||||||
|
for wdx, wrow in wdf.iterrows():
|
||||||
|
if not np.isnan(wrow['start']):
|
||||||
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
||||||
|
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
||||||
|
# remove no hit
|
||||||
|
if not fill_nearest:
|
||||||
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||||
|
else:
|
||||||
|
dia_tmp = diarize_df
|
||||||
|
if len(dia_tmp) == 0:
|
||||||
|
speaker = None
|
||||||
|
else:
|
||||||
|
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
||||||
|
else:
|
||||||
|
speaker = None
|
||||||
|
speakers.append(speaker)
|
||||||
|
seg['word-segments']['speaker'] = speakers
|
||||||
|
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
||||||
|
|
||||||
|
# create word level segments for .srt
|
||||||
|
word_seg = []
|
||||||
|
for seg in result_segments:
|
||||||
|
wseg = pd.DataFrame(seg["word-segments"])
|
||||||
|
for wdx, wrow in wseg.iterrows():
|
||||||
|
if wrow["start"] is not None:
|
||||||
|
speaker = wrow['speaker']
|
||||||
|
if speaker is None or speaker == np.nan:
|
||||||
|
speaker = "UNKNOWN"
|
||||||
|
word_seg.append(
|
||||||
|
{
|
||||||
|
"start": wrow["start"],
|
||||||
|
"end": wrow["end"],
|
||||||
|
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: create segments but split words on new speaker
|
||||||
|
|
||||||
|
return result_segments, word_seg
|
||||||
|
|
||||||
|
class Segment:
|
||||||
|
def __init__(self, start, end, speaker=None):
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.speaker = speaker
|
@ -5,36 +5,19 @@ from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||||
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
|
from .diarize import assign_word_speakers, Segment
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
||||||
|
from .vad import Binarize
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
|
||||||
|
|
||||||
DEFAULT_ALIGN_MODELS_TORCH = {
|
|
||||||
"en": "WAV2VEC2_ASR_BASE_960H",
|
|
||||||
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
|
||||||
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
|
||||||
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
|
||||||
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_ALIGN_MODELS_HF = {
|
|
||||||
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
|
||||||
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
|
||||||
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
|
||||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
|
||||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
@ -273,362 +256,29 @@ def transcribe(
|
|||||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
||||||
|
|
||||||
|
|
||||||
def align(
|
|
||||||
transcript: Iterator[dict],
|
|
||||||
model: torch.nn.Module,
|
|
||||||
align_model_metadata: dict,
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
|
||||||
device: str,
|
|
||||||
extend_duration: float = 0.0,
|
|
||||||
start_from_previous: bool = True,
|
|
||||||
interpolate_method: str = "nearest",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Force align phoneme recognition predictions to known transcription
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
transcript: Iterator[dict]
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
model: torch.nn.Module
|
|
||||||
Alignment model (wav2vec2)
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
device: str
|
|
||||||
cuda device
|
|
||||||
|
|
||||||
extend_duration: float
|
|
||||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
|
||||||
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
|
||||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
|
||||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
|
||||||
if not torch.is_tensor(audio):
|
|
||||||
if isinstance(audio, str):
|
|
||||||
audio = load_audio(audio)
|
|
||||||
audio = torch.from_numpy(audio)
|
|
||||||
if len(audio.shape) == 1:
|
|
||||||
audio = audio.unsqueeze(0)
|
|
||||||
|
|
||||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
|
||||||
|
|
||||||
model_dictionary = align_model_metadata["dictionary"]
|
|
||||||
model_lang = align_model_metadata["language"]
|
|
||||||
model_type = align_model_metadata["type"]
|
|
||||||
|
|
||||||
aligned_segments = []
|
|
||||||
|
|
||||||
prev_t2 = 0
|
|
||||||
for segment in transcript:
|
|
||||||
aligned_subsegments = []
|
|
||||||
while True:
|
|
||||||
segment_align_success = False
|
|
||||||
|
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
|
||||||
transcription = segment["text"]
|
|
||||||
|
|
||||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
|
||||||
# e.g. "$300" -> "three hundred dollars"
|
|
||||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
|
||||||
|
|
||||||
# split into words
|
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
||||||
per_word = transcription.split(" ")
|
|
||||||
else:
|
|
||||||
per_word = transcription
|
|
||||||
|
|
||||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
|
||||||
clean_char, clean_cdx = [], []
|
|
||||||
for cdx, char in enumerate(transcription):
|
|
||||||
char_ = char.lower()
|
|
||||||
# wav2vec2 models use "|" character to represent spaces
|
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
||||||
char_ = char_.replace(" ", "|")
|
|
||||||
|
|
||||||
# ignore whitespace at beginning and end of transcript
|
|
||||||
if cdx < num_leading:
|
|
||||||
pass
|
|
||||||
elif cdx > len(transcription) - num_trailing - 1:
|
|
||||||
pass
|
|
||||||
elif char_ in model_dictionary.keys():
|
|
||||||
clean_char.append(char_)
|
|
||||||
clean_cdx.append(cdx)
|
|
||||||
|
|
||||||
clean_wdx = []
|
|
||||||
for wdx, wrd in enumerate(per_word):
|
|
||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
|
||||||
clean_wdx.append(wdx)
|
|
||||||
|
|
||||||
# if no characters are in the dictionary, then we skip this segment...
|
|
||||||
if len(clean_char) == 0:
|
|
||||||
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
|
|
||||||
break
|
|
||||||
|
|
||||||
transcription_cleaned = "".join(clean_char)
|
|
||||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
|
||||||
|
|
||||||
# pad according original timestamps
|
|
||||||
t1 = max(segment["start"] - extend_duration, 0)
|
|
||||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
|
||||||
|
|
||||||
# use prev_t2 as current t1 if it"s later
|
|
||||||
if start_from_previous and t1 < prev_t2:
|
|
||||||
t1 = prev_t2
|
|
||||||
|
|
||||||
# check if timestamp range is still valid
|
|
||||||
if t1 >= MAX_DURATION:
|
|
||||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
|
||||||
break
|
|
||||||
if t2 - t1 < 0.02:
|
|
||||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
|
||||||
break
|
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
|
||||||
|
|
||||||
waveform_segment = audio[:, f1:f2]
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
if model_type == "torchaudio":
|
|
||||||
emissions, _ = model(waveform_segment.to(device))
|
|
||||||
elif model_type == "huggingface":
|
|
||||||
emissions = model(waveform_segment.to(device)).logits
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
|
||||||
emissions = torch.log_softmax(emissions, dim=-1)
|
|
||||||
|
|
||||||
emission = emissions[0].cpu().detach()
|
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens)
|
|
||||||
path = backtrack(trellis, emission, tokens)
|
|
||||||
if path is None:
|
|
||||||
print("Failed to align segment: backtrack failed, resorting to original...")
|
|
||||||
break
|
|
||||||
char_segments = merge_repeats(path, transcription_cleaned)
|
|
||||||
# word_segments = merge_words(char_segments)
|
|
||||||
|
|
||||||
|
|
||||||
# sub-segments
|
|
||||||
if "seg-text" not in segment:
|
|
||||||
segment["seg-text"] = [transcription]
|
|
||||||
|
|
||||||
v = 0
|
|
||||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
|
||||||
seg_lens_cumsum = [v := v + n for n in seg_lens]
|
|
||||||
sub_seg_idx = 0
|
|
||||||
|
|
||||||
char_level = {
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
"word-index": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
word_level = {
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
"segment-text-start": [],
|
|
||||||
"segment-text-end": []
|
|
||||||
}
|
|
||||||
|
|
||||||
wdx = 0
|
|
||||||
seg_start_actual, seg_end_actual = None, None
|
|
||||||
duration = t2 - t1
|
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
|
||||||
cdx_prev = 0
|
|
||||||
for cdx, char in enumerate(transcription + " "):
|
|
||||||
is_last = False
|
|
||||||
if cdx == len(transcription):
|
|
||||||
break
|
|
||||||
elif cdx+1 == len(transcription):
|
|
||||||
is_last = True
|
|
||||||
|
|
||||||
|
|
||||||
start, end, score = None, None, None
|
|
||||||
if cdx in clean_cdx:
|
|
||||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
|
||||||
start = char_seg.start * ratio + t1
|
|
||||||
end = char_seg.end * ratio + t1
|
|
||||||
score = char_seg.score
|
|
||||||
|
|
||||||
char_level["start"].append(start)
|
|
||||||
char_level["end"].append(end)
|
|
||||||
char_level["score"].append(score)
|
|
||||||
char_level["word-index"].append(wdx)
|
|
||||||
|
|
||||||
# word-level info
|
|
||||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
||||||
# character == word
|
|
||||||
wdx += 1
|
|
||||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
||||||
wdx += 1
|
|
||||||
word_level["start"].append(None)
|
|
||||||
word_level["end"].append(None)
|
|
||||||
word_level["score"].append(None)
|
|
||||||
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
|
|
||||||
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
|
|
||||||
cdx_prev = cdx+2
|
|
||||||
|
|
||||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
||||||
char_level = pd.DataFrame(char_level)
|
|
||||||
word_level = pd.DataFrame(word_level)
|
|
||||||
|
|
||||||
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
|
|
||||||
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
|
|
||||||
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
|
|
||||||
|
|
||||||
# fill missing
|
|
||||||
if interpolate_method != "ignore":
|
|
||||||
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
|
|
||||||
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
|
|
||||||
word_level["start"] = word_level["start"].values.tolist()
|
|
||||||
word_level["end"] = word_level["end"].values.tolist()
|
|
||||||
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
|
|
||||||
|
|
||||||
char_level = char_level.replace({np.nan:None}).to_dict("list")
|
|
||||||
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
|
|
||||||
else:
|
|
||||||
word_level = None
|
|
||||||
|
|
||||||
aligned_subsegments.append(
|
|
||||||
{
|
|
||||||
"text": segment["seg-text"][sub_seg_idx],
|
|
||||||
"start": seg_start_actual,
|
|
||||||
"end": seg_end_actual,
|
|
||||||
"char-segments": char_level,
|
|
||||||
"word-segments": word_level
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if "language" in segment:
|
|
||||||
aligned_subsegments[-1]["language"] = segment["language"]
|
|
||||||
|
|
||||||
char_level = {
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
"word-index": [],
|
|
||||||
}
|
|
||||||
word_level = {
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
"segment-text-start": [],
|
|
||||||
"segment-text-end": []
|
|
||||||
}
|
|
||||||
wdx = 0
|
|
||||||
cdx_prev = cdx + 2
|
|
||||||
sub_seg_idx += 1
|
|
||||||
seg_start_actual, seg_end_actual = None, None
|
|
||||||
|
|
||||||
|
|
||||||
# take min-max for actual segment-level timestamp
|
|
||||||
if seg_start_actual is None and start is not None:
|
|
||||||
seg_start_actual = start
|
|
||||||
if end is not None:
|
|
||||||
seg_end_actual = end
|
|
||||||
|
|
||||||
|
|
||||||
prev_t2 = segment["end"]
|
|
||||||
|
|
||||||
segment_align_success = True
|
|
||||||
# end while True loop
|
|
||||||
break
|
|
||||||
|
|
||||||
# reset prev_t2 due to drifting issues
|
|
||||||
if not segment_align_success:
|
|
||||||
prev_t2 = 0
|
|
||||||
|
|
||||||
start = interpolate_nans(pd.DataFrame(aligned_subsegments)["start"], method=interpolate_method)
|
|
||||||
end = interpolate_nans(pd.DataFrame(aligned_subsegments)["end"], method=interpolate_method)
|
|
||||||
for idx, seg in enumerate(aligned_subsegments):
|
|
||||||
seg['start'] = start.iloc[idx]
|
|
||||||
seg['end'] = end.iloc[idx]
|
|
||||||
|
|
||||||
aligned_segments += aligned_subsegments
|
|
||||||
|
|
||||||
# create word level segments for .srt
|
|
||||||
word_seg = []
|
|
||||||
for seg in aligned_segments:
|
|
||||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
||||||
# character based
|
|
||||||
seg["word-segments"] = seg["char-segments"]
|
|
||||||
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
|
|
||||||
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
|
|
||||||
|
|
||||||
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
|
|
||||||
for wdx, wrow in wseg.iterrows():
|
|
||||||
if wrow["start"] is not None:
|
|
||||||
word_seg.append(
|
|
||||||
{
|
|
||||||
"start": wrow["start"],
|
|
||||||
"end": wrow["end"],
|
|
||||||
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"segments": aligned_segments, "word_segments": word_seg}
|
|
||||||
|
|
||||||
def load_align_model(language_code, device, model_name=None):
|
|
||||||
if model_name is None:
|
|
||||||
# use default model
|
|
||||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
|
||||||
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
|
||||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
|
||||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
|
||||||
else:
|
|
||||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
|
||||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
|
||||||
raise ValueError(f"No default align-model for language: {language_code}")
|
|
||||||
|
|
||||||
if model_name in torchaudio.pipelines.__all__:
|
|
||||||
pipeline_type = "torchaudio"
|
|
||||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
|
||||||
align_model = bundle.get_model().to(device)
|
|
||||||
labels = bundle.get_labels()
|
|
||||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
|
||||||
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
|
||||||
pipeline_type = "huggingface"
|
|
||||||
align_model = align_model.to(device)
|
|
||||||
labels = processor.tokenizer.get_vocab()
|
|
||||||
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
|
||||||
|
|
||||||
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
|
||||||
|
|
||||||
return align_model, align_metadata
|
|
||||||
|
|
||||||
|
|
||||||
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
||||||
"""
|
"""
|
||||||
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
|
||||||
|
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
|
||||||
|
TODO: Or sliding window alignment model over long segment.
|
||||||
"""
|
"""
|
||||||
curr_start = 0
|
|
||||||
curr_end = 0
|
curr_end = 0
|
||||||
merged_segments = []
|
merged_segments = []
|
||||||
seg_idxs = []
|
seg_idxs = []
|
||||||
speaker_idxs = []
|
speaker_idxs = []
|
||||||
for sdx, seg in enumerate(segments):
|
|
||||||
|
assert chunk_size > 0
|
||||||
|
binarize = Binarize(max_duration=chunk_size)
|
||||||
|
segments = binarize(segments)
|
||||||
|
segments_list = []
|
||||||
|
for speech_turn in segments.get_timeline():
|
||||||
|
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
|
|
||||||
|
assert segments_list, "segments_list is empty."
|
||||||
|
# Make sur the starting point is the start of the segment.
|
||||||
|
curr_start = segments_list[0].start
|
||||||
|
|
||||||
|
for seg in segments_list:
|
||||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||||
merged_segments.append({
|
merged_segments.append({
|
||||||
"start": curr_start,
|
"start": curr_start,
|
||||||
@ -668,12 +318,9 @@ def transcribe_with_vad(
|
|||||||
prev = 0
|
prev = 0
|
||||||
output = {"segments": []}
|
output = {"segments": []}
|
||||||
|
|
||||||
vad_segments_list = []
|
|
||||||
vad_segments = vad_pipeline(audio)
|
vad_segments = vad_pipeline(audio)
|
||||||
for speech_turn in vad_segments.get_timeline().support():
|
|
||||||
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
||||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||||
vad_segments = merge_chunks(vad_segments_list)
|
vad_segments = merge_chunks(vad_segments)
|
||||||
|
|
||||||
for sdx, seg_t in enumerate(vad_segments):
|
for sdx, seg_t in enumerate(vad_segments):
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -702,56 +349,223 @@ def transcribe_with_vad(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
def transcribe_with_vad_parallel(
|
||||||
|
model: "Whisper",
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
vad_pipeline,
|
||||||
|
mel = None,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
batch_size = -1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe per VAD segment
|
||||||
|
"""
|
||||||
|
|
||||||
for seg in result_segments:
|
if mel is None:
|
||||||
wdf = pd.DataFrame(seg['word-segments'])
|
mel = log_mel_spectrogram(audio)
|
||||||
if len(wdf['start'].dropna()) == 0:
|
|
||||||
wdf['start'] = seg['start']
|
vad_segments = vad_pipeline(audio)
|
||||||
wdf['end'] = seg['end']
|
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||||
speakers = []
|
vad_segments = merge_chunks(vad_segments)
|
||||||
for wdx, wrow in wdf.iterrows():
|
|
||||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
|
||||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
|
||||||
# remove no hit
|
|
||||||
if not fill_nearest:
|
|
||||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
|
||||||
else:
|
|
||||||
dia_tmp = diarize_df
|
|
||||||
if len(dia_tmp) == 0:
|
|
||||||
speaker = None
|
|
||||||
else:
|
|
||||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
|
||||||
speakers.append(speaker)
|
|
||||||
seg['word-segments']['speaker'] = speakers
|
|
||||||
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
|
||||||
|
|
||||||
# create word level segments for .srt
|
################################
|
||||||
word_seg = []
|
### START of parallelization ###
|
||||||
for seg in result_segments:
|
################################
|
||||||
wseg = pd.DataFrame(seg["word-segments"])
|
|
||||||
for wdx, wrow in wseg.iterrows():
|
# pad mel to a same length
|
||||||
if wrow["start"] is not None:
|
start_seconds = [i['start'] for i in vad_segments]
|
||||||
speaker = wrow['speaker']
|
end_seconds = [i['end'] for i in vad_segments]
|
||||||
if speaker is None or speaker == np.nan:
|
duration_list = np.array(end_seconds) - np.array(start_seconds)
|
||||||
speaker = "UNKNOWN"
|
max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
|
||||||
word_seg.append(
|
offset_list = np.array(start_seconds)
|
||||||
{
|
chunks = []
|
||||||
"start": wrow["start"],
|
|
||||||
"end": wrow["end"],
|
for start_ts, end_ts in zip(start_seconds, end_seconds):
|
||||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||||
}
|
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||||
|
chunk = mel[:, start_ts:end_ts]
|
||||||
|
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
|
||||||
|
# using 'decode_options1': only support single temperature decoding (no fallbacks)
|
||||||
|
# result_list2 = model.decode(mel_chunk, decode_options1)
|
||||||
|
|
||||||
|
# prepare DecodingOptions
|
||||||
|
temperatures = kwargs.pop("temperature", None)
|
||||||
|
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
|
||||||
|
logprob_threshold = kwargs.pop("logprob_threshold", None)
|
||||||
|
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
|
||||||
|
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
|
||||||
|
initial_prompt = kwargs.pop("initial_prompt", None)
|
||||||
|
|
||||||
|
t = 0 # TODO: does not upport temperature sweeping
|
||||||
|
if t > 0:
|
||||||
|
# disable beam_size and patience when t > 0
|
||||||
|
kwargs.pop("beam_size", None)
|
||||||
|
kwargs.pop("patience", None)
|
||||||
|
else:
|
||||||
|
# disable best_of when t == 0
|
||||||
|
kwargs.pop("best_of", None)
|
||||||
|
|
||||||
|
options = DecodingOptions(**kwargs, temperature=t)
|
||||||
|
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
|
||||||
|
decode_result = []
|
||||||
|
for mel_chunk_batch in mel_chunk_batches:
|
||||||
|
decode_result.extend(model.decode(mel_chunk_batch, options))
|
||||||
|
|
||||||
|
##############################
|
||||||
|
### END of parallelization ###
|
||||||
|
##############################
|
||||||
|
|
||||||
|
# post processing: get segments rfom batch-decoded results
|
||||||
|
input_stride = exact_div(
|
||||||
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
|
) # mel frames per output token: 2
|
||||||
|
language = kwargs["language"]
|
||||||
|
task = kwargs["task"]
|
||||||
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||||
|
|
||||||
|
output = post_process_results(
|
||||||
|
vad_segments,
|
||||||
|
decode_result,
|
||||||
|
duration_list,
|
||||||
|
offset_list,
|
||||||
|
input_stride,
|
||||||
|
language,
|
||||||
|
tokenizer,
|
||||||
|
no_speech_threshold=no_speech_threshold,
|
||||||
|
logprob_threshold=logprob_threshold,
|
||||||
|
verbose=verbose)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_results(
|
||||||
|
vad_segments,
|
||||||
|
result_list,
|
||||||
|
duration_list,
|
||||||
|
offset_list,
|
||||||
|
input_stride,
|
||||||
|
language,
|
||||||
|
tokenizer,
|
||||||
|
no_speech_threshold = None,
|
||||||
|
logprob_threshold = None,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
seek = 0
|
||||||
|
time_precision = (
|
||||||
|
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
) # time per output token: 0.02 (seconds)
|
||||||
|
all_tokens = []
|
||||||
|
all_segments = []
|
||||||
|
output = {"segments": []}
|
||||||
|
|
||||||
|
def add_segment(
|
||||||
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||||
|
):
|
||||||
|
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||||
|
if len(text.strip()) == 0: # skip empty text output
|
||||||
|
return
|
||||||
|
|
||||||
|
all_segments.append(
|
||||||
|
{
|
||||||
|
"id": len(all_segments),
|
||||||
|
"seek": seek,
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
"text": text,
|
||||||
|
"tokens": text_tokens.tolist(),
|
||||||
|
"temperature": result.temperature,
|
||||||
|
"avg_logprob": result.avg_logprob,
|
||||||
|
"compression_ratio": result.compression_ratio,
|
||||||
|
"no_speech_prob": result.no_speech_prob,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||||
|
|
||||||
|
# process the output
|
||||||
|
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
|
||||||
|
all_tokens = []
|
||||||
|
all_segments = []
|
||||||
|
|
||||||
|
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
|
||||||
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
|
if no_speech_threshold is not None:
|
||||||
|
# no voice activity check
|
||||||
|
should_skip = result.no_speech_prob > no_speech_threshold
|
||||||
|
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||||
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||||
|
should_skip = False
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
|
seek += segment_shape # fast-forward to the next segment boundary
|
||||||
|
continue
|
||||||
|
|
||||||
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||||
|
|
||||||
|
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||||
|
last_slice = 0
|
||||||
|
for current_slice in consecutive:
|
||||||
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
|
start_timestamp_position = (
|
||||||
|
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
|
end_timestamp_position = (
|
||||||
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
add_segment(
|
||||||
|
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||||
|
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||||
|
text_tokens=sliced_tokens[1:-1],
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
last_slice = current_slice
|
||||||
|
last_timestamp_position = (
|
||||||
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
seek += last_timestamp_position * input_stride
|
||||||
|
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||||
|
else:
|
||||||
|
duration = segment_duration
|
||||||
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
|
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||||
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
|
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||||
|
duration = last_timestamp_position * time_precision
|
||||||
|
|
||||||
# TODO: create segments but split words on new speaker
|
add_segment(
|
||||||
|
start=timestamp_offset,
|
||||||
|
end=timestamp_offset + duration,
|
||||||
|
text_tokens=tokens,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
return result_segments, word_seg
|
seek += segment_shape
|
||||||
|
all_tokens.extend(tokens.tolist())
|
||||||
|
|
||||||
class Segment:
|
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
||||||
def __init__(self, start, end, speaker=None):
|
output["segments"].append(
|
||||||
self.start = start
|
{
|
||||||
self.end = end
|
"start": seg_t["start"],
|
||||||
self.speaker = speaker
|
"end": seg_t["end"],
|
||||||
|
"language": result["language"],
|
||||||
|
"text": result["text"],
|
||||||
|
"seg-text": [x["text"] for x in result["segments"]],
|
||||||
|
"seg-start": [x["start"] for x in result["segments"]],
|
||||||
|
"seg-end": [x["end"] for x in result["segments"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output["language"] = output["segments"][0]["language"]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@ -769,14 +583,14 @@ def cli():
|
|||||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||||
# vad params
|
# vad params
|
||||||
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
||||||
parser.add_argument("--vad_input", default=None, type=str)
|
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
|
||||||
# diarization params
|
# diarization params
|
||||||
parser.add_argument("--diarize", action='store_true')
|
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||||
parser.add_argument("--min_speakers", default=None, type=int)
|
parser.add_argument("--min_speakers", default=None, type=int)
|
||||||
parser.add_argument("--max_speakers", default=None, type=int)
|
parser.add_argument("--max_speakers", default=None, type=int)
|
||||||
# output save params
|
# output save params
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
|
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
|
||||||
|
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
@ -799,7 +613,8 @@ def cli():
|
|||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
model_dir: str = args.pop("model_dir")
|
model_dir: str = args.pop("model_dir")
|
||||||
@ -811,25 +626,34 @@ def cli():
|
|||||||
align_extend: float = args.pop("align_extend")
|
align_extend: float = args.pop("align_extend")
|
||||||
align_from_prev: bool = args.pop("align_from_prev")
|
align_from_prev: bool = args.pop("align_from_prev")
|
||||||
interpolate_method: bool = args.pop("interpolate_method")
|
interpolate_method: bool = args.pop("interpolate_method")
|
||||||
|
|
||||||
|
hf_token: str = args.pop("hf_token")
|
||||||
vad_filter: bool = args.pop("vad_filter")
|
vad_filter: bool = args.pop("vad_filter")
|
||||||
vad_input: bool = args.pop("vad_input")
|
parallel_bs: int = args.pop("parallel_bs")
|
||||||
|
|
||||||
diarize: bool = args.pop("diarize")
|
diarize: bool = args.pop("diarize")
|
||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
|
||||||
vad_pipeline = None
|
vad_pipeline = None
|
||||||
if vad_input is not None:
|
if vad_filter:
|
||||||
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
if hf_token is None:
|
||||||
elif vad_filter:
|
print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Inference
|
||||||
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
|
vad_pipeline = Inference(
|
||||||
|
"pyannote/segmentation",
|
||||||
|
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||||
|
use_auth_token=hf_token,
|
||||||
|
device=torch.device(device),
|
||||||
|
)
|
||||||
|
|
||||||
diarize_pipeline = None
|
diarize_pipeline = None
|
||||||
if diarize:
|
if diarize:
|
||||||
|
if hf_token is None:
|
||||||
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
|
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
|
||||||
|
use_auth_token=hf_token)
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -857,8 +681,12 @@ def cli():
|
|||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
if vad_filter:
|
if vad_filter:
|
||||||
print("Performing VAD...")
|
if parallel_bs > 1:
|
||||||
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
print("Performing VAD and parallel transcribing ...")
|
||||||
|
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
|
||||||
|
else:
|
||||||
|
print("Performing VAD...")
|
||||||
|
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
||||||
else:
|
else:
|
||||||
print("Performing transcription...")
|
print("Performing transcription...")
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
@ -868,6 +696,7 @@ def cli():
|
|||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
|
|
||||||
|
|
||||||
print("Performing alignment...")
|
print("Performing alignment...")
|
||||||
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
||||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||||
@ -901,7 +730,7 @@ def cli():
|
|||||||
|
|
||||||
# save TSV
|
# save TSV
|
||||||
if output_type in ["tsv", "all"]:
|
if output_type in ["tsv", "all"]:
|
||||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
with open(os.path.join(output_dir, audio_basename + ".tsv"), "w", encoding="utf-8") as srt:
|
||||||
write_tsv(result_aligned["segments"], file=srt)
|
write_tsv(result_aligned["segments"], file=srt)
|
||||||
|
|
||||||
# save SRT word-level
|
# save SRT word-level
|
||||||
@ -915,10 +744,20 @@ def cli():
|
|||||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||||
write_ass(result_aligned["segments"], file=ass)
|
write_ass(result_aligned["segments"], file=ass)
|
||||||
|
|
||||||
# save ASS character-level
|
# # save ASS character-level
|
||||||
if output_type in ["ass-char", "all"]:
|
if output_type in ["ass-char"]:
|
||||||
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
|
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
|
||||||
write_ass(result_aligned["segments"], file=ass, resolution="char")
|
write_ass(result_aligned["segments"], file=ass, resolution="char")
|
||||||
|
|
||||||
|
# save word tsv
|
||||||
|
if output_type in ["pickle"]:
|
||||||
|
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
|
||||||
|
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
|
||||||
|
|
||||||
|
# save word tsv
|
||||||
|
if output_type in ["vad"]:
|
||||||
|
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
|
||||||
|
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
|
||||||
|
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, TextIO, Iterator, Tuple
|
from typing import Callable, TextIO, Iterator, Tuple
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def exact_div(x, y):
|
def exact_div(x, y):
|
||||||
assert x % y == 0
|
assert x % y == 0
|
||||||
@ -64,8 +65,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
|
|||||||
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in transcript:
|
for segment in transcript:
|
||||||
print(round(1000 * segment['start']), file=file, end="\t")
|
print(segment['start'], file=file, end="\t")
|
||||||
print(round(1000 * segment['end']), file=file, end="\t")
|
print(segment['end'], file=file, end="\t")
|
||||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@ -206,6 +207,8 @@ def write_ass(transcript: Iterator[dict],
|
|||||||
ass_arr = []
|
ass_arr = []
|
||||||
|
|
||||||
for segment in transcript:
|
for segment in transcript:
|
||||||
|
# if "12" in segment['text']:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
if resolution_key in segment:
|
if resolution_key in segment:
|
||||||
res_segs = pd.DataFrame(segment[resolution_key])
|
res_segs = pd.DataFrame(segment[resolution_key])
|
||||||
prev = segment['start']
|
prev = segment['start']
|
||||||
@ -214,7 +217,7 @@ def write_ass(transcript: Iterator[dict],
|
|||||||
else:
|
else:
|
||||||
speaker_str = ""
|
speaker_str = ""
|
||||||
for cdx, crow in res_segs.iterrows():
|
for cdx, crow in res_segs.iterrows():
|
||||||
if crow['start'] is not None:
|
if not np.isnan(crow['start']):
|
||||||
if resolution == "char":
|
if resolution == "char":
|
||||||
idx_0 = cdx
|
idx_0 = cdx
|
||||||
idx_1 = cdx + 1
|
idx_1 = cdx + 1
|
||||||
|
185
whisperx/vad.py
Normal file
185
whisperx/vad.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
class Binarize:
|
||||||
|
"""Binarize detection scores using hysteresis thresholding
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
onset : float, optional
|
||||||
|
Onset threshold. Defaults to 0.5.
|
||||||
|
offset : float, optional
|
||||||
|
Offset threshold. Defaults to `onset`.
|
||||||
|
min_duration_on : float, optional
|
||||||
|
Remove active regions shorter than that many seconds. Defaults to 0s.
|
||||||
|
min_duration_off : float, optional
|
||||||
|
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
||||||
|
pad_onset : float, optional
|
||||||
|
Extend active regions by moving their start time by that many seconds.
|
||||||
|
Defaults to 0s.
|
||||||
|
pad_offset : float, optional
|
||||||
|
Extend active regions by moving their end time by that many seconds.
|
||||||
|
Defaults to 0s.
|
||||||
|
max_duration: float
|
||||||
|
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
||||||
|
Reference
|
||||||
|
---------
|
||||||
|
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||||
|
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||||
|
|
||||||
|
Pyannote-audio
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
min_duration_on: float = 0.0,
|
||||||
|
min_duration_off: float = 0.0,
|
||||||
|
pad_onset: float = 0.0,
|
||||||
|
pad_offset: float = 0.0,
|
||||||
|
max_duration: float = float('inf')
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.onset = onset
|
||||||
|
self.offset = offset or onset
|
||||||
|
|
||||||
|
self.pad_onset = pad_onset
|
||||||
|
self.pad_offset = pad_offset
|
||||||
|
|
||||||
|
self.min_duration_on = min_duration_on
|
||||||
|
self.min_duration_off = min_duration_off
|
||||||
|
|
||||||
|
self.max_duration = max_duration
|
||||||
|
|
||||||
|
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
||||||
|
"""Binarize detection scores
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
scores : SlidingWindowFeature
|
||||||
|
Detection scores.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
active : Annotation
|
||||||
|
Binarized scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_frames, num_classes = scores.data.shape
|
||||||
|
frames = scores.sliding_window
|
||||||
|
timestamps = [frames[i].middle for i in range(num_frames)]
|
||||||
|
|
||||||
|
# annotation meant to store 'active' regions
|
||||||
|
active = Annotation()
|
||||||
|
for k, k_scores in enumerate(scores.data.T):
|
||||||
|
|
||||||
|
label = k if scores.labels is None else scores.labels[k]
|
||||||
|
|
||||||
|
# initial state
|
||||||
|
start = timestamps[0]
|
||||||
|
is_active = k_scores[0] > self.onset
|
||||||
|
curr_scores = [k_scores[0]]
|
||||||
|
curr_timestamps = [start]
|
||||||
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||||
|
# currently active
|
||||||
|
if is_active:
|
||||||
|
curr_duration = t - start
|
||||||
|
if curr_duration > self.max_duration:
|
||||||
|
# if curr_duration > 15:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
search_after = len(curr_scores) // 2
|
||||||
|
# divide segment
|
||||||
|
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||||||
|
min_score_t = curr_timestamps[min_score_div_idx]
|
||||||
|
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||||
|
active[region, k] = label
|
||||||
|
start = curr_timestamps[min_score_div_idx]
|
||||||
|
curr_scores = curr_scores[min_score_div_idx+1:]
|
||||||
|
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
||||||
|
# switching from active to inactive
|
||||||
|
elif y < self.offset:
|
||||||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
|
active[region, k] = label
|
||||||
|
start = t
|
||||||
|
is_active = False
|
||||||
|
curr_scores = []
|
||||||
|
curr_timestamps = []
|
||||||
|
# currently inactive
|
||||||
|
else:
|
||||||
|
# switching from inactive to active
|
||||||
|
if y > self.onset:
|
||||||
|
start = t
|
||||||
|
is_active = True
|
||||||
|
curr_scores.append(y)
|
||||||
|
curr_timestamps.append(t)
|
||||||
|
|
||||||
|
# if active at the end, add final region
|
||||||
|
if is_active:
|
||||||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
|
active[region, k] = label
|
||||||
|
|
||||||
|
# because of padding, some active regions might be overlapping: merge them.
|
||||||
|
# also: fill same speaker gaps shorter than min_duration_off
|
||||||
|
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
||||||
|
if self.max_duration < float("inf"):
|
||||||
|
raise NotImplementedError(f"This would break current max_duration param")
|
||||||
|
active = active.support(collar=self.min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if self.min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < self.min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
return active
|
||||||
|
|
||||||
|
|
||||||
|
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||||
|
|
||||||
|
active = Annotation()
|
||||||
|
for k, vad_t in enumerate(vad_arr):
|
||||||
|
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||||
|
active[region, k] = 1
|
||||||
|
|
||||||
|
|
||||||
|
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||||
|
active = active.support(collar=min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
active = active.for_json()
|
||||||
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||||
|
return active_segs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# from pyannote.audio import Inference
|
||||||
|
# hook = lambda segmentation: segmentation
|
||||||
|
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||||||
|
# audio = "/tmp/11962.wav"
|
||||||
|
# scores = inference(audio)
|
||||||
|
# binarize = Binarize(max_duration=15)
|
||||||
|
# anno = binarize(scores)
|
||||||
|
# res = []
|
||||||
|
# for ann in anno.get_timeline():
|
||||||
|
# res.append((ann.start, ann.end))
|
||||||
|
|
||||||
|
# res = pd.DataFrame(res)
|
||||||
|
# res[2] = res[1] - res[0]
|
||||||
|
import pandas as pd
|
||||||
|
input_fp = "tt298650_sync.wav"
|
||||||
|
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
|
||||||
|
print(len(df))
|
||||||
|
N = 0.15
|
||||||
|
g = df[0].sub(df[1].shift())
|
||||||
|
input_base = input_fp.split('.')[0]
|
||||||
|
df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
|
||||||
|
df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
|
||||||
|
print(df)
|
||||||
|
import pdb; pdb.set_trace()
|
Reference in New Issue
Block a user