From 286a2f2c14c159b6aa6fb96e1f1acff309d32714 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Wed, 25 Jan 2023 18:42:52 +0000 Subject: [PATCH] clean up logic, use pandas where possibl --- EXAMPLES.md | 2 +- whisperx/__init__.py | 4 +- whisperx/alignment.py | 409 +++++++++++++++++++++++++++++++++++++- whisperx/transcribe.py | 437 ++--------------------------------------- whisperx/utils.py | 3 +- 5 files changed, 426 insertions(+), 429 deletions(-) diff --git a/EXAMPLES.md b/EXAMPLES.md index c39c0c4..d9dc8e4 100644 --- a/EXAMPLES.md +++ b/EXAMPLES.md @@ -2,7 +2,7 @@ ## Other Languages -For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22). +For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18). Currently support default models tested for {en, fr, de, es, it, ja, zh, nl} diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 4f253b3..b897f01 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -11,8 +11,8 @@ from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .model import Whisper, ModelDimensions -from .transcribe import transcribe, load_align_model, align, transcribe_with_vad - +from .transcribe import transcribe, transcribe_with_vad +from .alignment import load_align_model, align _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7d59231..ebf084d 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -1,9 +1,412 @@ +"""" +Forced Alignment with Whisper +C. Max Bain +""" +import numpy as np +import pandas as pd +from typing import List, Union, Iterator, TYPE_CHECKING +from transformers import AutoProcessor, Wav2Vec2ForCTC +import torchaudio +import torch +from dataclasses import dataclass +from .audio import SAMPLE_RATE, load_audio +from .utils import interpolate_nans + + +LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] + +DEFAULT_ALIGN_MODELS_TORCH = { + "en": "WAV2VEC2_ASR_BASE_960H", + "fr": "VOXPOPULI_ASR_BASE_10K_FR", + "de": "VOXPOPULI_ASR_BASE_10K_DE", + "es": "VOXPOPULI_ASR_BASE_10K_ES", + "it": "VOXPOPULI_ASR_BASE_10K_IT", +} + +DEFAULT_ALIGN_MODELS_HF = { + "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", + "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", + "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", + "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", + "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", +} + + +def load_align_model(language_code, device, model_name=None): + if model_name is None: + # use default model + if language_code in DEFAULT_ALIGN_MODELS_TORCH: + model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] + elif language_code in DEFAULT_ALIGN_MODELS_HF: + model_name = DEFAULT_ALIGN_MODELS_HF[language_code] + else: + print(f"There is no default alignment model set for this language ({language_code}).\ + Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") + raise ValueError(f"No default align-model for language: {language_code}") + + if model_name in torchaudio.pipelines.__all__: + pipeline_type = "torchaudio" + bundle = torchaudio.pipelines.__dict__[model_name] + align_model = bundle.get_model().to(device) + labels = bundle.get_labels() + align_dictionary = {c.lower(): i for i, c in enumerate(labels)} + else: + try: + processor = AutoProcessor.from_pretrained(model_name) + align_model = Wav2Vec2ForCTC.from_pretrained(model_name) + except Exception as e: + print(e) + print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") + raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') + pipeline_type = "huggingface" + align_model = align_model.to(device) + labels = processor.tokenizer.get_vocab() + align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} + + align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} + + return align_model, align_metadata + + +def align( + transcript: Iterator[dict], + model: torch.nn.Module, + align_model_metadata: dict, + audio: Union[str, np.ndarray, torch.Tensor], + device: str, + extend_duration: float = 0.0, + start_from_previous: bool = True, + interpolate_method: str = "nearest", +): + """ + Force align phoneme recognition predictions to known transcription + + Parameters + ---------- + transcript: Iterator[dict] + The Whisper model instance + + model: torch.nn.Module + Alignment model (wav2vec2) + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + device: str + cuda device + + diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]} + diarization segments with speaker labels. + + extend_duration: float + Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds + + If the gzip compression ratio is above this value, treat as failed + + interpolate_method: str ["nearest", "linear", "ignore"] + Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary. + "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output. + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level details ("segments"), and + the spoken language ("language"), which is detected when `decode_options["language"]` is None. + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + + MAX_DURATION = audio.shape[1] / SAMPLE_RATE + + model_dictionary = align_model_metadata["dictionary"] + model_lang = align_model_metadata["language"] + model_type = align_model_metadata["type"] + + aligned_segments = [] + + prev_t2 = 0 + + char_segments_arr = { + "segment-idx": [], + "subsegment-idx": [], + "word-idx": [], + "char": [], + "start": [], + "end": [], + "score": [], + } + + for sdx, segment in enumerate(transcript): + while True: + segment_align_success = False + + # strip spaces at beginning / end, but keep track of the amount. + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + transcription = segment["text"] + + # TODO: convert number tokenizer / symbols to phonetic words for alignment. + # e.g. "$300" -> "three hundred dollars" + # currently "$300" is ignored since no characters present in the phonetic dictionary + + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = transcription.split(" ") + else: + per_word = transcription + + # first check that characters in transcription can be aligned (they are contained in align model"s dictionary) + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(transcription): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces + if model_lang not in LANGUAGES_WITHOUT_SPACES: + char_ = char_.replace(" ", "|") + + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(transcription) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) + + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any([c in model_dictionary.keys() for c in wrd]): + clean_wdx.append(wdx) + + # if no characters are in the dictionary, then we skip this segment... + if len(clean_char) == 0: + print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...") + break + + transcription_cleaned = "".join(clean_char) + tokens = [model_dictionary[c] for c in transcription_cleaned] + + # pad according original timestamps + t1 = max(segment["start"] - extend_duration, 0) + t2 = min(segment["end"] + extend_duration, MAX_DURATION) + + # use prev_t2 as current t1 if it"s later + if start_from_previous and t1 < prev_t2: + t1 = prev_t2 + + # check if timestamp range is still valid + if t1 >= MAX_DURATION: + print("Failed to align segment: original start time longer than audio duration, skipping...") + break + if t2 - t1 < 0.02: + print("Failed to align segment: duration smaller than 0.02s time precision") + break + + f1 = int(t1 * SAMPLE_RATE) + f2 = int(t2 * SAMPLE_RATE) + + waveform_segment = audio[:, f1:f2] + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device)) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + + trellis = get_trellis(emission, tokens) + path = backtrack(trellis, emission, tokens) + if path is None: + print("Failed to align segment: backtrack failed, resorting to original...") + break + char_segments = merge_repeats(path, transcription_cleaned) + # word_segments = merge_words(char_segments) + + + # sub-segments + if "seg-text" not in segment: + segment["seg-text"] = [transcription] + + v = 0 + seg_lens = [0] + [len(x) for x in segment["seg-text"]] + seg_lens_cumsum = [v := v + n for n in seg_lens] + sub_seg_idx = 0 + + wdx = 0 + duration = t2 - t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + for cdx, char in enumerate(transcription + " "): + is_last = False + if cdx == len(transcription): + break + elif cdx+1 == len(transcription): + is_last = True + + + start, end, score = None, None, None + if cdx in clean_cdx: + char_seg = char_segments[clean_cdx.index(cdx)] + start = char_seg.start * ratio + t1 + end = char_seg.end * ratio + t1 + score = char_seg.score + + char_segments_arr["char"].append(char) + char_segments_arr["start"].append(start) + char_segments_arr["end"].append(end) + char_segments_arr["score"].append(score) + char_segments_arr["word-idx"].append(wdx) + char_segments_arr["segment-idx"].append(sdx) + char_segments_arr["subsegment-idx"].append(sub_seg_idx) + + # word-level info + if model_lang in LANGUAGES_WITHOUT_SPACES: + # character == word + wdx += 1 + elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: + wdx += 1 + + if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: + wdx = 0 + sub_seg_idx += 1 + + prev_t2 = segment["end"] + + segment_align_success = True + # end while True loop + break + + # reset prev_t2 due to drifting issues + if not segment_align_success: + prev_t2 = 0 + + char_segments_arr = pd.DataFrame(char_segments_arr) + not_space = char_segments_arr["char"] != " " + + per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False) + char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index() + per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) + per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"]) + per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"]) + + word_segments_arr = {} + + # start of word is first char with a timestamp + word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"] + # end of word is last char with a timestamp + word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"] + # score of word is mean (excluding nan) + word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"] + + + word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"] + word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1 + word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"] + + word_segments_arr = pd.DataFrame(word_segments_arr) + word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]] + + segments_arr = {} + segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"] + segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"] + segments_arr = pd.DataFrame(segments_arr) + segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]] + segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1 + + # interpolate missing words / sub-segments + if interpolate_method != "ignore": + wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"]) + wrd_seg_grp = word_segments_arr.groupby(["segment-idx"]) + # we still know which word timestamps are interpolated because their score == nan + word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + + word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + + sub_seg_grp = segments_arr.groupby(["segment-idx"]) + segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) + # merge subsegments which are missing times + # group by sub seg and time. + seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) + segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) + segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) + segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True) + else: + word_segments_arr.dropna(inplace=True) + segments_arr.dropna(inplace=True) + + aligned_segments = [] + aligned_segments_word = [] + + word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True) + char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True) + + for sdx, srow in segments_arr.iterrows(): + + seg_idx = int(srow["segment-idx"]) + sub_start = int(srow["subsegment-idx-start"]) + sub_end = int(srow["subsegment-idx-end"]) + + seg = transcript[seg_idx] + text = "".join(seg["seg-text"][sub_start:sub_end]) + + wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] + cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] + cseg['segment-text-start'] = cseg['level_1'] + cseg['segment-text-end'] = cseg['level_1'] + 1 + del cseg['level_1'] + del cseg['level_0'] + cseg.reset_index(inplace=True) + aligned_segments.append( + { + "start": srow["start"], + "end": srow["end"], + "text": text, + "word-segments": wseg, + "char-segments": cseg + } + ) + + def get_raw_text(word_row): + return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] + + wdx = 0 + curr_text = get_raw_text(wseg.iloc[wdx]) + if len(wseg) > 1: + for _, wrow in wseg.iloc[1:].iterrows(): + if wrow['start'] != wseg.iloc[wdx]['start']: + aligned_segments_word.append( + { + "text": curr_text.strip(), + "start": wseg.iloc[wdx]["start"], + "end": wseg.iloc[wdx]["end"], + } + ) + curr_text = "" + curr_text += " " + get_raw_text(wrow) + wdx += 1 + aligned_segments_word.append( + { + "text": curr_text.strip(), + "start": wseg.iloc[wdx]["start"], + "end": wseg.iloc[wdx]["end"] + } + ) + + + return {"segments": aligned_segments, "word_segments": aligned_segments_word} + + """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html """ -import torch -from dataclasses import dataclass - def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 5cfc6aa..7f07f3c 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -5,12 +5,11 @@ from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING import numpy as np import torch -import torchaudio -from transformers import AutoProcessor, Wav2Vec2ForCTC import tqdm from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio -from .alignment import get_trellis, backtrack, merge_repeats, merge_words +from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words from .decoding import DecodingOptions, DecodingResult +from .diarize import assign_word_speakers, Segment from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv import pandas as pd @@ -18,23 +17,6 @@ import pandas as pd if TYPE_CHECKING: from .model import Whisper -LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] - -DEFAULT_ALIGN_MODELS_TORCH = { - "en": "WAV2VEC2_ASR_BASE_960H", - "fr": "VOXPOPULI_ASR_BASE_10K_FR", - "de": "VOXPOPULI_ASR_BASE_10K_DE", - "es": "VOXPOPULI_ASR_BASE_10K_ES", - "it": "VOXPOPULI_ASR_BASE_10K_IT", -} - -DEFAULT_ALIGN_MODELS_HF = { - "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", - "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", - "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", - "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", - "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", -} def transcribe( @@ -273,355 +255,11 @@ def transcribe( return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) -def align( - transcript: Iterator[dict], - model: torch.nn.Module, - align_model_metadata: dict, - audio: Union[str, np.ndarray, torch.Tensor], - device: str, - extend_duration: float = 0.0, - start_from_previous: bool = True, - interpolate_method: str = "nearest", -): - """ - Force align phoneme recognition predictions to known transcription - - Parameters - ---------- - transcript: Iterator[dict] - The Whisper model instance - - model: torch.nn.Module - Alignment model (wav2vec2) - - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform - - device: str - cuda device - - extend_duration: float - Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds - - If the gzip compression ratio is above this value, treat as failed - - interpolate_method: str ["nearest", "linear", "ignore"] - Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary. - "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output. - - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. - """ - if not torch.is_tensor(audio): - if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) - if len(audio.shape) == 1: - audio = audio.unsqueeze(0) - - MAX_DURATION = audio.shape[1] / SAMPLE_RATE - - model_dictionary = align_model_metadata["dictionary"] - model_lang = align_model_metadata["language"] - model_type = align_model_metadata["type"] - - aligned_segments = [] - - prev_t2 = 0 - for segment in transcript: - aligned_subsegments = [] - while True: - segment_align_success = False - - # strip spaces at beginning / end, but keep track of the amount. - num_leading = len(segment["text"]) - len(segment["text"].lstrip()) - num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) - transcription = segment["text"] - - # TODO: convert number tokenizer / symbols to phonetic words for alignment. - # e.g. "$300" -> "three hundred dollars" - # currently "$300" is ignored since no characters present in the phonetic dictionary - - # split into words - if model_lang not in LANGUAGES_WITHOUT_SPACES: - per_word = transcription.split(" ") - else: - per_word = transcription - - # first check that characters in transcription can be aligned (they are contained in align model"s dictionary) - clean_char, clean_cdx = [], [] - for cdx, char in enumerate(transcription): - char_ = char.lower() - # wav2vec2 models use "|" character to represent spaces - if model_lang not in LANGUAGES_WITHOUT_SPACES: - char_ = char_.replace(" ", "|") - - # ignore whitespace at beginning and end of transcript - if cdx < num_leading: - pass - elif cdx > len(transcription) - num_trailing - 1: - pass - elif char_ in model_dictionary.keys(): - clean_char.append(char_) - clean_cdx.append(cdx) - - clean_wdx = [] - for wdx, wrd in enumerate(per_word): - if any([c in model_dictionary.keys() for c in wrd]): - clean_wdx.append(wdx) - - # if no characters are in the dictionary, then we skip this segment... - if len(clean_char) == 0: - print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...") - break - - transcription_cleaned = "".join(clean_char) - tokens = [model_dictionary[c] for c in transcription_cleaned] - - # pad according original timestamps - t1 = max(segment["start"] - extend_duration, 0) - t2 = min(segment["end"] + extend_duration, MAX_DURATION) - - # use prev_t2 as current t1 if it"s later - if start_from_previous and t1 < prev_t2: - t1 = prev_t2 - - # check if timestamp range is still valid - if t1 >= MAX_DURATION: - print("Failed to align segment: original start time longer than audio duration, skipping...") - break - if t2 - t1 < 0.02: - print("Failed to align segment: duration smaller than 0.02s time precision") - break - - f1 = int(t1 * SAMPLE_RATE) - f2 = int(t2 * SAMPLE_RATE) - - waveform_segment = audio[:, f1:f2] - - with torch.inference_mode(): - if model_type == "torchaudio": - emissions, _ = model(waveform_segment.to(device)) - elif model_type == "huggingface": - emissions = model(waveform_segment.to(device)).logits - else: - raise NotImplementedError(f"Align model of type {model_type} not supported.") - emissions = torch.log_softmax(emissions, dim=-1) - - emission = emissions[0].cpu().detach() - - trellis = get_trellis(emission, tokens) - path = backtrack(trellis, emission, tokens) - if path is None: - print("Failed to align segment: backtrack failed, resorting to original...") - break - char_segments = merge_repeats(path, transcription_cleaned) - # word_segments = merge_words(char_segments) - - - # sub-segments - if "seg-text" not in segment: - segment["seg-text"] = [transcription] - - v = 0 - seg_lens = [0] + [len(x) for x in segment["seg-text"]] - seg_lens_cumsum = [v := v + n for n in seg_lens] - sub_seg_idx = 0 - - char_level = { - "start": [], - "end": [], - "score": [], - "word-index": [], - } - - word_level = { - "start": [], - "end": [], - "score": [], - "segment-text-start": [], - "segment-text-end": [] - } - - wdx = 0 - seg_start_actual, seg_end_actual = None, None - duration = t2 - t1 - ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) - cdx_prev = 0 - for cdx, char in enumerate(transcription + " "): - is_last = False - if cdx == len(transcription): - break - elif cdx+1 == len(transcription): - is_last = True - - - start, end, score = None, None, None - if cdx in clean_cdx: - char_seg = char_segments[clean_cdx.index(cdx)] - start = char_seg.start * ratio + t1 - end = char_seg.end * ratio + t1 - score = char_seg.score - - char_level["start"].append(start) - char_level["end"].append(end) - char_level["score"].append(score) - char_level["word-index"].append(wdx) - - # word-level info - if model_lang in LANGUAGES_WITHOUT_SPACES: - # character == word - wdx += 1 - elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - wdx += 1 - word_level["start"].append(None) - word_level["end"].append(None) - word_level["score"].append(None) - word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx]) - word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx]) - cdx_prev = cdx+2 - - if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - if model_lang not in LANGUAGES_WITHOUT_SPACES: - char_level = pd.DataFrame(char_level) - word_level = pd.DataFrame(word_level) - - not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " " - word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space - word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word - - # fill missing - if interpolate_method != "ignore": - word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method) - word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method) - word_level["start"] = word_level["start"].values.tolist() - word_level["end"] = word_level["end"].values.tolist() - word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores - - char_level = char_level.replace({np.nan:None}).to_dict("list") - word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list") - else: - word_level = None - - aligned_subsegments.append( - { - "text": segment["seg-text"][sub_seg_idx], - "start": seg_start_actual, - "end": seg_end_actual, - "char-segments": char_level, - "word-segments": word_level - } - ) - if "language" in segment: - aligned_subsegments[-1]["language"] = segment["language"] - - char_level = { - "start": [], - "end": [], - "score": [], - "word-index": [], - } - word_level = { - "start": [], - "end": [], - "score": [], - "segment-text-start": [], - "segment-text-end": [] - } - wdx = 0 - cdx_prev = cdx + 2 - sub_seg_idx += 1 - seg_start_actual, seg_end_actual = None, None - - - # take min-max for actual segment-level timestamp - if seg_start_actual is None and start is not None: - seg_start_actual = start - if end is not None: - seg_end_actual = end - - - prev_t2 = segment["end"] - - segment_align_success = True - # end while True loop - break - - # reset prev_t2 due to drifting issues - if not segment_align_success: - prev_t2 = 0 - - start = interpolate_nans(pd.DataFrame(aligned_subsegments)["start"], method=interpolate_method) - end = interpolate_nans(pd.DataFrame(aligned_subsegments)["end"], method=interpolate_method) - for idx, seg in enumerate(aligned_subsegments): - seg['start'] = start.iloc[idx] - seg['end'] = end.iloc[idx] - - aligned_segments += aligned_subsegments - - # create word level segments for .srt - word_seg = [] - for seg in aligned_segments: - if model_lang in LANGUAGES_WITHOUT_SPACES: - # character based - seg["word-segments"] = seg["char-segments"] - seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start'])) - seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1) - - wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None}) - for wdx, wrow in wseg.iterrows(): - if wrow["start"] is not None: - word_seg.append( - { - "start": wrow["start"], - "end": wrow["end"], - "text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] - } - ) - - return {"segments": aligned_segments, "word_segments": word_seg} - -def load_align_model(language_code, device, model_name=None): - if model_name is None: - # use default model - if language_code in DEFAULT_ALIGN_MODELS_TORCH: - model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] - elif language_code in DEFAULT_ALIGN_MODELS_HF: - model_name = DEFAULT_ALIGN_MODELS_HF[language_code] - else: - print(f"There is no default alignment model set for this language ({language_code}).\ - Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") - raise ValueError(f"No default align-model for language: {language_code}") - - if model_name in torchaudio.pipelines.__all__: - pipeline_type = "torchaudio" - bundle = torchaudio.pipelines.__dict__[model_name] - align_model = bundle.get_model().to(device) - labels = bundle.get_labels() - align_dictionary = {c.lower(): i for i, c in enumerate(labels)} - else: - try: - processor = AutoProcessor.from_pretrained(model_name) - align_model = Wav2Vec2ForCTC.from_pretrained(model_name) - except Exception as e: - print(e) - print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") - raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') - pipeline_type = "huggingface" - align_model = align_model.to(device) - labels = processor.tokenizer.get_vocab() - align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} - - align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} - - return align_model, align_metadata - - def merge_chunks(segments, chunk_size=CHUNK_LENGTH): """ - Merge VAD segments into larger segments of size ~CHUNK_LENGTH. + Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH. + TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model + TODO: Or sliding window alignment model over long segment. """ curr_start = 0 curr_end = 0 @@ -702,58 +340,6 @@ def transcribe_with_vad( return output -def assign_word_speakers(diarize_df, result_segments, fill_nearest=False): - - for seg in result_segments: - wdf = pd.DataFrame(seg['word-segments']) - if len(wdf['start'].dropna()) == 0: - wdf['start'] = seg['start'] - wdf['end'] = seg['end'] - speakers = [] - for wdx, wrow in wdf.iterrows(): - diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start']) - diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start']) - # remove no hit - if not fill_nearest: - dia_tmp = diarize_df[diarize_df['intersection'] > 0] - else: - dia_tmp = diarize_df - if len(dia_tmp) == 0: - speaker = None - else: - speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2] - speakers.append(speaker) - seg['word-segments']['speaker'] = speakers - seg["speaker"] = pd.Series(speakers).value_counts().index[0] - - # create word level segments for .srt - word_seg = [] - for seg in result_segments: - wseg = pd.DataFrame(seg["word-segments"]) - for wdx, wrow in wseg.iterrows(): - if wrow["start"] is not None: - speaker = wrow['speaker'] - if speaker is None or speaker == np.nan: - speaker = "UNKNOWN" - word_seg.append( - { - "start": wrow["start"], - "end": wrow["end"], - "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] - } - ) - - # TODO: create segments but split words on new speaker - - return result_segments, word_seg - -class Segment: - def __init__(self, start, end, speaker=None): - self.start = start - self.end = end - self.speaker = speaker - - def cli(): from . import available_models @@ -776,7 +362,7 @@ def cli(): parser.add_argument("--max_speakers", default=None, type=int) # output save params parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save") + parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") @@ -868,6 +454,7 @@ def cli(): print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) + print("Performing alignment...") result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) @@ -915,10 +502,16 @@ def cli(): with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: write_ass(result_aligned["segments"], file=ass) - # save ASS character-level - if output_type in ["ass-char", "all"]: + # # save ASS character-level + if output_type in ["ass-char"]: with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass: write_ass(result_aligned["segments"], file=ass, resolution="char") + # save word tsv + if output_type in ["pickle"]: + exp_fp = os.path.join(output_dir, audio_basename + ".pkl") + pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp) + + if __name__ == "__main__": cli() diff --git a/whisperx/utils.py b/whisperx/utils.py index 77243d6..6f46514 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -2,6 +2,7 @@ import os import zlib from typing import Callable, TextIO, Iterator, Tuple import pandas as pd +import numpy as np def exact_div(x, y): assert x % y == 0 @@ -214,7 +215,7 @@ def write_ass(transcript: Iterator[dict], else: speaker_str = "" for cdx, crow in res_segs.iterrows(): - if crow['start'] is not None: + if not np.isnan(crow['start']): if resolution == "char": idx_0 = cdx idx_1 = cdx + 1