diff --git a/whisperx/alignment.py b/whisperx/alignment.py index d6241bb..03b5a49 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,6 +2,8 @@ Forced Alignment with Whisper C. Max Bain """ +import math + from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -163,10 +165,17 @@ def align( elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) + else: + # add placeholder + clean_char.append('*') + clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): - if any([c in model_dictionary.keys() for c in wrd]): + if any([c in model_dictionary.keys() for c in wrd.lower()]): + clean_wdx.append(wdx) + else: + # index for placeholder clean_wdx.append(wdx) @@ -211,7 +220,7 @@ def align( continue text_clean = "".join(segment["clean_char"]) - tokens = [model_dictionary[c] for c in text_clean] + tokens = [model_dictionary.get(c, -1) for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) @@ -244,7 +253,8 @@ def align( blank_id = code trellis = get_trellis(emission, tokens, blank_id) - path = backtrack(trellis, emission, tokens, blank_id) + # path = backtrack(trellis, emission, tokens, blank_id) + path = backtrack_beam(trellis, emission, tokens, blank_id) if path is None: print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') @@ -253,7 +263,7 @@ def align( char_segments = merge_repeats(path, text_clean) - duration = t2 -t1 + duration = t2 - t1 ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) # assign timestamps to aligned characters @@ -360,70 +370,180 @@ def align( """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html """ + + def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) - # Trellis has extra diemsions for both time axis and tokens. - # The extra dim for tokens represents (start-of-sentence) - # The extra dim for time axis is for simplification of the code. - trellis = torch.empty((num_frame + 1, num_tokens + 1)) - trellis[0, 0] = 0 - trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) - trellis[0, -num_tokens:] = -float("inf") - trellis[-num_tokens:, 0] = float("inf") + trellis = torch.zeros((num_frame, num_tokens)) + trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) + trellis[0, 1:] = -float("inf") + trellis[-num_tokens + 1:, 0] = float("inf") - for t in range(num_frame): + for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token - trellis[t, :-1] + emission[t, tokens], + # trellis[t, :-1] + emission[t, tokens[1:]], + trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), ) return trellis + +def get_wildcard_emission(frame_emission, tokens, blank_id): + """处理包含通配符的token emission分数 + + Args: + frame_emission: 当前帧的emission概率向量 + tokens: token索引列表 + + Returns: + tensor: 每个token位置的最大概率分数 + """ + assert 0 <= blank_id < len(frame_emission) + scores = [] + for token in tokens: + if token == -1: # 通配符 + valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]]) + scores.append(torch.max(valid_scores)) + else: + scores.append(frame_emission[token]) + return torch.tensor(scores) + + @dataclass class Point: token_index: int time_index: int score: float + def backtrack(trellis, emission, tokens, blank_id=0): - # Note: - # j and t are indices for trellis, which has extra dimensions - # for time and tokens at the beginning. - # When referring to time frame index `T` in trellis, - # the corresponding index in emission is `T-1`. - # Similarly, when referring to token index `J` in trellis, - # the corresponding index in transcript is `J-1`. - j = trellis.size(1) - 1 - t_start = torch.argmax(trellis[:, j]).item() + t, j = trellis.size(0) - 1, trellis.size(1) - 1 + + path = [Point(j, t, emission[t, blank_id].exp().item())] + while j > 0: + # Should not happen but just in case + assert t > 0 - path = [] - for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change - # Note (again): - # `emission[J-1]` is the emission at time frame `J` of trellis dimension. - # Score for token staying the same from time frame J-1 to T. - stayed = trellis[t - 1, j] + emission[t - 1, blank_id] - # Score for token changing from C-1 at T-1 to J at T. - changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + # Frame-wise score of stay vs change + p_stay = emission[t - 1, blank_id] + # p_change = emission[t - 1, tokens[j]] + p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - # 2. Store the path with frame-wise probability. - prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() - # Return token index and time index in non-trellis coordinate. - path.append(Point(j - 1, t - 1, prob)) + # Context-aware score for stay vs change + stayed = trellis[t - 1, j] + p_stay + changed = trellis[t - 1, j - 1] + p_change - # 3. Update the token + # Update position + t -= 1 if changed > stayed: j -= 1 - if j == 0: - break - else: - # failed - return None + + # Store the path with frame-wise probability. + prob = (p_change if changed > stayed else p_stay).exp().item() + path.append(Point(j, t, prob)) + + # Now j == 0, which means, it reached the SoS. + # Fill up the rest for the sake of visualization + while t > 0: + prob = emission[t - 1, blank_id].exp().item() + path.append(Point(j, t - 1, prob)) + t -= 1 + return path[::-1] + + +@dataclass +class Path: + points: List[Point] + score: float + + +@dataclass +class BeamState: + """beam search中的状态""" + token_index: int # 当前token位置 + time_index: int # 当前时间步 + score: float # 累积分数 + path: List[Point] # 路径历史 + + +def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): + """标准CTC beam search回溯实现 + """ + T, J = trellis.size(0) - 1, trellis.size(1) - 1 + + init_state = BeamState( + token_index=J, + time_index=T, + score=trellis[T, J], + path=[Point(J, T, emission[T, blank_id].exp().item())] + ) + + beams = [init_state] + + while beams and beams[0].token_index > 0: + next_beams = [] + + for beam in beams: + t, j = beam.time_index, beam.token_index + + if t <= 0: + continue + + p_stay = emission[t - 1, blank_id] + p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] + + stay_score = trellis[t - 1, j] + change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') + + # Stay路径 + if not math.isinf(stay_score): + new_path = beam.path.copy() + new_path.append(Point(j, t - 1, p_stay.exp().item())) + next_beams.append(BeamState( + token_index=j, + time_index=t - 1, + score=stay_score, + path=new_path + )) + + # Change路径 + if j > 0 and not math.isinf(change_score): + new_path = beam.path.copy() + new_path.append(Point(j - 1, t - 1, p_change.exp().item())) + next_beams.append(BeamState( + token_index=j - 1, + time_index=t - 1, + score=change_score, + path=new_path + )) + + # 只按分数排序,不需要去重 + beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] + + if not beams: + break + + if not beams: + raise ValueError("No valid path found") + + best_beam = beams[0] + t = best_beam.time_index + j = best_beam.token_index + while t > 0: + prob = emission[t - 1, blank_id].exp().item() + best_beam.path.append(Point(j, t - 1, prob)) + t -= 1 + + return best_beam.path[::-1] + + # Merge the labels @dataclass class Segment: