mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
vad filter
This commit is contained in:
@ -116,14 +116,14 @@ audio_file = "audio.mp3"
|
||||
model = whisperx.load_model("large", device)
|
||||
result = model.transcribe(audio_file)
|
||||
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# load alignment model and metadata
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
|
||||
# align whisper output
|
||||
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
|
||||
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
print(result_aligned["segments"]) # after alignment
|
||||
print(result_aligned["word_segments"]) # after alignment
|
||||
```
|
||||
|
@ -8,11 +8,12 @@ import torch
|
||||
import torchaudio
|
||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||
import tqdm
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
|
||||
import pandas as pd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@ -45,6 +46,7 @@ def transcribe(
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
|
||||
mel: np.ndarray = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -100,7 +102,8 @@ def transcribe(
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
@ -293,8 +296,10 @@ def align(
|
||||
model_type = align_model_metadata['type']
|
||||
|
||||
prev_t2 = 0
|
||||
word_segments_list = []
|
||||
total_word_segments_list = []
|
||||
vad_segments_list = []
|
||||
for idx, segment in enumerate(transcript):
|
||||
word_segments_list = []
|
||||
# first we pad
|
||||
t1 = max(segment['start'] - extend_duration, 0)
|
||||
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
|
||||
@ -326,6 +331,30 @@ def align(
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
|
||||
ratio = waveform_segment.size(0) / emission.size(0)
|
||||
space_idx = model_dictionary['|']
|
||||
# find non-vad segments
|
||||
for i in range(1, len(segment['vad'])):
|
||||
start = segment['vad'][i-1][1]
|
||||
end = segment['vad'][i][0]
|
||||
if start < end: # check if there is a gap between intervals
|
||||
non_vad_f1 = int(start / ratio)
|
||||
non_vad_f2 = int(end / ratio)
|
||||
# non-vad should be masked, use space to do so
|
||||
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
||||
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
||||
|
||||
|
||||
start = segment['vad'][i][1]
|
||||
end = segment['end']
|
||||
non_vad_f1 = int(start / ratio)
|
||||
non_vad_f2 = int(end / ratio)
|
||||
# non-vad should be masked, use space to do so
|
||||
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
||||
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
||||
|
||||
transcription = segment['text'].strip()
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
t_words = transcription.split(' ')
|
||||
@ -400,9 +429,33 @@ def align(
|
||||
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||
|
||||
if 'vad' in segment:
|
||||
curr_vdx = 0
|
||||
curr_text = ''
|
||||
for wrd_seg in word_segments_list:
|
||||
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
|
||||
curr_speaker = segment['speakers'][curr_vdx]
|
||||
vad_segments_list.append(
|
||||
{'start': segment['vad'][curr_vdx][0],
|
||||
'end': segment['vad'][curr_vdx][1],
|
||||
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
||||
)
|
||||
curr_vdx += 1
|
||||
curr_text = ''
|
||||
curr_text += ' ' + wrd_seg['text']
|
||||
if len(curr_text) > 0:
|
||||
curr_speaker = segment['speakers'][curr_vdx]
|
||||
vad_segments_list.append(
|
||||
{'start': segment['vad'][curr_vdx][0],
|
||||
'end': segment['vad'][curr_vdx][1],
|
||||
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
||||
)
|
||||
curr_text = ''
|
||||
total_word_segments_list += word_segments_list
|
||||
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
|
||||
|
||||
return {"segments": transcript, "word_segments": word_segments_list}
|
||||
|
||||
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
if model_name is None:
|
||||
@ -439,6 +492,91 @@ def load_align_model(language_code, device, model_name=None):
|
||||
|
||||
return align_model, align_metadata
|
||||
|
||||
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
|
||||
'''
|
||||
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
||||
'''
|
||||
|
||||
curr_start = 0
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
for sdx, seg in enumerate(segments):
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
"speakers": speaker_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
"speakers": speaker_idxs
|
||||
})
|
||||
return merged_segments
|
||||
|
||||
|
||||
|
||||
def transcribe_segments(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
merged_segments,
|
||||
mel = None,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
Transcribe according to predefined VAD segments.
|
||||
'''
|
||||
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
prev = 0
|
||||
|
||||
output = {'segments': []}
|
||||
|
||||
for sdx, seg_t in enumerate(merged_segments):
|
||||
print(sdx, seg_t['start'], seg_t['end'], '...')
|
||||
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
|
||||
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
||||
mel = mel[:, local_f_start:] # seek forward
|
||||
prev = seg_f_start
|
||||
local_mel = mel[:, :local_f_end-local_f_start]
|
||||
result = transcribe(model, audio, mel=local_mel, **kwargs)
|
||||
seg_t['text'] = result['text']
|
||||
output['segments'].append(
|
||||
{
|
||||
'start': seg_t['start'],
|
||||
'end': seg_t['end'],
|
||||
'language': result['language'],
|
||||
'text': result['text'],
|
||||
'seg-text': [x['text'] for x in result['segments']],
|
||||
'seg-start': [x['start'] for x in result['segments']],
|
||||
'seg-end': [x['end'] for x in result['segments']],
|
||||
}
|
||||
)
|
||||
|
||||
output['language'] = output['segments'][0]['language']
|
||||
|
||||
return output
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
@ -452,7 +590,8 @@ def cli():
|
||||
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
||||
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
||||
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
||||
|
||||
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
|
||||
parser.add_argument("--vad_input", default=None, type=str)
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
|
||||
|
||||
@ -490,6 +629,17 @@ def cli():
|
||||
align_from_prev: bool = args.pop("align_from_prev")
|
||||
drop_non_aligned: bool = args.pop("drop_non_aligned")
|
||||
|
||||
vad_filter: bool = args.pop("vad_filter")
|
||||
vad_input: bool = args.pop("vad_input")
|
||||
|
||||
vad_pipeline = None
|
||||
if vad_input is not None:
|
||||
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
||||
elif vad_filter:
|
||||
from pyannote.audio import Pipeline
|
||||
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
|
||||
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
@ -515,8 +665,26 @@ def cli():
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
print("Performing transcription...")
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
if vad_filter or vad_input is not None:
|
||||
output_segments = []
|
||||
if vad_filter:
|
||||
print("Performing VAD...")
|
||||
# vad_segments = vad_pipeline(audio_path)
|
||||
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
|
||||
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
|
||||
vad_segments = vad_pipeline(audio_path)
|
||||
for speech_turn in vad_segments.get_timeline().support():
|
||||
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
elif vad_input is not None:
|
||||
# rttm format
|
||||
for idx, row in vad_input.iterrows():
|
||||
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
|
||||
vad_segments = merge_chunks(output_segments)
|
||||
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
|
||||
else:
|
||||
vad_segments = None
|
||||
print("Performing transcription...")
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
if result["language"] != align_metadata["language"]:
|
||||
# load new language
|
||||
@ -548,8 +716,13 @@ def cli():
|
||||
write_srt(result_aligned["word_segments"], file=srt)
|
||||
|
||||
# save ASS
|
||||
# with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||
# write_ass(result_aligned["segments"], file=ass)
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass)
|
||||
|
||||
if vad_filter is not None:
|
||||
# save per-word SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result_aligned["vad_segments"], file=srt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -193,7 +193,7 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
|
||||
curr_words = [wrd['text'] for wrd in segment['word-level']]
|
||||
prev = segment['word-level'][0]['start']
|
||||
if prev is None:
|
||||
prev = 0
|
||||
prev = segment['start']
|
||||
for wdx, word in enumerate(segment['word-level']):
|
||||
if word['start'] is not None:
|
||||
# fill gap between previous word
|
||||
|
Reference in New Issue
Block a user