""" Forced Alignment with Whisper C. Max Bain """ import math from dataclasses import dataclass from typing import Iterable, Optional, Union, List import numpy as np import pandas as pd import torch import torchaudio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans from .types import ( AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment, SegmentData, ) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] DEFAULT_ALIGN_MODELS_TORCH = { "en": "WAV2VEC2_ASR_BASE_960H", "fr": "VOXPOPULI_ASR_BASE_10K_FR", "de": "VOXPOPULI_ASR_BASE_10K_DE", "es": "VOXPOPULI_ASR_BASE_10K_ES", "it": "VOXPOPULI_ASR_BASE_10K_IT", } DEFAULT_ALIGN_MODELS_HF = { "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", "vi": 'nguyenvulebinh/wav2vec2-base-vi', "ko": "kresnik/wav2vec2-large-xlsr-korean", "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", "ca": "softcatala/wav2vec2-large-xlsr-catala", "ml": "gvs/wav2vec2-large-xlsr-malayalam", "no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2", "nn": "NbAiLab/nb-wav2vec2-1b-nynorsk", "sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8", "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian", "hr": "classla/wav2vec2-xls-r-parlaspeech-hr", "ro": "gigant/romanian-wav2vec2", "eu": "stefan-it/wav2vec2-large-xlsr-53-basque", "gl": "ifrz/wav2vec2-large-xlsr-galician", "ka": "xsway/wav2vec2-large-xlsr-georgian", } def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): if model_name is None: # use default model if language_code in DEFAULT_ALIGN_MODELS_TORCH: model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] elif language_code in DEFAULT_ALIGN_MODELS_HF: model_name = DEFAULT_ALIGN_MODELS_HF[language_code] else: print(f"There is no default alignment model set for this language ({language_code}).\ Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") raise ValueError(f"No default align-model for language: {language_code}") if model_name in torchaudio.pipelines.__all__: pipeline_type = "torchaudio" bundle = torchaudio.pipelines.__dict__[model_name] align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) labels = bundle.get_labels() align_dictionary = {c.lower(): i for i, c in enumerate(labels)} else: try: processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir) align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir) except Exception as e: print(e) print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') pipeline_type = "huggingface" align_model = align_model.to(device) labels = processor.tokenizer.get_vocab() align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} return align_model, align_metadata def align( transcript: Iterable[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, interpolate_method: str = "nearest", return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. """ if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) audio = torch.from_numpy(audio) if len(audio.shape) == 1: audio = audio.unsqueeze(0) MAX_DURATION = audio.shape[1] / SAMPLE_RATE model_dictionary = align_model_metadata["dictionary"] model_lang = align_model_metadata["language"] model_type = align_model_metadata["type"] # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) # Store temporary processing values segment_data: dict[int, SegmentData] = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: base_progress = ((sdx + 1) / total_segments) * 100 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) text = segment["text"] # split into words if model_lang not in LANGUAGES_WITHOUT_SPACES: per_word = text.split(" ") else: per_word = text clean_char, clean_cdx = [], [] for cdx, char in enumerate(text): char_ = char.lower() # wav2vec2 models use "|" character to represent spaces if model_lang not in LANGUAGES_WITHOUT_SPACES: char_ = char_.replace(" ", "|") # ignore whitespace at beginning and end of transcript if cdx < num_leading: pass elif cdx > len(text) - num_trailing - 1: pass elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) else: # add placeholder clean_char.append('*') clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): if any([c in model_dictionary.keys() for c in wrd.lower()]): clean_wdx.append(wdx) else: # index for placeholder clean_wdx.append(wdx) punkt_param = PunktParameters() punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_spans = list(sentence_splitter.span_tokenize(text)) segment_data[sdx] = { "clean_char": clean_char, "clean_cdx": clean_cdx, "clean_wdx": clean_wdx, "sentence_spans": sentence_spans } aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): t1 = segment["start"] t2 = segment["end"] text = segment["text"] aligned_seg: SingleAlignedSegment = { "start": t1, "end": t2, "text": text, "words": [], "chars": None, } if return_char_alignments: aligned_seg["chars"] = [] # check we can align if len(segment_data[sdx]["clean_char"]) == 0: print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') aligned_segments.append(aligned_seg) continue if t1 >= MAX_DURATION: print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') aligned_segments.append(aligned_seg) continue text_clean = "".join(segment_data[sdx]["clean_char"]) tokens = [model_dictionary.get(c, -1) for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) # TODO: Probably can get some speedup gain with batched inference here waveform_segment = audio[:, f1:f2] # Handle the minimum input length for wav2vec2 models if waveform_segment.shape[-1] < 400: lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) waveform_segment = torch.nn.functional.pad( waveform_segment, (0, 400 - waveform_segment.shape[-1]) ) else: lengths = None with torch.inference_mode(): if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits else: raise NotImplementedError(f"Align model of type {model_type} not supported.") emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() blank_id = 0 for char, code in model_dictionary.items(): if char == '[pad]' or char == '': blank_id = code trellis = get_trellis(emission, tokens, blank_id) # path = backtrack(trellis, emission, tokens, blank_id) path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) if path is None: print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') aligned_segments.append(aligned_seg) continue char_segments = merge_repeats(path, text_clean) duration = t2 - t1 ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) # assign timestamps to aligned characters char_segments_arr = [] word_idx = 0 for cdx, char in enumerate(text): start, end, score = None, None, None if cdx in segment_data[sdx]["clean_cdx"]: char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] start = round(char_seg.start * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3) score = round(char_seg.score, 3) char_segments_arr.append( { "char": char, "start": start, "end": end, "score": score, "word-idx": word_idx, } ) # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... if model_lang in LANGUAGES_WITHOUT_SPACES: word_idx += 1 elif cdx == len(text) - 1 or text[cdx+1] == " ": word_idx += 1 char_segments_arr = pd.DataFrame(char_segments_arr) aligned_subsegments = [] # assign sentence_idx to each character index char_segments_arr["sentence-idx"] = None for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 sentence_text = text[sstart:send] sentence_start = curr_chars["start"].min() end_chars = curr_chars[curr_chars["char"] != ' '] sentence_end = end_chars["end"].max() sentence_words = [] for word_idx in curr_chars["word-idx"].unique(): word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] word_text = "".join(word_chars["char"].tolist()).strip() if len(word_text) == 0: continue # dont use space character for alignment word_chars = word_chars[word_chars["char"] != " "] word_start = word_chars["start"].min() word_end = word_chars["end"].max() word_score = round(word_chars["score"].mean(), 3) # -1 indicates unalignable word_segment = {"word": word_text} if not np.isnan(word_start): word_segment["start"] = word_start if not np.isnan(word_end): word_segment["end"] = word_end if not np.isnan(word_score): word_segment["score"] = word_score sentence_words.append(word_segment) aligned_subsegments.append({ "text": sentence_text, "start": sentence_start, "end": sentence_end, "words": sentence_words, }) if return_char_alignments: curr_chars = curr_chars[["char", "start", "end", "score"]] curr_chars.fillna(-1, inplace=True) curr_chars = curr_chars.to_dict("records") curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] aligned_subsegments[-1]["chars"] = curr_chars aligned_subsegments = pd.DataFrame(aligned_subsegments) aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) # concatenate sentences with same timestamps agg_dict = {"text": " ".join, "words": "sum"} if model_lang in LANGUAGES_WITHOUT_SPACES: agg_dict["text"] = "".join if return_char_alignments: agg_dict["chars"] = "sum" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments = aligned_subsegments.to_dict('records') aligned_segments += aligned_subsegments # create word_segments list word_segments: List[SingleWordSegment] = [] for segment in aligned_segments: word_segments += segment["words"] return {"segments": aligned_segments, "word_segments": word_segments} """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html """ def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) trellis = torch.zeros((num_frame, num_tokens)) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) trellis[0, 1:] = -float("inf") trellis[-num_tokens + 1:, 0] = float("inf") for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token # trellis[t, :-1] + emission[t, tokens[1:]], trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), ) return trellis def get_wildcard_emission(frame_emission, tokens, blank_id): """处理包含通配符的token emission分数(向量化版本) Args: frame_emission: 当前帧的emission概率向量 tokens: token索引列表 blank_id: blank token的ID Returns: tensor: 每个token位置的最大概率分数 """ assert 0 <= blank_id < len(frame_emission) # 将tokens转换为tensor(如果还不是的话) tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens # 创建mask来标识通配符位置 wildcard_mask = (tokens == -1) # 为非通配符位置获取分数 regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引 # 创建掩码并计算最大值,不会修改frame_emission max_valid_score = frame_emission.clone() # 创建副本 max_valid_score[blank_id] = float('-inf') # 在副本上操作 max_valid_score = max_valid_score.max() # 使用where操作来组合结果 result = torch.where(wildcard_mask, max_valid_score, regular_scores) return result @dataclass class Point: token_index: int time_index: int score: float def backtrack(trellis, emission, tokens, blank_id=0): t, j = trellis.size(0) - 1, trellis.size(1) - 1 path = [Point(j, t, emission[t, blank_id].exp().item())] while j > 0: # Should not happen but just in case assert t > 0 # 1. Figure out if the current position was stay or change # Frame-wise score of stay vs change p_stay = emission[t - 1, blank_id] # p_change = emission[t - 1, tokens[j]] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] # Context-aware score for stay vs change stayed = trellis[t - 1, j] + p_stay changed = trellis[t - 1, j - 1] + p_change # Update position t -= 1 if changed > stayed: j -= 1 # Store the path with frame-wise probability. prob = (p_change if changed > stayed else p_stay).exp().item() path.append(Point(j, t, prob)) # Now j == 0, which means, it reached the SoS. # Fill up the rest for the sake of visualization while t > 0: prob = emission[t - 1, blank_id].exp().item() path.append(Point(j, t - 1, prob)) t -= 1 return path[::-1] @dataclass class Path: points: List[Point] score: float @dataclass class BeamState: """beam search中的状态""" token_index: int # 当前token位置 time_index: int # 当前时间步 score: float # 累积分数 path: List[Point] # 路径历史 def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): """标准CTC beam search回溯实现 """ T, J = trellis.size(0) - 1, trellis.size(1) - 1 init_state = BeamState( token_index=J, time_index=T, score=trellis[T, J], path=[Point(J, T, emission[T, blank_id].exp().item())] ) beams = [init_state] while beams and beams[0].token_index > 0: next_beams = [] for beam in beams: t, j = beam.time_index, beam.token_index if t <= 0: continue p_stay = emission[t - 1, blank_id] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] stay_score = trellis[t - 1, j] change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') # Stay路径 if not math.isinf(stay_score): new_path = beam.path.copy() new_path.append(Point(j, t - 1, p_stay.exp().item())) next_beams.append(BeamState( token_index=j, time_index=t - 1, score=stay_score, path=new_path )) # Change路径 if j > 0 and not math.isinf(change_score): new_path = beam.path.copy() new_path.append(Point(j - 1, t - 1, p_change.exp().item())) next_beams.append(BeamState( token_index=j - 1, time_index=t - 1, score=change_score, path=new_path )) # 只按分数排序,不需要去重 beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] if not beams: break if not beams: return None best_beam = beams[0] t = best_beam.time_index j = best_beam.token_index while t > 0: prob = emission[t - 1, blank_id].exp().item() best_beam.path.append(Point(j, t - 1, prob)) t -= 1 return best_beam.path[::-1] # Merge the labels @dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path, transcript): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments def merge_words(segments, separator="|"): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words