From ba102feb7ff30e6f8345f00470955f5632e767e2 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Fri, 20 Jan 2023 12:54:20 +0000 Subject: [PATCH] vad filter --- README.md | 4 +- whisperx/transcribe.py | 191 +++++++++++++++++++++++++++++++++++++++-- whisperx/utils.py | 2 +- 3 files changed, 185 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 6af17b3..f89918a 100644 --- a/README.md +++ b/README.md @@ -116,14 +116,14 @@ audio_file = "audio.mp3" model = whisperx.load_model("large", device) result = model.transcribe(audio_file) +print(result["segments"]) # before alignment + # load alignment model and metadata model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) # align whisper output result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device) -print(result["segments"]) # before alignment - print(result_aligned["segments"]) # after alignment print(result_aligned["word_segments"]) # after alignment ``` diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 772143d..9ff5a64 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -8,11 +8,12 @@ import torch import torchaudio from transformers import AutoProcessor, Wav2Vec2ForCTC import tqdm -from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio +from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio from .alignment import get_trellis, backtrack, merge_repeats, merge_words from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass +import pandas as pd if TYPE_CHECKING: from .model import Whisper @@ -45,6 +46,7 @@ def transcribe( logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = False, # turn off by default due to errors it causes + mel: np.ndarray = None, **decode_options, ): """ @@ -100,7 +102,8 @@ def transcribe( if dtype == torch.float32: decode_options["fp16"] = False - mel = log_mel_spectrogram(audio) + if mel is None: + mel = log_mel_spectrogram(audio) if decode_options.get("language", None) is None: if not model.is_multilingual: @@ -293,8 +296,10 @@ def align( model_type = align_model_metadata['type'] prev_t2 = 0 - word_segments_list = [] + total_word_segments_list = [] + vad_segments_list = [] for idx, segment in enumerate(transcript): + word_segments_list = [] # first we pad t1 = max(segment['start'] - extend_duration, 0) t2 = min(segment['end'] + extend_duration, MAX_DURATION) @@ -326,6 +331,30 @@ def align( emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() + + if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary: + ratio = waveform_segment.size(0) / emission.size(0) + space_idx = model_dictionary['|'] + # find non-vad segments + for i in range(1, len(segment['vad'])): + start = segment['vad'][i-1][1] + end = segment['vad'][i][0] + if start < end: # check if there is a gap between intervals + non_vad_f1 = int(start / ratio) + non_vad_f2 = int(end / ratio) + # non-vad should be masked, use space to do so + emission[non_vad_f1:non_vad_f2, :] = float("-inf") + emission[non_vad_f1:non_vad_f2, space_idx] = 0 + + + start = segment['vad'][i][1] + end = segment['end'] + non_vad_f1 = int(start / ratio) + non_vad_f2 = int(end / ratio) + # non-vad should be masked, use space to do so + emission[non_vad_f1:non_vad_f2, :] = float("-inf") + emission[non_vad_f1:non_vad_f2, space_idx] = 0 + transcription = segment['text'].strip() if model_lang not in LANGUAGES_WITHOUT_SPACES: t_words = transcription.split(' ') @@ -400,9 +429,33 @@ def align( segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']}) word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']}) + if 'vad' in segment: + curr_vdx = 0 + curr_text = '' + for wrd_seg in word_segments_list: + if wrd_seg['start'] > segment['vad'][curr_vdx][1]: + curr_speaker = segment['speakers'][curr_vdx] + vad_segments_list.append( + {'start': segment['vad'][curr_vdx][0], + 'end': segment['vad'][curr_vdx][1], + 'text': f"[{curr_speaker}]: " + curr_text.strip()} + ) + curr_vdx += 1 + curr_text = '' + curr_text += ' ' + wrd_seg['text'] + if len(curr_text) > 0: + curr_speaker = segment['speakers'][curr_vdx] + vad_segments_list.append( + {'start': segment['vad'][curr_vdx][0], + 'end': segment['vad'][curr_vdx][1], + 'text': f"[{curr_speaker}]: " + curr_text.strip()} + ) + curr_text = '' + total_word_segments_list += word_segments_list print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}") - return {"segments": transcript, "word_segments": word_segments_list} + + return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list} def load_align_model(language_code, device, model_name=None): if model_name is None: @@ -439,6 +492,91 @@ def load_align_model(language_code, device, model_name=None): return align_model, align_metadata +def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False): + ''' + Merge VAD segments into larger segments of size ~CHUNK_LENGTH. + ''' + + curr_start = 0 + curr_end = 0 + merged_segments = [] + seg_idxs = [] + speaker_idxs = [] + for sdx, seg in enumerate(segments): + if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + "speakers": speaker_idxs, + }) + curr_start = seg.start + seg_idxs = [] + speaker_idxs = [] + curr_end = seg.end + seg_idxs.append((seg.start, seg.end)) + speaker_idxs.append(seg.speaker) + # add final + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + "speakers": speaker_idxs + }) + return merged_segments + + + +def transcribe_segments( + model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + merged_segments, + mel = None, + **kwargs +): + ''' + Transcribe according to predefined VAD segments. + ''' + + if mel is None: + mel = log_mel_spectrogram(audio) + + prev = 0 + + output = {'segments': []} + + for sdx, seg_t in enumerate(merged_segments): + print(sdx, seg_t['start'], seg_t['end'], '...') + seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH) + local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev + mel = mel[:, local_f_start:] # seek forward + prev = seg_f_start + local_mel = mel[:, :local_f_end-local_f_start] + result = transcribe(model, audio, mel=local_mel, **kwargs) + seg_t['text'] = result['text'] + output['segments'].append( + { + 'start': seg_t['start'], + 'end': seg_t['end'], + 'language': result['language'], + 'text': result['text'], + 'seg-text': [x['text'] for x in result['segments']], + 'seg-start': [x['start'] for x in result['segments']], + 'seg-end': [x['end'] for x in result['segments']], + } + ) + + output['language'] = output['segments'][0]['language'] + + return output + +class Segment: + def __init__(self, start, end, speaker=None): + self.start = start + self.end = end + self.speaker = speaker + + def cli(): from . import available_models @@ -452,7 +590,8 @@ def cli(): parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment") parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment") parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.") - + parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...") + parser.add_argument("--vad_input", default=None, type=str) parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save") @@ -490,6 +629,17 @@ def cli(): align_from_prev: bool = args.pop("align_from_prev") drop_non_aligned: bool = args.pop("drop_non_aligned") + vad_filter: bool = args.pop("vad_filter") + vad_input: bool = args.pop("vad_input") + + vad_pipeline = None + if vad_input is not None: + vad_input = pd.read_csv(vad_input, header=None, sep= " ") + elif vad_filter: + from pyannote.audio import Pipeline + vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") + # vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1") + os.makedirs(output_dir, exist_ok=True) if model_name.endswith(".en") and args["language"] not in {"en", "English"}: @@ -515,8 +665,26 @@ def cli(): align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) for audio_path in args.pop("audio"): - print("Performing transcription...") - result = transcribe(model, audio_path, temperature=temperature, **args) + if vad_filter or vad_input is not None: + output_segments = [] + if vad_filter: + print("Performing VAD...") + # vad_segments = vad_pipeline(audio_path) + # for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True): + # output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker)) + vad_segments = vad_pipeline(audio_path) + for speech_turn in vad_segments.get_timeline().support(): + output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN")) + elif vad_input is not None: + # rttm format + for idx, row in vad_input.iterrows(): + output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}")) + vad_segments = merge_chunks(output_segments) + result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args) + else: + vad_segments = None + print("Performing transcription...") + result = transcribe(model, audio_path, temperature=temperature, **args) if result["language"] != align_metadata["language"]: # load new language @@ -548,8 +716,13 @@ def cli(): write_srt(result_aligned["word_segments"], file=srt) # save ASS - # with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: - # write_ass(result_aligned["segments"], file=ass) + with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: + write_ass(result_aligned["segments"], file=ass) + + if vad_filter is not None: + # save per-word SRT + with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt: + write_srt(result_aligned["vad_segments"], file=srt) if __name__ == '__main__': diff --git a/whisperx/utils.py b/whisperx/utils.py index e7c8c33..56e3483 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -193,7 +193,7 @@ def write_ass(transcript: Iterator[dict], file: TextIO, curr_words = [wrd['text'] for wrd in segment['word-level']] prev = segment['word-level'][0]['start'] if prev is None: - prev = 0 + prev = segment['start'] for wdx, word in enumerate(segment['word-level']): if word['start'] is not None: # fill gap between previous word