41 Commits

Author SHA1 Message Date
847a3cd85b Merge pull request #96 from smly/fix-batch-processing
FIX: Assertion error in batch processing
2023-02-22 12:11:01 +00:00
2b1ffa12b8 Merge pull request #97 from smly/gpu-vad-filter
GPU acceleration when using VAD filters
2023-02-21 18:57:14 +00:00
57f5957e0e Pass device to pyannote.audio.Inference 2023-02-22 03:48:20 +09:00
27fe502344 Fix assertion error in batch processing 2023-02-22 02:45:13 +09:00
f7093e60d3 Merge pull request #90 from Pikauba/translation_starting_point_improvement
Improvement to transcription starting point with VAD
2023-02-18 21:59:57 +00:00
a1d2229416 Improvement to transcription starting point with VAD 2023-02-18 11:12:23 -05:00
4cb167a225 Merge pull request #74 from Camb-ai/level-bug-fix
added if clause for checking 'level-1'
2023-02-14 19:22:22 +00:00
2e307814dd added if clause for checking 2023-02-10 14:48:51 +05:30
d687cf3358 Merge pull request #58 from MahmoudAshraf97/main
added turkish wav2vec2 model
2023-02-01 22:11:51 +00:00
0a3fd11562 update readme 2023-02-01 22:09:11 +00:00
29e95b746b Merge pull request #57 from TengdaHan/main
support batch processing
2023-02-01 20:37:54 +00:00
039af89a86 support batch processing 2023-02-01 19:41:20 +00:00
9f26112d5c added turkish wav2vec2 model 2023-02-01 21:38:50 +02:00
fd2a093754 Merge pull request #55 from jonatasgrosman/main
FIX: Error when loading Hugging Face's models with embedded LM
2023-02-01 10:27:45 +00:00
31f069752f Merge pull request #53 from MahmoudAshraf97/main
Add more languages to models list
2023-02-01 10:27:25 +00:00
4cdf7ef856 Merge pull request #48 from Barabazs/main
doc: format checklist
2023-02-01 10:26:58 +00:00
d294e29ad9 fix: error when loading huggingface model with embedded language model 2023-01-31 23:24:26 -03:00
0eae9e1f50 added several wav2vec2 models by jonatasgrosman
since his models were used in other languages before and I tested the arabic model myself, I assumed it's safe to include all the available models
2023-02-01 03:02:10 +02:00
1b08661e42 change arabic model to jonatasgrosman 2023-01-31 19:32:31 +02:00
a49799294b add arabic wav2vec2 model form elgeish 2023-01-31 19:07:48 +02:00
d83c74a79f doc: format checklist 2023-01-29 16:07:58 +01:00
acaefa09a1 Merge pull request #46 from Barabazs/main
Add sponsor link to sidebar
2023-01-28 19:05:36 +00:00
76f79f600a fix short seg timestamps bug 2023-01-28 19:04:19 +00:00
33073f9bba Create FUNDING.yml 2023-01-28 19:43:27 +01:00
50f3965fdb fix tsv file ext 2023-01-28 17:39:07 +00:00
df2b1b70cb increase vad cut default 2023-01-28 14:49:53 +00:00
c19cf407d8 handle non-alignable whole segments 2023-01-28 13:53:03 +00:00
8081ef2dcd add custom vad binarization for vad cut 2023-01-28 00:22:33 +00:00
c6dbac76c8 cut up vad segments when too long to prevent OOM 2023-01-28 00:01:39 +00:00
69673eb39b buy-me-a-coffee 2023-01-27 15:12:49 +00:00
5b8c8a7bd3 pandas fix 2023-01-27 15:05:08 +00:00
7f2159a953 Merge branch 'main' of https://github.com/m-bain/whisperX into main 2023-01-26 10:46:36 +00:00
16d24b1c96 only pad timestamps if not using VAD 2023-01-26 10:46:13 +00:00
d20a2a4ea2 typo in --diarize flag 2023-01-26 10:28:54 +00:00
312f1cc50c Merge pull request #40 from MahmoudAshraf97/main
Added arguments and instructions to enable the usage VAD and Diarization
2023-01-26 00:34:03 +00:00
99b6e79fbf Update README.md
added additional instructions to use PyAnnote modules
2023-01-26 00:56:10 +02:00
e7773358a3 Update transcribe.py
added the ability to include HF access token in order to use PyAnnote models
2023-01-26 00:42:35 +02:00
6b2aa4ff3e Merge pull request #1 from MahmoudAshraf97/patch-1
Update README.md
2023-01-26 00:37:38 +02:00
c3de5e9580 Update README.md
fixed model name
2023-01-26 00:36:29 +02:00
58d7191949 add diarize 2023-01-25 19:40:41 +00:00
286a2f2c14 clean up logic, use pandas where possibl 2023-01-25 18:42:52 +00:00
9 changed files with 994 additions and 463 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1 @@
custom: https://www.buymeacoffee.com/maxhbain

View File

@ -2,7 +2,7 @@
## Other Languages ## Other Languages
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22). For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl} Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}

View File

@ -27,7 +27,6 @@
<a href="EXAMPLES.md">More examples</a> <a href="EXAMPLES.md">More examples</a>
</p> </p>
<h6 align="center">Made by Max Bain • :globe_with_meridians: <a href="https://www.maxbain.com">https://www.maxbain.com</a></h6>
<img width="1216" align="center" alt="whisperx-arch" src="https://user-images.githubusercontent.com/36994049/211200186-8b779e26-0bfd-4127-aee2-5a9238b95e1f.png"> <img width="1216" align="center" alt="whisperx-arch" src="https://user-images.githubusercontent.com/36994049/211200186-8b779e26-0bfd-4127-aee2-5a9238b95e1f.png">
@ -50,9 +49,10 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left", id="highlights">New🚨</h2> <h2 align="left", id="highlights">New🚨</h2>
- Batch processing: Add `--vad_filter --parallel_bs [int]` for transcribing long audio file in batches (only supported with VAD filtering). Replace `[int]` with a batch size that fits your GPU memory, e.g. `--parallel_bs 16`.
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) - VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output) - Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarization`) - Diarization (still in beta, add `--diarize`)
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
@ -71,9 +71,13 @@ $ cd whisperX
$ pip install -e . $ pip install -e .
``` ```
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Voice Activity Detection Filtering & Diarization
To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
<h2 align="left" id="example">Usage 💬 (command line)</h2> <h2 align="left" id="example">Usage 💬 (command line)</h2>
### English ### English
@ -85,7 +89,7 @@ Run whisper on example segment (using default params)
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g. For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H whisperx examples/sample01.wav --model large-v2 --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large: Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -149,8 +153,9 @@ In addition to forced alignment, the following two modifications have been made
- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data - Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers. - Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
- Assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors) - If not using VAD filter, whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
- Hacked this up quite quickly, there might be some errors, please raise an issue if you encounter any. - Overlapping speech is not handled particularly well by whisper nor whisperx
- Diariazation is far from perfect.
<h2 align="left" id="contribute">Contribute 🧑‍🏫</h2> <h2 align="left" id="contribute">Contribute 🧑‍🏫</h2>
@ -161,38 +166,46 @@ The next major upgrade we are working on is whisper with speaker diarization, so
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2> <h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
[x] ~~Multilingual init~~ done * [x] Multilingual init
[x] ~~Subtitle .ass output~~ done * [x] Subtitle .ass output
[x] ~~Automatic align model selection based on language detection~~ done * [x] Automatic align model selection based on language detection
[x] ~~Python usage~~ done * [x] Python usage
[x] ~~Character level timestamps~~ * [x] Character level timestamps
[x] ~~Incorporating speaker diarization~~ * [x] Incorporating speaker diarization
[ ] Improve diarization (word level) * [x] Inference speedup with batch processing
[ ] Inference speedup with batch processing * [ ] Improve diarization (word level). *Harder than first thought...*
<h2 align="left" id="contact">Contact 📇</h2>
Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for business things. <h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for queries
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
<h2 align="left" id="acks">Acknowledgements 🙏</h2> <h2 align="left" id="acks">Acknowledgements 🙏</h2>
Of course, this is mostly just a modification to [openAI's whisper](https://github.com/openai/whisper). This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and University of Oxford.
As well as accreditation to this [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
<h2 align="left" id="cite">Citation</h2> <h2 align="left" id="cite">Citation</h2>
If you use this in your research, just cite the repo, If you use this in your research, for now just cite the repo,
```bibtex ```bibtex
@misc{bain2022whisperx, @misc{bain2022whisperx,
author = {Bain, Max}, author = {Bain, Max and Han, Tengda},
title = {WhisperX}, title = {WhisperX},
year = {2022}, year = {2022},
publisher = {GitHub}, publisher = {GitHub},

View File

@ -11,8 +11,8 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions from .model import Whisper, ModelDimensions
from .transcribe import transcribe, load_align_model, align, transcribe_with_vad from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel
from .alignment import load_align_model, align
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",

View File

@ -1,9 +1,443 @@
""""
Forced Alignment with Whisper
C. Max Bain
"""
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch
from dataclasses import dataclass
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = Wav2Vec2Processor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def align(
transcript: Iterator[dict],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# we only pad if not using VAD filtering
if "seg_text" not in segment:
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().values
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
try:
sub_start = int(srow["subsegment-idx-start"])
except:
import pdb; pdb.set_trace()
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
wseg["start"].fillna(srow["start"], inplace=True)
wseg["end"].fillna(srow["end"], inplace=True)
wseg["segment-text-start"].fillna(0, inplace=True)
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
# fixes bug for single segment in transcript
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
""" """
import torch
from dataclasses import dataclass
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)

56
whisperx/diarize.py Normal file
View File

@ -0,0 +1,56 @@
import numpy as np
import pandas as pd
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = seg['word-segments']
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
if not np.isnan(wrow['start']):
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
else:
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker

View File

@ -5,36 +5,19 @@ from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
import torchaudio
from transformers import AutoProcessor, Wav2Vec2ForCTC
import tqdm import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .alignment import get_trellis, backtrack, merge_repeats, merge_words from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .diarize import assign_word_speakers, Segment
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
from .vad import Binarize
import pandas as pd import pandas as pd
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
}
def transcribe( def transcribe(
@ -273,362 +256,29 @@ def transcribe(
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
def align(
transcript: Iterator[dict],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
for segment in transcript:
aligned_subsegments = []
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
seg_start_actual, seg_end_actual = None, None
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
cdx_prev = 0
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_level["start"].append(start)
char_level["end"].append(end)
char_level["score"].append(score)
char_level["word-index"].append(wdx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
word_level["start"].append(None)
word_level["end"].append(None)
word_level["score"].append(None)
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
cdx_prev = cdx+2
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_level = pd.DataFrame(char_level)
word_level = pd.DataFrame(word_level)
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
# fill missing
if interpolate_method != "ignore":
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
word_level["start"] = word_level["start"].values.tolist()
word_level["end"] = word_level["end"].values.tolist()
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
char_level = char_level.replace({np.nan:None}).to_dict("list")
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else:
word_level = None
aligned_subsegments.append(
{
"text": segment["seg-text"][sub_seg_idx],
"start": seg_start_actual,
"end": seg_end_actual,
"char-segments": char_level,
"word-segments": word_level
}
)
if "language" in segment:
aligned_subsegments[-1]["language"] = segment["language"]
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx += 1
seg_start_actual, seg_end_actual = None, None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None:
seg_start_actual = start
if end is not None:
seg_end_actual = end
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
start = interpolate_nans(pd.DataFrame(aligned_subsegments)["start"], method=interpolate_method)
end = interpolate_nans(pd.DataFrame(aligned_subsegments)["end"], method=interpolate_method)
for idx, seg in enumerate(aligned_subsegments):
seg['start'] = start.iloc[idx]
seg['end'] = end.iloc[idx]
aligned_segments += aligned_subsegments
# create word level segments for .srt
word_seg = []
for seg in aligned_segments:
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character based
seg["word-segments"] = seg["char-segments"]
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = AutoProcessor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH): def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
""" """
Merge VAD segments into larger segments of size ~CHUNK_LENGTH. Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
TODO: Or sliding window alignment model over long segment.
""" """
curr_start = 0
curr_end = 0 curr_end = 0
merged_segments = [] merged_segments = []
seg_idxs = [] seg_idxs = []
speaker_idxs = [] speaker_idxs = []
for sdx, seg in enumerate(segments):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({ merged_segments.append({
"start": curr_start, "start": curr_start,
@ -668,12 +318,9 @@ def transcribe_with_vad(
prev = 0 prev = 0
output = {"segments": []} output = {"segments": []}
vad_segments_list = []
vad_segments = vad_pipeline(audio) vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite # merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list) vad_segments = merge_chunks(vad_segments)
for sdx, seg_t in enumerate(vad_segments): for sdx, seg_t in enumerate(vad_segments):
if verbose: if verbose:
@ -702,56 +349,223 @@ def transcribe_with_vad(
return output return output
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False): def transcribe_with_vad_parallel(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
batch_size = -1,
**kwargs
):
"""
Transcribe per VAD segment
"""
for seg in result_segments: if mel is None:
wdf = pd.DataFrame(seg['word-segments']) mel = log_mel_spectrogram(audio)
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt vad_segments = vad_pipeline(audio)
word_seg = [] # merge segments to approx 30s inputs to make whisper most appropraite
for seg in result_segments: vad_segments = merge_chunks(vad_segments)
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows(): ################################
if wrow["start"] is not None: ### START of parallelization ###
speaker = wrow['speaker'] ################################
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN" # pad mel to a same length
word_seg.append( start_seconds = [i['start'] for i in vad_segments]
{ end_seconds = [i['end'] for i in vad_segments]
"start": wrow["start"], duration_list = np.array(end_seconds) - np.array(start_seconds)
"end": wrow["end"], max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] offset_list = np.array(start_seconds)
} chunks = []
for start_ts, end_ts in zip(start_seconds, end_seconds):
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
chunk = mel[:, start_ts:end_ts]
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
chunks.append(chunk)
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
# using 'decode_options1': only support single temperature decoding (no fallbacks)
# result_list2 = model.decode(mel_chunk, decode_options1)
# prepare DecodingOptions
temperatures = kwargs.pop("temperature", None)
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
logprob_threshold = kwargs.pop("logprob_threshold", None)
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
initial_prompt = kwargs.pop("initial_prompt", None)
t = 0 # TODO: does not upport temperature sweeping
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
decode_result = []
for mel_chunk_batch in mel_chunk_batches:
decode_result.extend(model.decode(mel_chunk_batch, options))
##############################
### END of parallelization ###
##############################
# post processing: get segments rfom batch-decoded results
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
language = kwargs["language"]
task = kwargs["task"]
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
output = post_process_results(
vad_segments,
decode_result,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold=no_speech_threshold,
logprob_threshold=logprob_threshold,
verbose=verbose)
return output
def post_process_results(
vad_segments,
result_list,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold = None,
logprob_threshold = None,
verbose: Optional[bool] = None,
):
seek = 0
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
output = {"segments": []}
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
# process the output
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
all_tokens = []
all_segments = []
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_shape # fast-forward to the next segment boundary
continue
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
) )
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
# TODO: create segments but split words on new speaker add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
return result_segments, word_seg seek += segment_shape
all_tokens.extend(tokens.tolist())
class Segment: result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
def __init__(self, start, end, speaker=None): output["segments"].append(
self.start = start {
self.end = end "start": seg_t["start"],
self.speaker = speaker "end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output["language"] = output["segments"][0]["language"]
return output
def cli(): def cli():
@ -769,14 +583,14 @@ def cli():
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
# vad params # vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.") parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str) parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
# diarization params # diarization params
parser.add_argument("--diarize", action='store_true') parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int) parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int) parser.add_argument("--max_speakers", default=None, type=int)
# output save params # output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save") parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -799,6 +613,7 @@ def cli():
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
@ -812,24 +627,33 @@ def cli():
align_from_prev: bool = args.pop("align_from_prev") align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: bool = args.pop("interpolate_method") interpolate_method: bool = args.pop("interpolate_method")
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter") vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input") parallel_bs: int = args.pop("parallel_bs")
diarize: bool = args.pop("diarize") diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers") min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers") max_speakers: int = args.pop("max_speakers")
vad_pipeline = None vad_pipeline = None
if vad_input is not None: if vad_filter:
vad_input = pd.read_csv(vad_input, header=None, sep= " ") if hf_token is None:
elif vad_filter: print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
from pyannote.audio import Pipeline from pyannote.audio import Inference
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") vad_pipeline = Inference(
"pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=hf_token,
device=torch.device(device),
)
diarize_pipeline = None diarize_pipeline = None
if diarize: if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1") diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=hf_token)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -857,8 +681,12 @@ def cli():
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
if vad_filter: if vad_filter:
print("Performing VAD...") if parallel_bs > 1:
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args) print("Performing VAD and parallel transcribing ...")
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
else:
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else: else:
print("Performing transcription...") print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
@ -868,6 +696,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print("Performing alignment...") print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
@ -901,7 +730,7 @@ def cli():
# save TSV # save TSV
if output_type in ["tsv", "all"]: if output_type in ["tsv", "all"]:
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: with open(os.path.join(output_dir, audio_basename + ".tsv"), "w", encoding="utf-8") as srt:
write_tsv(result_aligned["segments"], file=srt) write_tsv(result_aligned["segments"], file=srt)
# save SRT word-level # save SRT word-level
@ -915,10 +744,20 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass) write_ass(result_aligned["segments"], file=ass)
# save ASS character-level # # save ASS character-level
if output_type in ["ass-char", "all"]: if output_type in ["ass-char"]:
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass: with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass, resolution="char") write_ass(result_aligned["segments"], file=ass, resolution="char")
# save word tsv
if output_type in ["pickle"]:
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
# save word tsv
if output_type in ["vad"]:
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -2,6 +2,7 @@ import os
import zlib import zlib
from typing import Callable, TextIO, Iterator, Tuple from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd import pandas as pd
import numpy as np
def exact_div(x, y): def exact_div(x, y):
assert x % y == 0 assert x % y == 0
@ -64,8 +65,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
def write_tsv(transcript: Iterator[dict], file: TextIO): def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in transcript: for segment in transcript:
print(round(1000 * segment['start']), file=file, end="\t") print(segment['start'], file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t") print(segment['end'], file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True) print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
@ -206,6 +207,8 @@ def write_ass(transcript: Iterator[dict],
ass_arr = [] ass_arr = []
for segment in transcript: for segment in transcript:
# if "12" in segment['text']:
# import pdb; pdb.set_trace()
if resolution_key in segment: if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key]) res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start'] prev = segment['start']
@ -214,7 +217,7 @@ def write_ass(transcript: Iterator[dict],
else: else:
speaker_str = "" speaker_str = ""
for cdx, crow in res_segs.iterrows(): for cdx, crow in res_segs.iterrows():
if crow['start'] is not None: if not np.isnan(crow['start']):
if resolution == "char": if resolution == "char":
idx_0 = cdx idx_0 = cdx
idx_1 = cdx + 1 idx_1 = cdx + 1

185
whisperx/vad.py Normal file
View File

@ -0,0 +1,185 @@
import pandas as pd
import numpy as np
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
from typing import List, Tuple, Optional
class Binarize:
"""Binarize detection scores using hysteresis thresholding
Parameters
----------
onset : float, optional
Onset threshold. Defaults to 0.5.
offset : float, optional
Offset threshold. Defaults to `onset`.
min_duration_on : float, optional
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off : float, optional
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset : float, optional
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset : float, optional
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration: float
The maximum length of an active segment, divides segment at timestamp with lowest score.
Reference
---------
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores
Parameters
----------
scores : SlidingWindowFeature
Detection scores.
Returns
-------
active : Annotation
Binarized scores.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
# if curr_duration > 15:
# import pdb; pdb.set_trace()
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
curr_scores.append(y)
curr_timestamps.append(t)
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError(f"This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
if __name__ == "__main__":
# from pyannote.audio import Inference
# hook = lambda segmentation: segmentation
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
# audio = "/tmp/11962.wav"
# scores = inference(audio)
# binarize = Binarize(max_duration=15)
# anno = binarize(scores)
# res = []
# for ann in anno.get_timeline():
# res.append((ann.start, ann.end))
# res = pd.DataFrame(res)
# res[2] = res[1] - res[0]
import pandas as pd
input_fp = "tt298650_sync.wav"
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
print(len(df))
N = 0.15
g = df[0].sub(df[1].shift())
input_base = input_fp.split('.')[0]
df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
print(df)
import pdb; pdb.set_trace()