""" Forced Alignment with Whisper C. Max Bain """ import math from dataclasses import dataclass from typing import Iterable, Optional, Union, List import numpy as np import pandas as pd import torch import torchaudio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from whisperx.audio import SAMPLE_RATE, load_audio from whisperx.utils import interpolate_nans from whisperx.types import ( AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment, SegmentData, ) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] DEFAULT_ALIGN_MODELS_TORCH = { "en": "WAV2VEC2_ASR_BASE_960H", "fr": "VOXPOPULI_ASR_BASE_10K_FR", "de": "VOXPOPULI_ASR_BASE_10K_DE", "es": "VOXPOPULI_ASR_BASE_10K_ES", "it": "VOXPOPULI_ASR_BASE_10K_IT", } DEFAULT_ALIGN_MODELS_HF = { "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", "vi": 'nguyenvulebinh/wav2vec2-base-vi', "ko": "kresnik/wav2vec2-large-xlsr-korean", "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", "ca": "softcatala/wav2vec2-large-xlsr-catala", "ml": "gvs/wav2vec2-large-xlsr-malayalam", "no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2", "nn": "NbAiLab/nb-wav2vec2-1b-nynorsk", "sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8", "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian", "hr": "classla/wav2vec2-xls-r-parlaspeech-hr", "ro": "gigant/romanian-wav2vec2", "eu": "stefan-it/wav2vec2-large-xlsr-53-basque", "gl": "ifrz/wav2vec2-large-xlsr-galician", "ka": "xsway/wav2vec2-large-xlsr-georgian", "lv": "jimregan/wav2vec2-large-xlsr-latvian-cv", "tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official", } def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): if model_name is None: # use default model if language_code in DEFAULT_ALIGN_MODELS_TORCH: model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] elif language_code in DEFAULT_ALIGN_MODELS_HF: model_name = DEFAULT_ALIGN_MODELS_HF[language_code] else: print(f"There is no default alignment model set for this language ({language_code}).\ Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") raise ValueError(f"No default align-model for language: {language_code}") if model_name in torchaudio.pipelines.__all__: pipeline_type = "torchaudio" bundle = torchaudio.pipelines.__dict__[model_name] align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) labels = bundle.get_labels() align_dictionary = {c.lower(): i for i, c in enumerate(labels)} else: try: processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir) align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir) except Exception as e: print(e) print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') pipeline_type = "huggingface" align_model = align_model.to(device) labels = processor.tokenizer.get_vocab() align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} return align_model, align_metadata def align( transcript: Iterable[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, interpolate_method: str = "nearest", return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. """ if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) audio = torch.from_numpy(audio) if len(audio.shape) == 1: audio = audio.unsqueeze(0) MAX_DURATION = audio.shape[1] / SAMPLE_RATE model_dictionary = align_model_metadata["dictionary"] model_lang = align_model_metadata["language"] model_type = align_model_metadata["type"] # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) # Store temporary processing values segment_data: dict[int, SegmentData] = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: base_progress = ((sdx + 1) / total_segments) * 100 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) text = segment["text"] # split into words if model_lang not in LANGUAGES_WITHOUT_SPACES: per_word = text.split(" ") else: per_word = text clean_char, clean_cdx = [], [] for cdx, char in enumerate(text): char_ = char.lower() # wav2vec2 models use "|" character to represent spaces if model_lang not in LANGUAGES_WITHOUT_SPACES: char_ = char_.replace(" ", "|") # ignore whitespace at beginning and end of transcript if cdx < num_leading: pass elif cdx > len(text) - num_trailing - 1: pass elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) else: # add placeholder clean_char.append('*') clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): if any([c in model_dictionary.keys() for c in wrd.lower()]): clean_wdx.append(wdx) else: # index for placeholder clean_wdx.append(wdx) punkt_param = PunktParameters() punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_spans = list(sentence_splitter.span_tokenize(text)) segment_data[sdx] = { "clean_char": clean_char, "clean_cdx": clean_cdx, "clean_wdx": clean_wdx, "sentence_spans": sentence_spans } aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): t1 = segment["start"] t2 = segment["end"] text = segment["text"] aligned_seg: SingleAlignedSegment = { "start": t1, "end": t2, "text": text, "words": [], "chars": None, } if return_char_alignments: aligned_seg["chars"] = [] # check we can align if len(segment_data[sdx]["clean_char"]) == 0: print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') aligned_segments.append(aligned_seg) continue if t1 >= MAX_DURATION: print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') aligned_segments.append(aligned_seg) continue text_clean = "".join(segment_data[sdx]["clean_char"]) tokens = [model_dictionary.get(c, -1) for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) # TODO: Probably can get some speedup gain with batched inference here waveform_segment = audio[:, f1:f2] # Handle the minimum input length for wav2vec2 models if waveform_segment.shape[-1] < 400: lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) waveform_segment = torch.nn.functional.pad( waveform_segment, (0, 400 - waveform_segment.shape[-1]) ) else: lengths = None with torch.inference_mode(): if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits else: raise NotImplementedError(f"Align model of type {model_type} not supported.") emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() blank_id = 0 for char, code in model_dictionary.items(): if char == '[pad]' or char == '': blank_id = code trellis = get_trellis(emission, tokens, blank_id) # path = backtrack(trellis, emission, tokens, blank_id) path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) if path is None: print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') aligned_segments.append(aligned_seg) continue char_segments = merge_repeats(path, text_clean) duration = t2 - t1 ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) # assign timestamps to aligned characters char_segments_arr = [] word_idx = 0 for cdx, char in enumerate(text): start, end, score = None, None, None if cdx in segment_data[sdx]["clean_cdx"]: char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] start = round(char_seg.start * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3) score = round(char_seg.score, 3) char_segments_arr.append( { "char": char, "start": start, "end": end, "score": score, "word-idx": word_idx, } ) # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... if model_lang in LANGUAGES_WITHOUT_SPACES: word_idx += 1 elif cdx == len(text) - 1 or text[cdx+1] == " ": word_idx += 1 char_segments_arr = pd.DataFrame(char_segments_arr) aligned_subsegments = [] # assign sentence_idx to each character index char_segments_arr["sentence-idx"] = None for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 sentence_text = text[sstart:send] sentence_start = curr_chars["start"].min() end_chars = curr_chars[curr_chars["char"] != ' '] sentence_end = end_chars["end"].max() sentence_words = [] for word_idx in curr_chars["word-idx"].unique(): word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] word_text = "".join(word_chars["char"].tolist()).strip() if len(word_text) == 0: continue # dont use space character for alignment word_chars = word_chars[word_chars["char"] != " "] word_start = word_chars["start"].min() word_end = word_chars["end"].max() word_score = round(word_chars["score"].mean(), 3) # -1 indicates unalignable word_segment = {"word": word_text} if not np.isnan(word_start): word_segment["start"] = word_start if not np.isnan(word_end): word_segment["end"] = word_end if not np.isnan(word_score): word_segment["score"] = word_score sentence_words.append(word_segment) aligned_subsegments.append({ "text": sentence_text, "start": sentence_start, "end": sentence_end, "words": sentence_words, }) if return_char_alignments: curr_chars = curr_chars[["char", "start", "end", "score"]] curr_chars.fillna(-1, inplace=True) curr_chars = curr_chars.to_dict("records") curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] aligned_subsegments[-1]["chars"] = curr_chars aligned_subsegments = pd.DataFrame(aligned_subsegments) aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) # concatenate sentences with same timestamps agg_dict = {"text": " ".join, "words": "sum"} if model_lang in LANGUAGES_WITHOUT_SPACES: agg_dict["text"] = "".join if return_char_alignments: agg_dict["chars"] = "sum" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments = aligned_subsegments.to_dict('records') aligned_segments += aligned_subsegments # create word_segments list word_segments: List[SingleWordSegment] = [] for segment in aligned_segments: word_segments += segment["words"] return {"segments": aligned_segments, "word_segments": word_segments} """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html """ def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) trellis = torch.zeros((num_frame, num_tokens)) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) trellis[0, 1:] = -float("inf") trellis[-num_tokens + 1:, 0] = float("inf") for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token # trellis[t, :-1] + emission[t, tokens[1:]], trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), ) return trellis def get_wildcard_emission(frame_emission, tokens, blank_id): """Processing token emission scores containing wildcards (vectorized version) Args: frame_emission: Emission probability vector for the current frame tokens: List of token indices blank_id: ID of the blank token Returns: tensor: Maximum probability score for each token position """ assert 0 <= blank_id < len(frame_emission) # Convert tokens to a tensor if they are not already tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens # Create a mask to identify wildcard positions wildcard_mask = (tokens == -1) # Get scores for non-wildcard positions regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index # Create a mask and compute the maximum value without modifying frame_emission max_valid_score = frame_emission.clone() # Create a copy max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token max_valid_score = max_valid_score.max() # Use where operation to combine results result = torch.where(wildcard_mask, max_valid_score, regular_scores) return result @dataclass class Point: token_index: int time_index: int score: float def backtrack(trellis, emission, tokens, blank_id=0): t, j = trellis.size(0) - 1, trellis.size(1) - 1 path = [Point(j, t, emission[t, blank_id].exp().item())] while j > 0: # Should not happen but just in case assert t > 0 # 1. Figure out if the current position was stay or change # Frame-wise score of stay vs change p_stay = emission[t - 1, blank_id] # p_change = emission[t - 1, tokens[j]] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] # Context-aware score for stay vs change stayed = trellis[t - 1, j] + p_stay changed = trellis[t - 1, j - 1] + p_change # Update position t -= 1 if changed > stayed: j -= 1 # Store the path with frame-wise probability. prob = (p_change if changed > stayed else p_stay).exp().item() path.append(Point(j, t, prob)) # Now j == 0, which means, it reached the SoS. # Fill up the rest for the sake of visualization while t > 0: prob = emission[t - 1, blank_id].exp().item() path.append(Point(j, t - 1, prob)) t -= 1 return path[::-1] @dataclass class Path: points: List[Point] score: float @dataclass class BeamState: """State in beam search.""" token_index: int # Current token position time_index: int # Current time step score: float # Cumulative score path: List[Point] # Path history def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): """Standard CTC beam search backtracking implementation. Args: trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps and N is the number of tokens (including the blank token). emission (torch.Tensor): The emission probabilities of shape (T, N). tokens (List[int]): List of token indices (excluding the blank token). blank_id (int, optional): The ID of the blank token. Defaults to 0. beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. Returns: List[Point]: the best path """ T, J = trellis.size(0) - 1, trellis.size(1) - 1 init_state = BeamState( token_index=J, time_index=T, score=trellis[T, J], path=[Point(J, T, emission[T, blank_id].exp().item())] ) beams = [init_state] while beams and beams[0].token_index > 0: next_beams = [] for beam in beams: t, j = beam.time_index, beam.token_index if t <= 0: continue p_stay = emission[t - 1, blank_id] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] stay_score = trellis[t - 1, j] change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') # Stay if not math.isinf(stay_score): new_path = beam.path.copy() new_path.append(Point(j, t - 1, p_stay.exp().item())) next_beams.append(BeamState( token_index=j, time_index=t - 1, score=stay_score, path=new_path )) # Change if j > 0 and not math.isinf(change_score): new_path = beam.path.copy() new_path.append(Point(j - 1, t - 1, p_change.exp().item())) next_beams.append(BeamState( token_index=j - 1, time_index=t - 1, score=change_score, path=new_path )) # sort by score beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] if not beams: break if not beams: return None best_beam = beams[0] t = best_beam.time_index j = best_beam.token_index while t > 0: prob = emission[t - 1, blank_id].exp().item() best_beam.path.append(Point(j, t - 1, prob)) t -= 1 return best_beam.path[::-1] # Merge the labels @dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path, transcript): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments def merge_words(segments, separator="|"): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words