From d395c21b8399cb2f29643a75f91469917cdbb991 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Tue, 24 Jan 2023 15:02:08 +0000 Subject: [PATCH] new logic, diarization, vad filtering --- README.md | 17 +- requirements.txt | 1 + whisperx/__init__.py | 2 +- whisperx/audio.py | 4 +- whisperx/model.py | 15 +- whisperx/normalizers/english.json | 1 - whisperx/transcribe.py | 608 ++++++++++++++++++++---------- whisperx/utils.py | 110 ++++-- 8 files changed, 498 insertions(+), 260 deletions(-) diff --git a/README.md b/README.md index f89918a..fb08e80 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,13 @@ This repository refines the timestamps of openAI's Whisper model via forced alig **Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation. +

New🚨

+ +- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) +- Character level timestamps (see `*.char.ass` file output) +- Diarization (still in beta, add `--diarization`) + +

Setup ⚙️

Install this package using @@ -76,9 +83,9 @@ Run whisper on example segment (using default params) whisperx examples/sample01.wav -For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models e.g. +For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g. - whisperx examples/sample01.wav --model large.en --align_model WAV2VEC2_ASR_LARGE_LV60K_960H + whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H Result using *WhisperX* with forced alignment to wav2vec2.0 large: @@ -162,7 +169,11 @@ The next major upgrade we are working on is whisper with speaker diarization, so [x] ~~Python usage~~ done -[ ] Incorporating word-level speaker diarization +[x] ~~Character level timestamps~~ + +[x] ~~Incorporating speaker diarization~~ + +[ ] Improve diarization (word level) [ ] Inference speedup with batch processing diff --git a/requirements.txt b/requirements.txt index 26a53c3..999a8d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ soundfile more-itertools transformers>=4.19.0 ffmpeg-python==0.2.0 +pyannote.audio diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 839b29a..4f253b3 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -11,7 +11,7 @@ from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .model import Whisper, ModelDimensions -from .transcribe import transcribe, load_align_model, align +from .transcribe import transcribe, load_align_model, align, transcribe_with_vad _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", diff --git a/whisperx/audio.py b/whisperx/audio.py index a3d8a13..b6c7e83 100644 --- a/whisperx/audio.py +++ b/whisperx/audio.py @@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int window = torch.hann_window(N_FFT).to(audio.device) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) - magnitudes = stft[:, :-1].abs() ** 2 + magnitudes = stft[..., :-1].abs() ** 2 filters = mel_filters(audio.device, n_mels) mel_spec = filters @ magnitudes @@ -121,4 +121,4 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - return log_spec + return log_spec \ No newline at end of file diff --git a/whisperx/model.py b/whisperx/model.py index ca3928e..9bdd84e 100644 --- a/whisperx/model.py +++ b/whisperx/model.py @@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module): k = kv_cache[self.key] v = kv_cache[self.value] - wv = self.qkv_attention(q, k, v, mask) - return self.out(wv) + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): n_batch, n_ctx, n_state = q.shape @@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() - w = F.softmax(qk.float(), dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() class ResidualAttentionBlock(nn.Module): @@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module): mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ): - x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.mlp(self.mlp_ln(x)) return x @@ -264,4 +265,4 @@ class Whisper(nn.Module): detect_language = detect_language_function transcribe = transcribe_function - decode = decode_function + decode = decode_function \ No newline at end of file diff --git a/whisperx/normalizers/english.json b/whisperx/normalizers/english.json index bd84ae7..74a1c35 100644 --- a/whisperx/normalizers/english.json +++ b/whisperx/normalizers/english.json @@ -1737,6 +1737,5 @@ "yoghurt": "yogurt", "yoghurts": "yogurts", "mhm": "hmm", - "mm": "hmm", "mmm": "hmm" } \ No newline at end of file diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 9ff5a64..ba92d3d 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, from .alignment import get_trellis, backtrack, merge_repeats, merge_words from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass +from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv import pandas as pd if TYPE_CHECKING: @@ -280,8 +280,39 @@ def align( device: str, extend_duration: float = 0.0, start_from_previous: bool = True, - drop_non_aligned_words: bool = False, + interpolate_method: str = "nearest", ): + """ + Force align phoneme recognition predictions to known transcription + + Parameters + ---------- + transcript: Iterator[dict] + The Whisper model instance + + model: torch.nn.Module + Alignment model (wav2vec2) + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + device: str + cuda device + + extend_duration: float + Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds + + If the gzip compression ratio is above this value, treat as failed + + interpolate_method: str ["nearest", "linear", "ignore"] + Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary. + "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output. + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level details ("segments"), and + the spoken language ("language"), which is detected when `decode_options["language"]` is None. + """ if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) @@ -291,171 +322,266 @@ def align( MAX_DURATION = audio.shape[1] / SAMPLE_RATE - model_dictionary = align_model_metadata['dictionary'] - model_lang = align_model_metadata['language'] - model_type = align_model_metadata['type'] + model_dictionary = align_model_metadata["dictionary"] + model_lang = align_model_metadata["language"] + model_type = align_model_metadata["type"] + + aligned_segments = [] prev_t2 = 0 - total_word_segments_list = [] - vad_segments_list = [] - for idx, segment in enumerate(transcript): - word_segments_list = [] - # first we pad - t1 = max(segment['start'] - extend_duration, 0) - t2 = min(segment['end'] + extend_duration, MAX_DURATION) + sdx = 0 + for segment in transcript: + while True: + segment_align_success = False - # use prev_t2 as current t1 if it's later - if start_from_previous and t1 < prev_t2: - t1 = prev_t2 + # strip spaces at beginning / end, but keep track of the amount. + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + transcription = segment["text"] - # check if timestamp range is still valid - if t1 >= MAX_DURATION: - print("Failed to align segment: original start time longer than audio duration, skipping...") - continue - if t2 - t1 < 0.02: - print("Failed to align segment: duration smaller than 0.02s time precision") - continue + # TODO: convert number tokenizer / symbols to phonetic words for alignment. + # e.g. "$300" -> "three hundred dollars" + # currently "$300" is ignored since no characters present in the phonetic dictionary - f1 = int(t1 * SAMPLE_RATE) - f2 = int(t2 * SAMPLE_RATE) - - waveform_segment = audio[:, f1:f2] - - with torch.inference_mode(): - if model_type == "torchaudio": - emissions, _ = model(waveform_segment.to(device)) - elif model_type == "huggingface": - emissions = model(waveform_segment.to(device)).logits + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = transcription.split(" ") else: - raise NotImplementedError(f"Align model of type {model_type} not supported.") - emissions = torch.log_softmax(emissions, dim=-1) + per_word = transcription - emission = emissions[0].cpu().detach() + # first check that characters in transcription can be aligned (they are contained in align model"s dictionary) + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(transcription): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces + if model_lang not in LANGUAGES_WITHOUT_SPACES: + char_ = char_.replace(" ", "|") + + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(transcription) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) - if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary: - ratio = waveform_segment.size(0) / emission.size(0) - space_idx = model_dictionary['|'] - # find non-vad segments - for i in range(1, len(segment['vad'])): - start = segment['vad'][i-1][1] - end = segment['vad'][i][0] - if start < end: # check if there is a gap between intervals - non_vad_f1 = int(start / ratio) - non_vad_f2 = int(end / ratio) - # non-vad should be masked, use space to do so - emission[non_vad_f1:non_vad_f2, :] = float("-inf") - emission[non_vad_f1:non_vad_f2, space_idx] = 0 + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any([c in model_dictionary.keys() for c in wrd]): + clean_wdx.append(wdx) - - start = segment['vad'][i][1] - end = segment['end'] - non_vad_f1 = int(start / ratio) - non_vad_f2 = int(end / ratio) - # non-vad should be masked, use space to do so - emission[non_vad_f1:non_vad_f2, :] = float("-inf") - emission[non_vad_f1:non_vad_f2, space_idx] = 0 - - transcription = segment['text'].strip() - if model_lang not in LANGUAGES_WITHOUT_SPACES: - t_words = transcription.split(' ') - else: - t_words = [c for c in transcription] - - t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words] - t_words_nonempty = [x for x in t_words_clean if x != ""] - t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""] - segment['word-level'] = [] - - fail_fallback = False - if len(t_words_nonempty) > 0: - transcription_cleaned = "|".join(t_words_nonempty).lower() + # if no characters are in the dictionary, then we skip this segment... + if len(clean_char) == 0: + print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...") + break + + transcription_cleaned = "".join(clean_char) tokens = [model_dictionary[c] for c in transcription_cleaned] + + # pad according original timestamps + t1 = max(segment["start"] - extend_duration, 0) + t2 = min(segment["end"] + extend_duration, MAX_DURATION) + + # use prev_t2 as current t1 if it"s later + if start_from_previous and t1 < prev_t2: + t1 = prev_t2 + + # check if timestamp range is still valid + if t1 >= MAX_DURATION: + print("Failed to align segment: original start time longer than audio duration, skipping...") + break + if t2 - t1 < 0.02: + print("Failed to align segment: duration smaller than 0.02s time precision") + break + + f1 = int(t1 * SAMPLE_RATE) + f2 = int(t2 * SAMPLE_RATE) + + waveform_segment = audio[:, f1:f2] + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device)) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + trellis = get_trellis(emission, tokens) path = backtrack(trellis, emission, tokens) if path is None: print("Failed to align segment: backtrack failed, resorting to original...") - fail_fallback = True - else: - segments = merge_repeats(path, transcription_cleaned) - word_segments = merge_words(segments) - ratio = waveform_segment.size(0) / (trellis.size(0) - 1) + break + char_segments = merge_repeats(path, transcription_cleaned) + # word_segments = merge_words(char_segments) + - duration = t2 - t1 - local = [] - t_local = [None] * len(t_words) - for wdx, word in enumerate(word_segments): - t1_ = ratio * word.start - t2_ = ratio * word.end - local.append((t1_, t2_)) - t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1) - t1_actual = t1 + local[0][0] * duration - t2_actual = t1 + local[-1][1] * duration + # sub-segments + if "seg-text" not in segment: + segment["seg-text"] = [transcription] + + v = 0 + seg_lens = [0] + [len(x) for x in segment["seg-text"]] + seg_lens_cumsum = [v := v + n for n in seg_lens] + sub_seg_idx = 0 - segment['start'] = t1_actual - segment['end'] = t2_actual - prev_t2 = segment['end'] + char_level = { + "start": [], + "end": [], + "score": [], + "word-index": [], + } - # for the .ass output - for x in range(len(t_local)): - curr_word = t_words[x] - curr_timestamp = t_local[x] - if curr_timestamp is not None: - segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]}) + word_level = { + "start": [], + "end": [], + "score": [], + "segment-text-start": [], + "segment-text-end": [] + } + + wdx = 0 + seg_start_actual, seg_end_actual = None, None + duration = t2 - t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + cdx_prev = 0 + for cdx, char in enumerate(transcription + " "): + is_last = False + if cdx == len(transcription): + break + elif cdx+1 == len(transcription): + is_last = True + + + start, end, score = None, None, None + if cdx in clean_cdx: + char_seg = char_segments[clean_cdx.index(cdx)] + start = char_seg.start * ratio + t1 + end = char_seg.end * ratio + t1 + score = char_seg.score + + char_level["start"].append(start) + char_level["end"].append(end) + char_level["score"].append(score) + char_level["word-index"].append(wdx) + + # word-level info + if model_lang in LANGUAGES_WITHOUT_SPACES: + # character == word + wdx += 1 + elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: + wdx += 1 + word_level["start"].append(None) + word_level["end"].append(None) + word_level["score"].append(None) + word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx]) + word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx]) + cdx_prev = cdx+2 + + if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: + if model_lang not in LANGUAGES_WITHOUT_SPACES: + char_level = pd.DataFrame(char_level) + word_level = pd.DataFrame(word_level) + + not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " " + word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space + word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word + + # fill missing + if interpolate_method != "ignore": + word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method) + word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method) + word_level["start"] = word_level["start"].values.tolist() + word_level["end"] = word_level["end"].values.tolist() + word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores + + char_level = char_level.replace({np.nan:None}).to_dict("list") + word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list") else: - segment['word-level'].append({"text": curr_word, "start": None, "end": None}) + word_level = None - # for per-word .srt ouput - # merge missing words to previous, or merge with next word ahead if idx == 0 - found_first_ts = False - for x in range(len(t_local)): - curr_word = t_words[x] - curr_timestamp = t_local[x] - if curr_timestamp is not None: - word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]}) - found_first_ts = True - elif not drop_non_aligned_words: - # then we merge - if not found_first_ts: - t_words[x+1] = " ".join([curr_word, t_words[x+1]]) - else: - word_segments_list[-1]['text'] += ' ' + curr_word - else: - fail_fallback = True - - if fail_fallback: - # then we resort back to original whisper timestamps - # segment['start] and segment['end'] are unchanged - prev_t2 = 0 - segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']}) - word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']}) - - if 'vad' in segment: - curr_vdx = 0 - curr_text = '' - for wrd_seg in word_segments_list: - if wrd_seg['start'] > segment['vad'][curr_vdx][1]: - curr_speaker = segment['speakers'][curr_vdx] - vad_segments_list.append( - {'start': segment['vad'][curr_vdx][0], - 'end': segment['vad'][curr_vdx][1], - 'text': f"[{curr_speaker}]: " + curr_text.strip()} + aligned_segments.append( + { + "text": segment["seg-text"][sub_seg_idx], + "start": seg_start_actual, + "end": seg_end_actual, + "char-segments": char_level, + "word-segments": word_level + } ) - curr_vdx += 1 - curr_text = '' - curr_text += ' ' + wrd_seg['text'] - if len(curr_text) > 0: - curr_speaker = segment['speakers'][curr_vdx] - vad_segments_list.append( - {'start': segment['vad'][curr_vdx][0], - 'end': segment['vad'][curr_vdx][1], - 'text': f"[{curr_speaker}]: " + curr_text.strip()} - ) - curr_text = '' - total_word_segments_list += word_segments_list - print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}") + if "language" in segment: + aligned_segments[-1]["language"] = segment["language"] + + print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}") - return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list} + char_level = { + "start": [], + "end": [], + "score": [], + "word-index": [], + } + word_level = { + "start": [], + "end": [], + "score": [], + "segment-text-start": [], + "segment-text-end": [] + } + wdx = 0 + cdx_prev = cdx + 2 + sub_seg_idx += 1 + seg_start_actual, seg_end_actual = None, None + + + # take min-max for actual segment-level timestamp + if seg_start_actual is None and start is not None: + seg_start_actual = start + if end is not None: + seg_end_actual = end + + + prev_t2 = segment["end"] + + segment_align_success = True + # end while True loop + break + + # reset prev_t2 due to drifting issues + if not segment_align_success: + prev_t2 = 0 + + # shift segment index by amount of sub-segments + if "seg-text" in segment: + sdx += len(segment["seg-text"]) + else: + sdx += 1 + + # create word level segments for .srt + word_seg = [] + for seg in aligned_segments: + if model_lang in LANGUAGES_WITHOUT_SPACES: + # character based + seg["word-segments"] = seg["char-segments"] + seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start'])) + seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1) + + wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None}) + for wdx, wrow in wseg.iterrows(): + if wrow["start"] is not None: + word_seg.append( + { + "start": wrow["start"], + "end": wrow["end"], + "text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] + } + ) + + return {"segments": aligned_segments, "word_segments": word_seg} def load_align_model(language_code, device, model_name=None): if model_name is None: @@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None): return align_model, align_metadata -def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False): - ''' - Merge VAD segments into larger segments of size ~CHUNK_LENGTH. - ''' +def merge_chunks(segments, chunk_size=CHUNK_LENGTH): + """ + Merge VAD segments into larger segments of size ~CHUNK_LENGTH. + """ curr_start = 0 curr_end = 0 merged_segments = [] @@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False): "start": curr_start, "end": curr_end, "segments": seg_idxs, - "speakers": speaker_idxs, }) curr_start = seg.start seg_idxs = [] @@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False): "start": curr_start, "end": curr_end, "segments": seg_idxs, - "speakers": speaker_idxs }) return merged_segments - -def transcribe_segments( +def transcribe_with_vad( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], - merged_segments, + vad_pipeline, mel = None, + verbose: Optional[bool] = None, **kwargs ): - ''' - Transcribe according to predefined VAD segments. - ''' + """ + Transcribe per VAD segment + """ if mel is None: mel = log_mel_spectrogram(audio) prev = 0 + output = {"segments": []} - output = {'segments': []} + vad_segments_list = [] + vad_segments = vad_pipeline(audio) + for speech_turn in vad_segments.get_timeline().support(): + vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN")) + # merge segments to approx 30s inputs to make whisper most appropraite + vad_segments = merge_chunks(vad_segments_list) - for sdx, seg_t in enumerate(merged_segments): - print(sdx, seg_t['start'], seg_t['end'], '...') - seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH) + for sdx, seg_t in enumerate(vad_segments): + if verbose: + print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~") + seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH) local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev mel = mel[:, local_f_start:] # seek forward prev = seg_f_start local_mel = mel[:, :local_f_end-local_f_start] - result = transcribe(model, audio, mel=local_mel, **kwargs) - seg_t['text'] = result['text'] - output['segments'].append( + result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs) + seg_t["text"] = result["text"] + output["segments"].append( { - 'start': seg_t['start'], - 'end': seg_t['end'], - 'language': result['language'], - 'text': result['text'], - 'seg-text': [x['text'] for x in result['segments']], - 'seg-start': [x['start'] for x in result['segments']], - 'seg-end': [x['end'] for x in result['segments']], + "start": seg_t["start"], + "end": seg_t["end"], + "language": result["language"], + "text": result["text"], + "seg-text": [x["text"] for x in result["segments"]], + "seg-start": [x["start"] for x in result["segments"]], + "seg-end": [x["end"] for x in result["segments"]], } ) - output['language'] = output['segments'][0]['language'] + output["language"] = output["segments"][0]["language"] return output + +def assign_word_speakers(diarize_df, result_segments, fill_nearest=False): + + for seg in result_segments: + wdf = pd.DataFrame(seg['word-segments']) + if len(wdf['start'].dropna()) == 0: + wdf['start'] = seg['start'] + wdf['end'] = seg['end'] + speakers = [] + for wdx, wrow in wdf.iterrows(): + diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start']) + # remove no hit + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) == 0: + speaker = None + else: + speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2] + speakers.append(speaker) + seg['word-segments']['speaker'] = speakers + seg["speaker"] = pd.Series(speakers).value_counts().index[0] + + # create word level segments for .srt + word_seg = [] + for seg in result_segments: + wseg = pd.DataFrame(seg["word-segments"]) + for wdx, wrow in wseg.iterrows(): + if wrow["start"] is not None: + speaker = wrow['speaker'] + if speaker is None or speaker == np.nan: + speaker = "UNKNOWN" + word_seg.append( + { + "start": wrow["start"], + "end": wrow["end"], + "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] + } + ) + + # TODO: create segments but split words on new speaker + + return result_segments, word_seg + class Segment: def __init__(self, start, end, speaker=None): self.start = start @@ -589,11 +766,17 @@ def cli(): parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment") parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment") - parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.") - parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...") + parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") + # vad params + parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.") parser.add_argument("--vad_input", default=None, type=str) + # diarization params + parser.add_argument("--diarize", action='store_true') + parser.add_argument("--min_speakers", default=None, type=int) + parser.add_argument("--max_speakers", default=None, type=int) + # output save params parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save") + parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") @@ -627,24 +810,32 @@ def cli(): align_model: str = args.pop("align_model") align_extend: float = args.pop("align_extend") align_from_prev: bool = args.pop("align_from_prev") - drop_non_aligned: bool = args.pop("drop_non_aligned") + interpolate_method: bool = args.pop("interpolate_method") vad_filter: bool = args.pop("vad_filter") vad_input: bool = args.pop("vad_input") + diarize: bool = args.pop("diarize") + min_speakers: int = args.pop("min_speakers") + max_speakers: int = args.pop("max_speakers") + vad_pipeline = None if vad_input is not None: vad_input = pd.read_csv(vad_input, header=None, sep= " ") elif vad_filter: from pyannote.audio import Pipeline vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") - # vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1") + + diarize_pipeline = None + if diarize: + from pyannote.audio import Pipeline + diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1") os.makedirs(output_dir, exist_ok=True) if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: - warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") + warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.') args["language"] = "en" temperature = args.pop("temperature") @@ -665,24 +856,10 @@ def cli(): align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) for audio_path in args.pop("audio"): - if vad_filter or vad_input is not None: - output_segments = [] - if vad_filter: - print("Performing VAD...") - # vad_segments = vad_pipeline(audio_path) - # for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True): - # output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker)) - vad_segments = vad_pipeline(audio_path) - for speech_turn in vad_segments.get_timeline().support(): - output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN")) - elif vad_input is not None: - # rttm format - for idx, row in vad_input.iterrows(): - output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}")) - vad_segments = merge_chunks(output_segments) - result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args) + if vad_filter: + print("Performing VAD...") + result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args) else: - vad_segments = None print("Performing transcription...") result = transcribe(model, audio_path, temperature=temperature, **args) @@ -693,9 +870,20 @@ def cli(): print("Performing alignment...") result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, - extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned) + extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) audio_basename = os.path.basename(audio_path) + if diarize: + print("Performing diarization...") + diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True)) + diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) + diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) + # assumes each utterance is single speaker (needs fix) + result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True) + result_aligned["segments"] = result_segments + result_aligned["word_segments"] = word_segments + # save TXT if output_type in ["txt", "all"]: with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: @@ -711,19 +899,27 @@ def cli(): with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: write_srt(result_aligned["segments"], file=srt) - # save per-word SRT - with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt: - write_srt(result_aligned["word_segments"], file=srt) + # save TSV + if output_type in ["tsv", "all"]: + with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: + write_tsv(result_aligned["segments"], file=srt) + + # save SRT word-level + if output_type in ["srt-word", "all"]: + # save per-word SRT + with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt: + write_srt(result_aligned["word_segments"], file=srt) # save ASS - with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: - write_ass(result_aligned["segments"], file=ass) + if output_type in ["ass", "all"]: + with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: + write_ass(result_aligned["segments"], file=ass) - if vad_filter is not None: - # save per-word SRT - with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt: - write_srt(result_aligned["vad_segments"], file=srt) + # save ASS character-level + if output_type in ["ass-char", "all"]: + with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass: + write_ass(result_aligned["segments"], file=ass, resolution="char") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/whisperx/utils.py b/whisperx/utils.py index 56e3483..590eaab 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -1,6 +1,7 @@ +import os import zlib -from typing import Iterator, TextIO, Tuple, List - +from typing import Callable, TextIO, Iterator, Tuple +import pandas as pd def exact_div(x, y): assert x % y == 0 @@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO): flush=True, ) +def write_tsv(transcript: Iterator[dict], file: TextIO): + print("start", "end", "text", sep="\t", file=file) + for segment in transcript: + print(round(1000 * segment['start']), file=file, end="\t") + print(round(1000 * segment['end']), file=file, end="\t") + print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + def write_srt(transcript: Iterator[dict], file: TextIO): """ @@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO): ) -def write_ass(transcript: Iterator[dict], file: TextIO, +def write_ass(transcript: Iterator[dict], + file: TextIO, + resolution: str = "word", color: str = None, underline=True, prefmt: str = None, suffmt: str = None, font: str = None, font_size: int = 24, @@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO, Note: ass file is used in the same way as srt, vtt, etc. Parameters ---------- - res: dict + transcript: dict results from modified model - ass_path: str - output path (e.g. caption.ass) + file: TextIO + file object to write to + resolution: str + "word" or "char", timestamp resolution to highlight. color: str color code for a word at its corresponding timestamp reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00) @@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO, return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}' - def dialogue(words: List[str], idx, start, end) -> str: - text = ''.join(f' {prefmt}{word}{suffmt}' - # if not word.startswith(' ') or word == ' ' else - # f' {prefmt}{word.strip()}{suffmt}') - if curr_idx == idx else - f' {word}' - for curr_idx, word in enumerate(words)) + def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str: + if idx_0 == -1: + text = chars + else: + text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}' return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \ f"Default,,0,0,0,,{text.strip() if strip else text}" - + if resolution == "word": + resolution_key = "word-segments" + elif resolution == "char": + resolution_key = "char-segments" + else: + raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution) + ass_arr = [] for segment in transcript: - curr_words = [wrd['text'] for wrd in segment['word-level']] - prev = segment['word-level'][0]['start'] - if prev is None: + if resolution_key in segment: + res_segs = pd.DataFrame(segment[resolution_key]) prev = segment['start'] - for wdx, word in enumerate(segment['word-level']): - if word['start'] is not None: - # fill gap between previous word - if word['start'] > prev: - filler_ts = { - "words": curr_words, - "start": prev, - "end": word['start'], - "idx": -1 + if "speaker" in segment: + speaker_str = f"[{segment['speaker']}]: " + else: + speaker_str = "" + for cdx, crow in res_segs.iterrows(): + if crow['start'] is not None: + if resolution == "char": + idx_0 = cdx + idx_1 = cdx + 1 + elif resolution == "word": + idx_0 = int(crow["segment-text-start"]) + idx_1 = int(crow["segment-text-end"]) + # fill gap + if crow['start'] > prev: + filler_ts = { + "chars": speaker_str + segment['text'], + "start": prev, + "end": crow['start'], + "idx_0": -1, + "idx_1": -1 + } + + ass_arr.append(filler_ts) + # highlight current word + f_word_ts = { + "chars": speaker_str + segment['text'], + "start": crow['start'], + "end": crow['end'], + "idx_0": idx_0 + len(speaker_str), + "idx_1": idx_1 + len(speaker_str) } - ass_arr.append(filler_ts) - - # highlight current word - f_word_ts = { - "words": curr_words, - "start": word['start'], - "end": word['end'], - "idx": wdx - } - ass_arr.append(f_word_ts) - - prev = word['end'] - - + ass_arr.append(f_word_ts) + prev = crow['end'] ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr)) file.write(ass_str) + +def interpolate_nans(x, method='nearest'): + if x.notnull().sum() > 1: + return x.interpolate(method=method).ffill().bfill() + else: + return x.ffill().bfill() \ No newline at end of file