2 Commits

15 changed files with 270 additions and 517 deletions

View File

@ -10,6 +10,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Set up Python
uses: actions/setup-python@v5

View File

@ -129,7 +129,7 @@ To **enable Speaker Diarization**, include your Hugging Face access token (read)
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx path/to/audio.wav
whisperx examples/sample01.wav
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -143,27 +143,27 @@ https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
To run on CPU instead of GPU (and for running on Mac OS X):
whisperx path/to/audio.wav --compute_type int8
whisperx examples/sample01.wav --compute_type int8
### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
Just pass in the `--language` code, and use the whisper `--model large`.
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
#### E.g. German
whisperx --model large-v2 --language de path/to/audio.wav
whisperx --model large-v2 --language de examples/sample_de_01.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
@ -278,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [x] Allow silero-vad as alternative VAD option
* [ ] Allow silero-vad as alternative VAD option
* [ ] Improve diarization (word level). *Harder than first thought...*
@ -300,9 +300,7 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from:
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)

View File

@ -1,7 +1,7 @@
torch>=2
torchaudio>=2
faster-whisper==1.1.0
ctranslate2<4.5.0
ctranslate2>=4.5.0
transformers
pandas
setuptools>=65

View File

@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.3.1",
version="3.3.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
long_description=long_description,
long_description_content_type="text/markdown",

View File

@ -1,4 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline
from .asr import load_model
from .diarize import assign_word_speakers, DiarizationPipeline

View File

@ -1,9 +1,7 @@
"""
""""
Forced Alignment with Whisper
C. Max Bain
"""
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
@ -15,13 +13,8 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import (
AlignedTranscriptionResult,
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -69,8 +62,6 @@ DEFAULT_ALIGN_MODELS_HF = {
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
"gl": "ifrz/wav2vec2-large-xlsr-galician",
"ka": "xsway/wav2vec2-large-xlsr-georgian",
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
}
@ -140,8 +131,6 @@ def align(
# 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount.
if print_progress:
@ -174,17 +163,10 @@ def align(
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
@ -193,13 +175,11 @@ def align(
sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment_data[sdx] = {
"clean_char": clean_char,
"clean_cdx": clean_cdx,
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
@ -214,14 +194,13 @@ def align(
"end": t2,
"text": text,
"words": [],
"chars": None,
}
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment_data[sdx]["clean_char"]) == 0:
if len(segment["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
@ -231,8 +210,8 @@ def align(
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary.get(c, -1) for c in text_clean]
text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
@ -265,8 +244,7 @@ def align(
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -275,7 +253,7 @@ def align(
char_segments = merge_repeats(path, text_clean)
duration = t2 - t1
duration = t2 -t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters
@ -283,8 +261,8 @@ def align(
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
@ -310,10 +288,10 @@ def align(
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
end_chars = curr_chars[curr_chars["char"] != ' ']
@ -382,202 +360,69 @@ def align(
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1:, 0] = float("inf")
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame - 1):
for t in range(num_frame):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
trellis[t, :-1] + emission[t, tokens],
)
return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)
# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# Update position
t -= 1
# 3. Update the token
if changed > stayed:
j -= 1
# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break
if not beams:
if j == 0:
break
else:
# failed
return None
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
return path[::-1]
# Merge the labels
@dataclass

View File

@ -1,5 +1,6 @@
import os
from typing import List, Optional, Union
import warnings
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace
import ctranslate2
@ -13,12 +14,10 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .types import SingleSegment, TranscriptionResult
from .vads import Vad, Silero, Pyannote
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
"""
Finds tokens that represent numeral and symbols.
"""
numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
@ -28,10 +27,10 @@ def find_numeral_symbol_tokens(tokenizer):
return numeral_symbol_tokens
class WhisperModel(faster_whisper.WhisperModel):
"""
Wrapper around faster-whisper's WhisperModel to enable batched inference.
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch.
"""
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(
self,
@ -40,28 +39,13 @@ class WhisperModel(faster_whisper.WhisperModel):
options: TranscriptionOptions,
encoder_output=None,
):
"""
Generates transcription for a batch of audio segments.
Args:
features: The input audio features.
tokenizer: The tokenizer used to decode the generated tokens.
options: Transcription options.
encoder_output: Output from the encoder model.
Returns:
The decoded transcription text.
"""
batch_size = features.shape[0]
# Initialize tokens and prompt for the generation process.
all_tokens = []
prompt_reset_since = 0
# Check if an initial prompt is provided and handle it.
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
# Prepare the prompt for the current batch.
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
@ -69,58 +53,118 @@ class WhisperModel(faster_whisper.WhisperModel):
without_timestamps=options.without_timestamps,
prefix=options.prefix,
)
# Encode the features to obtain the encoder output.
encoder_output = self.encode(features)
# Determine the maximum initial timestamp index based on the options.
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
# Generate the transcription result for the batch.
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
# Extract the token sequences from the result.
tokens_batch = [x.sequences_ids[0] for x in result]
# Define an inner function to decode the tokens for each batch.
def decode_batch(tokens: List[List[int]]) -> str:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)
# Decode the tokens to get the transcription text.
text = decode_batch(tokens_batch)
return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
"""
Encodes the audio features using the CTranslate2 storage.
When the model is running on multiple GPUs, the encoder output should be moved
to the CPU since we don't know which GPU will handle the next job.
"""
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array.
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features)
# call the model
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.suppress_numerals = suppress_numerals
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
super(Pipeline, self).__init__()
self.vad_model = vad
self._vad_params = vad_params
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, audio):
audio = audio['inputs']
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
def postprocess(self, model_outputs):
return model_outputs
def get_iterator(
self,
inputs,
@ -164,16 +208,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
@ -261,8 +296,7 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = "pyannote",
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -275,7 +309,6 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
@ -341,7 +374,6 @@ def load_model(
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500,
"vad_offset": 0.363
}
@ -349,17 +381,10 @@ def load_model(
if vad_options is not None:
default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model
else:
if vad_method == "silero":
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
return FasterWhisperPipeline(
model=model,
@ -369,4 +394,4 @@ def load_model(
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
)

View File

@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment:
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker

View File

@ -26,7 +26,6 @@ def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
@ -47,7 +46,6 @@ def cli():
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
@ -91,7 +89,6 @@ def cli():
model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir")
model_cache_only: bool = args.pop("model_cache_only")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
@ -113,7 +110,6 @@ def cli():
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
@ -179,7 +175,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List, Tuple
from typing import TypedDict, Optional, List
class SingleWordSegment(TypedDict):
@ -30,17 +30,6 @@ class SingleSegment(TypedDict):
text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.

View File

@ -106,7 +106,6 @@ LANGUAGES = {
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
"lv": "latvian",
}
# language code lookup by name, with a few language aliases
@ -242,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times: list[tuple] = []
times = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):

View File

@ -1,46 +1,51 @@
import hashlib
import os
import urllib
from typing import Callable, Text, Union
from typing import Optional
from typing import Callable, Optional, Text, Union
import numpy as np
import pandas as pd
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from tqdm import tqdm
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from .diarize import Segment as SegmentX
# deprecated
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
vad_dir = os.path.dirname(os.path.abspath(__file__))
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
# Dynamically resolve the path to the model file
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
@ -76,21 +81,21 @@ class Binarize:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
@ -136,7 +141,7 @@ class Binarize:
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
@ -146,8 +151,8 @@ class Binarize:
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
@ -188,11 +193,11 @@ class Binarize:
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
@ -231,35 +236,72 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
return segmentations
class Pyannote(Vad):
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
def merge_chunks(
segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments

View File

@ -1,3 +0,0 @@
from whisperx.vads.pyannote import Pyannote
from whisperx.vads.silero import Silero
from whisperx.vads.vad import Vad

View File

@ -1,66 +0,0 @@
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments_list,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

View File

@ -1,74 +0,0 @@
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
# Unused function
@staticmethod
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs