mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
37 Commits
v3.3.2
...
improve-co
Author | SHA1 | Date | |
---|---|---|---|
88939b9e8a | |||
8c58c54635 | |||
0d9807adc5 | |||
4db839018c | |||
f8d11df727 | |||
44e8bf5bb6 | |||
7b3c9ce629 | |||
36d2622e27 | |||
8bfa12193b | |||
acbeba6057 | |||
fca563a782 | |||
2117909bf6 | |||
de0d8fe313 | |||
355f8e06f7 | |||
86e2b3ee74 | |||
70c639cdb5 | |||
235536e28d | |||
12604a48ea | |||
ffbc73664c | |||
289eadfc76 | |||
22a93f2932 | |||
1027367b79 | |||
5e54b872a9 | |||
6be02cccfa | |||
2f93e029c7 | |||
024bc8481b | |||
f286e7f3de | |||
73e644559d | |||
1ec527375a | |||
6695426a85 | |||
7a98456321 | |||
aaddb83aa5 | |||
c288f4812a | |||
4ebfb078c5 | |||
65b2332e13 | |||
69281f3a29 | |||
79eb8fa53d |
20
README.md
20
README.md
@ -129,7 +129,7 @@ To **enable Speaker Diarization**, include your Hugging Face access token (read)
|
||||
|
||||
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
||||
|
||||
whisperx examples/sample01.wav
|
||||
whisperx path/to/audio.wav
|
||||
|
||||
|
||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||
@ -143,27 +143,27 @@ https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-
|
||||
|
||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||
|
||||
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||
|
||||
|
||||
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||
|
||||
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
||||
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
|
||||
|
||||
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||
|
||||
whisperx examples/sample01.wav --compute_type int8
|
||||
whisperx path/to/audio.wav --compute_type int8
|
||||
|
||||
### Other languages
|
||||
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||
Just pass in the `--language` code, and use the whisper `--model large`.
|
||||
|
||||
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||
|
||||
|
||||
#### E.g. German
|
||||
whisperx --model large-v2 --language de examples/sample_de_01.wav
|
||||
whisperx --model large-v2 --language de path/to/audio.wav
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||
|
||||
@ -278,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
||||
|
||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||
|
||||
* [ ] Allow silero-vad as alternative VAD option
|
||||
* [x] Allow silero-vad as alternative VAD option
|
||||
|
||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||
|
||||
@ -300,7 +300,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
|
||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||
|
||||
|
||||
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
|
||||
Valuable VAD & Diarization Models from:
|
||||
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||
|
||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .transcribe import load_model
|
||||
from .alignment import load_align_model, align
|
||||
from .audio import load_audio
|
||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
||||
from .asr import load_model
|
||||
|
@ -1,7 +1,9 @@
|
||||
""""
|
||||
"""
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
import math
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
@ -13,8 +15,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
from .types import (
|
||||
AlignedTranscriptionResult,
|
||||
SingleSegment,
|
||||
SingleAlignedSegment,
|
||||
SingleWordSegment,
|
||||
SegmentData,
|
||||
)
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
@ -62,6 +69,8 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
||||
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
||||
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
||||
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
|
||||
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
|
||||
}
|
||||
|
||||
|
||||
@ -131,6 +140,8 @@ def align(
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
total_segments = len(transcript)
|
||||
# Store temporary processing values
|
||||
segment_data: dict[int, SegmentData] = {}
|
||||
for sdx, segment in enumerate(transcript):
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
if print_progress:
|
||||
@ -163,10 +174,17 @@ def align(
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
else:
|
||||
# add placeholder
|
||||
clean_char.append('*')
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||
clean_wdx.append(wdx)
|
||||
else:
|
||||
# index for placeholder
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
|
||||
@ -175,11 +193,13 @@ def align(
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
segment_data[sdx] = {
|
||||
"clean_char": clean_char,
|
||||
"clean_cdx": clean_cdx,
|
||||
"clean_wdx": clean_wdx,
|
||||
"sentence_spans": sentence_spans
|
||||
}
|
||||
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
@ -194,13 +214,14 @@ def align(
|
||||
"end": t2,
|
||||
"text": text,
|
||||
"words": [],
|
||||
"chars": None,
|
||||
}
|
||||
|
||||
if return_char_alignments:
|
||||
aligned_seg["chars"] = []
|
||||
|
||||
# check we can align
|
||||
if len(segment["clean_char"]) == 0:
|
||||
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
@ -210,8 +231,8 @@ def align(
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
text_clean = "".join(segment["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
@ -244,7 +265,8 @@ def align(
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
@ -253,7 +275,7 @@ def align(
|
||||
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
duration = t2 -t1
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
@ -261,8 +283,8 @@ def align(
|
||||
word_idx = 0
|
||||
for cdx, char in enumerate(text):
|
||||
start, end, score = None, None, None
|
||||
if cdx in segment["clean_cdx"]:
|
||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = round(char_seg.score, 3)
|
||||
@ -288,10 +310,10 @@ def align(
|
||||
aligned_subsegments = []
|
||||
# assign sentence_idx to each character index
|
||||
char_segments_arr["sentence-idx"] = None
|
||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||
@ -360,70 +382,203 @@ def align(
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
"""
|
||||
|
||||
|
||||
def get_trellis(emission, tokens, blank_id=0):
|
||||
num_frame = emission.size(0)
|
||||
num_tokens = len(tokens)
|
||||
|
||||
# Trellis has extra diemsions for both time axis and tokens.
|
||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
||||
# The extra dim for time axis is for simplification of the code.
|
||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
||||
trellis[0, 0] = 0
|
||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
||||
trellis[0, -num_tokens:] = -float("inf")
|
||||
trellis[-num_tokens:, 0] = float("inf")
|
||||
trellis = torch.zeros((num_frame, num_tokens))
|
||||
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||
trellis[0, 1:] = -float("inf")
|
||||
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||
|
||||
for t in range(num_frame):
|
||||
for t in range(num_frame - 1):
|
||||
trellis[t + 1, 1:] = torch.maximum(
|
||||
# Score for staying at the same token
|
||||
trellis[t, 1:] + emission[t, blank_id],
|
||||
# Score for changing to the next token
|
||||
trellis[t, :-1] + emission[t, tokens],
|
||||
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||
)
|
||||
return trellis
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""Processing token emission scores containing wildcards (vectorized version)
|
||||
|
||||
Args:
|
||||
frame_emission: Emission probability vector for the current frame
|
||||
tokens: List of token indices
|
||||
blank_id: ID of the blank token
|
||||
|
||||
Returns:
|
||||
tensor: Maximum probability score for each token position
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
|
||||
# Convert tokens to a tensor if they are not already
|
||||
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||
|
||||
# Create a mask to identify wildcard positions
|
||||
wildcard_mask = (tokens == -1)
|
||||
|
||||
# Get scores for non-wildcard positions
|
||||
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||
|
||||
# Create a mask and compute the maximum value without modifying frame_emission
|
||||
max_valid_score = frame_emission.clone() # Create a copy
|
||||
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||
max_valid_score = max_valid_score.max()
|
||||
|
||||
# Use where operation to combine results
|
||||
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class Point:
|
||||
token_index: int
|
||||
time_index: int
|
||||
score: float
|
||||
|
||||
|
||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||
# Note:
|
||||
# j and t are indices for trellis, which has extra dimensions
|
||||
# for time and tokens at the beginning.
|
||||
# When referring to time frame index `T` in trellis,
|
||||
# the corresponding index in emission is `T-1`.
|
||||
# Similarly, when referring to token index `J` in trellis,
|
||||
# the corresponding index in transcript is `J-1`.
|
||||
j = trellis.size(1) - 1
|
||||
t_start = torch.argmax(trellis[:, j]).item()
|
||||
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||
while j > 0:
|
||||
# Should not happen but just in case
|
||||
assert t > 0
|
||||
|
||||
path = []
|
||||
for t in range(t_start, 0, -1):
|
||||
# 1. Figure out if the current position was stay or change
|
||||
# Note (again):
|
||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
||||
# Score for token staying the same from time frame J-1 to T.
|
||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
||||
# Score for token changing from C-1 at T-1 to J at T.
|
||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
||||
# Frame-wise score of stay vs change
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
# p_change = emission[t - 1, tokens[j]]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
# 2. Store the path with frame-wise probability.
|
||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
||||
# Return token index and time index in non-trellis coordinate.
|
||||
path.append(Point(j - 1, t - 1, prob))
|
||||
# Context-aware score for stay vs change
|
||||
stayed = trellis[t - 1, j] + p_stay
|
||||
changed = trellis[t - 1, j - 1] + p_change
|
||||
|
||||
# 3. Update the token
|
||||
# Update position
|
||||
t -= 1
|
||||
if changed > stayed:
|
||||
j -= 1
|
||||
if j == 0:
|
||||
break
|
||||
else:
|
||||
# failed
|
||||
return None
|
||||
|
||||
# Store the path with frame-wise probability.
|
||||
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||
path.append(Point(j, t, prob))
|
||||
|
||||
# Now j == 0, which means, it reached the SoS.
|
||||
# Fill up the rest for the sake of visualization
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return path[::-1]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
points: List[Point]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamState:
|
||||
"""State in beam search."""
|
||||
token_index: int # Current token position
|
||||
time_index: int # Current time step
|
||||
score: float # Cumulative score
|
||||
path: List[Point] # Path history
|
||||
|
||||
|
||||
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
"""Standard CTC beam search backtracking implementation.
|
||||
|
||||
Args:
|
||||
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||
and N is the number of tokens (including the blank token).
|
||||
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||
tokens (List[int]): List of token indices (excluding the blank token).
|
||||
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
List[Point]: the best path
|
||||
"""
|
||||
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
init_state = BeamState(
|
||||
token_index=J,
|
||||
time_index=T,
|
||||
score=trellis[T, J],
|
||||
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||
)
|
||||
|
||||
beams = [init_state]
|
||||
|
||||
while beams and beams[0].token_index > 0:
|
||||
next_beams = []
|
||||
|
||||
for beam in beams:
|
||||
t, j = beam.time_index, beam.token_index
|
||||
|
||||
if t <= 0:
|
||||
continue
|
||||
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
stay_score = trellis[t - 1, j]
|
||||
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||
|
||||
# Stay
|
||||
if not math.isinf(stay_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j,
|
||||
time_index=t - 1,
|
||||
score=stay_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# Change
|
||||
if j > 0 and not math.isinf(change_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j - 1,
|
||||
time_index=t - 1,
|
||||
score=change_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# sort by score
|
||||
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||
|
||||
if not beams:
|
||||
break
|
||||
|
||||
if not beams:
|
||||
return None
|
||||
|
||||
best_beam = beams[0]
|
||||
t = best_beam.time_index
|
||||
j = best_beam.token_index
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
best_beam.path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return best_beam.path[::-1]
|
||||
|
||||
|
||||
# Merge the labels
|
||||
@dataclass
|
||||
class Segment:
|
||||
|
169
whisperx/asr.py
169
whisperx/asr.py
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import replace
|
||||
|
||||
import ctranslate2
|
||||
@ -14,10 +13,12 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .types import SingleSegment, TranscriptionResult
|
||||
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
|
||||
|
||||
from .vads import Vad, Silero, Pyannote
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
"""
|
||||
Finds tokens that represent numeral and symbols.
|
||||
"""
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
@ -27,10 +28,10 @@ def find_numeral_symbol_tokens(tokenizer):
|
||||
return numeral_symbol_tokens
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
"""
|
||||
Wrapper around faster-whisper's WhisperModel to enable batched inference.
|
||||
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch.
|
||||
"""
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
@ -39,13 +40,28 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
"""
|
||||
Generates transcription for a batch of audio segments.
|
||||
|
||||
Args:
|
||||
features: The input audio features.
|
||||
tokenizer: The tokenizer used to decode the generated tokens.
|
||||
options: Transcription options.
|
||||
encoder_output: Output from the encoder model.
|
||||
|
||||
Returns:
|
||||
The decoded transcription text.
|
||||
"""
|
||||
batch_size = features.shape[0]
|
||||
# Initialize tokens and prompt for the generation process.
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
# Check if an initial prompt is provided and handle it.
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
# Prepare the prompt for the current batch.
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
@ -53,118 +69,58 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
|
||||
# Encode the features to obtain the encoder output.
|
||||
encoder_output = self.encode(features)
|
||||
|
||||
# Determine the maximum initial timestamp index based on the options.
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
|
||||
# Generate the transcription result for the batch.
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
# Extract the token sequences from the result.
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
# Define an inner function to decode the tokens for each batch.
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
res = []
|
||||
for tk in tokens:
|
||||
res.append([token for token in tk if token < tokenizer.eot])
|
||||
# text_tokens = [token for token in tokens if token < self.eot]
|
||||
return tokenizer.tokenizer.decode_batch(res)
|
||||
|
||||
# Decode the tokens to get the transcription text.
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
Encodes the audio features using the CTranslate2 storage.
|
||||
|
||||
When the model is running on multiple GPUs, the encoder output should be moved
|
||||
to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
# unsqueeze if batch size = 1
|
||||
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array.
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
# call the model
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad: VoiceActivitySegmentation,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
suppress_numerals: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.suppress_numerals = suppress_numerals
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
self.call_count = 0
|
||||
self.framework = framework
|
||||
if self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
elif isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
elif device < 0:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
self._vad_params = vad_params
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "tokenizer" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
features = log_mel_spectrogram(
|
||||
audio,
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=N_SAMPLES - audio.shape[0],
|
||||
)
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
inputs,
|
||||
@ -208,7 +164,16 @@ class FasterWhisperPipeline(Pipeline):
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||
if issubclass(type(self.vad_model), Vad):
|
||||
waveform = self.vad_model.preprocess_audio(audio)
|
||||
merge_chunks = self.vad_model.merge_chunks
|
||||
else:
|
||||
waveform = Pyannote.preprocess_audio(audio)
|
||||
merge_chunks = Pyannote.merge_chunks
|
||||
|
||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(
|
||||
vad_segments,
|
||||
chunk_size,
|
||||
@ -296,7 +261,8 @@ def load_model(
|
||||
compute_type="float16",
|
||||
asr_options: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
vad_model: Optional[VoiceActivitySegmentation] = None,
|
||||
vad_model: Optional[Vad]= None,
|
||||
vad_method: Optional[str] = "pyannote",
|
||||
vad_options: Optional[dict] = None,
|
||||
model: Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
@ -309,6 +275,7 @@ def load_model(
|
||||
whisper_arch - The name of the Whisper model to load.
|
||||
device - The device to load the model on.
|
||||
compute_type - The compute type to use for the model.
|
||||
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||
options - A dictionary of options to use for the model.
|
||||
language - The language of the model. (use English for now)
|
||||
model - The WhisperModel instance to use.
|
||||
@ -374,6 +341,7 @@ def load_model(
|
||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
@ -381,10 +349,17 @@ def load_model(
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||
if vad_model is not None:
|
||||
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||
vad_model = vad_model
|
||||
else:
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
if vad_method == "silero":
|
||||
vad_model = Silero(**default_vad_options)
|
||||
elif vad_method == "pyannote":
|
||||
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
else:
|
||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
@ -394,4 +369,4 @@ def load_model(
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
||||
)
|
||||
|
@ -79,7 +79,7 @@ def assign_word_speakers(
|
||||
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
@ -26,6 +26,7 @@ def cli():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
@ -46,6 +47,7 @@ def cli():
|
||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||
@ -89,6 +91,7 @@ def cli():
|
||||
model_name: str = args.pop("model")
|
||||
batch_size: int = args.pop("batch_size")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
model_cache_only: bool = args.pop("model_cache_only")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
@ -110,6 +113,7 @@ def cli():
|
||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||
|
||||
hf_token: str = args.pop("hf_token")
|
||||
vad_method: str = args.pop("vad_method")
|
||||
vad_onset: float = args.pop("vad_onset")
|
||||
vad_offset: float = args.pop("vad_offset")
|
||||
|
||||
@ -175,7 +179,7 @@ def cli():
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
|
||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TypedDict, Optional, List
|
||||
from typing import TypedDict, Optional, List, Tuple
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class SegmentData(TypedDict):
|
||||
"""
|
||||
Temporary processing data used during alignment.
|
||||
Contains cleaned and preprocessed data for each segment.
|
||||
"""
|
||||
clean_char: List[str] # Cleaned characters that exist in model dictionary
|
||||
clean_cdx: List[int] # Original indices of cleaned characters
|
||||
clean_wdx: List[int] # Indices of words containing valid characters
|
||||
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
|
@ -106,6 +106,7 @@ LANGUAGES = {
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"lv": "latvian",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
@ -241,7 +242,7 @@ class SubtitlesWriter(ResultWriter):
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
times = []
|
||||
times: list[tuple] = []
|
||||
last = result["segments"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
|
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from whisperx.vads.pyannote import Pyannote
|
||||
from whisperx.vads.silero import Silero
|
||||
from whisperx.vads.vad import Vad
|
@ -1,51 +1,46 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from typing import Callable, Text, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
from pyannote.core import Annotation, SlidingWindowFeature
|
||||
from pyannote.core import Segment
|
||||
from tqdm import tqdm
|
||||
|
||||
from .diarize import Segment as SegmentX
|
||||
|
||||
# deprecated
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
|
||||
vad_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
# Dynamically resolve the path to the model file
|
||||
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
|
||||
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
|
||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||
else:
|
||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||
|
||||
|
||||
# Check if the resolved model file exists
|
||||
if not os.path.exists(model_fp):
|
||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||
|
||||
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
@ -81,21 +76,21 @@ class Binarize:
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
|
||||
Pyannote-audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float('inf')
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float('inf')
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@ -141,7 +136,7 @@ class Binarize:
|
||||
t = start
|
||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||
# currently active
|
||||
if is_active:
|
||||
if is_active:
|
||||
curr_duration = t - start
|
||||
if curr_duration > self.max_duration:
|
||||
search_after = len(curr_scores) // 2
|
||||
@ -151,8 +146,8 @@ class Binarize:
|
||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = curr_timestamps[min_score_div_idx]
|
||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
||||
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||
# switching from active to inactive
|
||||
elif y < self.offset:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
@ -193,11 +188,11 @@ class Binarize:
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
):
|
||||
|
||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||
@ -236,72 +231,35 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
return segmentations
|
||||
|
||||
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
class Pyannote(Vad):
|
||||
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||
active[region, k] = 1
|
||||
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||
print(">>Performing voice activity detection using Pyannote...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
return self.vad_pipeline(audio)
|
||||
|
||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||
active = active.support(collar=min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
active = active.for_json()
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return torch.from_numpy(audio).unsqueeze(0)
|
||||
|
||||
def merge_chunks(
|
||||
segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
# assert segments_list, "segments_list is empty."
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
|
||||
for seg in segments_list:
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
66
whisperx/vads/silero.py
Normal file
66
whisperx/vads/silero.py
Normal file
@ -0,0 +1,66 @@
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Text
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||
|
||||
|
||||
class Silero(Vad):
|
||||
# check again default values
|
||||
def __init__(self, **kwargs):
|
||||
print(">>Performing voice activity detection using Silero...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
|
||||
self.vad_onset = kwargs['vad_onset']
|
||||
self.chunk_size = kwargs['chunk_size']
|
||||
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False,
|
||||
onnx=False,
|
||||
trust_repo=True)
|
||||
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
"""use silero to get segments of speech"""
|
||||
# Only accept 16000 Hz for now.
|
||||
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||
sample_rate = audio["sample_rate"]
|
||||
if sample_rate != 16000:
|
||||
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||
|
||||
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||
model=self.vad_pipeline,
|
||||
sampling_rate=sample_rate,
|
||||
max_speech_duration_s=self.chunk_size,
|
||||
threshold=self.vad_onset
|
||||
# min_silence_duration_ms = self.min_duration_off/1000
|
||||
# min_speech_duration_ms = self.min_duration_on/1000
|
||||
# ...
|
||||
# See silero documentation for full option list
|
||||
)
|
||||
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def merge_chunks(segments_list,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from pyannote.core import Annotation, Segment
|
||||
|
||||
|
||||
class Vad:
|
||||
def __init__(self, vad_onset):
|
||||
if not (0 < vad_onset < 1):
|
||||
raise ValueError(
|
||||
"vad_onset is a decimal value between 0 and 1."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
pass
|
||||
|
||||
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float,
|
||||
offset: Optional[float]):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs: list[tuple]= []
|
||||
speaker_idxs: list[Optional[str]] = []
|
||||
|
||||
curr_start = segments[0].start
|
||||
for seg in segments:
|
||||
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
|
||||
return merged_segments
|
||||
|
||||
# Unused function
|
||||
@staticmethod
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||
active[region, k] = 1
|
||||
|
||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||
active = active.support(collar=min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
active = active.for_json()
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
Reference in New Issue
Block a user