diff --git a/whisperx/alignment.py b/whisperx/alignment.py index c4750ca..0c4e2b0 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,6 +2,7 @@ Forced Alignment with Whisper C. Max Bain """ +import math from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -171,10 +172,17 @@ def align( elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) + else: + # add placeholder + clean_char.append('*') + clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): - if any([c in model_dictionary.keys() for c in wrd]): + if any([c in model_dictionary.keys() for c in wrd.lower()]): + clean_wdx.append(wdx) + else: + # index for placeholder clean_wdx.append(wdx) @@ -222,7 +230,7 @@ def align( continue text_clean = "".join(segment_data[sdx]["clean_char"]) - tokens = [model_dictionary[c] for c in text_clean] + tokens = [model_dictionary.get(c, -1) for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) @@ -255,7 +263,8 @@ def align( blank_id = code trellis = get_trellis(emission, tokens, blank_id) - path = backtrack(trellis, emission, tokens, blank_id) + # path = backtrack(trellis, emission, tokens, blank_id) + path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) if path is None: print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') @@ -264,7 +273,7 @@ def align( char_segments = merge_repeats(path, text_clean) - duration = t2 -t1 + duration = t2 - t1 ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) # assign timestamps to aligned characters @@ -371,70 +380,203 @@ def align( """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html """ + + def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) - # Trellis has extra diemsions for both time axis and tokens. - # The extra dim for tokens represents (start-of-sentence) - # The extra dim for time axis is for simplification of the code. - trellis = torch.empty((num_frame + 1, num_tokens + 1)) - trellis[0, 0] = 0 - trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) - trellis[0, -num_tokens:] = -float("inf") - trellis[-num_tokens:, 0] = float("inf") + trellis = torch.zeros((num_frame, num_tokens)) + trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) + trellis[0, 1:] = -float("inf") + trellis[-num_tokens + 1:, 0] = float("inf") - for t in range(num_frame): + for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token - trellis[t, :-1] + emission[t, tokens], + # trellis[t, :-1] + emission[t, tokens[1:]], + trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), ) return trellis + +def get_wildcard_emission(frame_emission, tokens, blank_id): + """Processing token emission scores containing wildcards (vectorized version) + + Args: + frame_emission: Emission probability vector for the current frame + tokens: List of token indices + blank_id: ID of the blank token + + Returns: + tensor: Maximum probability score for each token position + """ + assert 0 <= blank_id < len(frame_emission) + + # Convert tokens to a tensor if they are not already + tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens + + # Create a mask to identify wildcard positions + wildcard_mask = (tokens == -1) + + # Get scores for non-wildcard positions + regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index + + # Create a mask and compute the maximum value without modifying frame_emission + max_valid_score = frame_emission.clone() # Create a copy + max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token + max_valid_score = max_valid_score.max() + + # Use where operation to combine results + result = torch.where(wildcard_mask, max_valid_score, regular_scores) + + return result + + @dataclass class Point: token_index: int time_index: int score: float + def backtrack(trellis, emission, tokens, blank_id=0): - # Note: - # j and t are indices for trellis, which has extra dimensions - # for time and tokens at the beginning. - # When referring to time frame index `T` in trellis, - # the corresponding index in emission is `T-1`. - # Similarly, when referring to token index `J` in trellis, - # the corresponding index in transcript is `J-1`. - j = trellis.size(1) - 1 - t_start = torch.argmax(trellis[:, j]).item() + t, j = trellis.size(0) - 1, trellis.size(1) - 1 + + path = [Point(j, t, emission[t, blank_id].exp().item())] + while j > 0: + # Should not happen but just in case + assert t > 0 - path = [] - for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change - # Note (again): - # `emission[J-1]` is the emission at time frame `J` of trellis dimension. - # Score for token staying the same from time frame J-1 to T. - stayed = trellis[t - 1, j] + emission[t - 1, blank_id] - # Score for token changing from C-1 at T-1 to J at T. - changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + # Frame-wise score of stay vs change + p_stay = emission[t - 1, blank_id] + # p_change = emission[t - 1, tokens[j]] + p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - # 2. Store the path with frame-wise probability. - prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() - # Return token index and time index in non-trellis coordinate. - path.append(Point(j - 1, t - 1, prob)) + # Context-aware score for stay vs change + stayed = trellis[t - 1, j] + p_stay + changed = trellis[t - 1, j - 1] + p_change - # 3. Update the token + # Update position + t -= 1 if changed > stayed: j -= 1 - if j == 0: - break - else: - # failed - return None + + # Store the path with frame-wise probability. + prob = (p_change if changed > stayed else p_stay).exp().item() + path.append(Point(j, t, prob)) + + # Now j == 0, which means, it reached the SoS. + # Fill up the rest for the sake of visualization + while t > 0: + prob = emission[t - 1, blank_id].exp().item() + path.append(Point(j, t - 1, prob)) + t -= 1 + return path[::-1] + + +@dataclass +class Path: + points: List[Point] + score: float + + +@dataclass +class BeamState: + """State in beam search.""" + token_index: int # Current token position + time_index: int # Current time step + score: float # Cumulative score + path: List[Point] # Path history + + +def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): + """Standard CTC beam search backtracking implementation. + + Args: + trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps + and N is the number of tokens (including the blank token). + emission (torch.Tensor): The emission probabilities of shape (T, N). + tokens (List[int]): List of token indices (excluding the blank token). + blank_id (int, optional): The ID of the blank token. Defaults to 0. + beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. + + Returns: + List[Point]: the best path + """ + T, J = trellis.size(0) - 1, trellis.size(1) - 1 + + init_state = BeamState( + token_index=J, + time_index=T, + score=trellis[T, J], + path=[Point(J, T, emission[T, blank_id].exp().item())] + ) + + beams = [init_state] + + while beams and beams[0].token_index > 0: + next_beams = [] + + for beam in beams: + t, j = beam.time_index, beam.token_index + + if t <= 0: + continue + + p_stay = emission[t - 1, blank_id] + p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] + + stay_score = trellis[t - 1, j] + change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') + + # Stay + if not math.isinf(stay_score): + new_path = beam.path.copy() + new_path.append(Point(j, t - 1, p_stay.exp().item())) + next_beams.append(BeamState( + token_index=j, + time_index=t - 1, + score=stay_score, + path=new_path + )) + + # Change + if j > 0 and not math.isinf(change_score): + new_path = beam.path.copy() + new_path.append(Point(j - 1, t - 1, p_change.exp().item())) + next_beams.append(BeamState( + token_index=j - 1, + time_index=t - 1, + score=change_score, + path=new_path + )) + + # sort by score + beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] + + if not beams: + break + + if not beams: + return None + + best_beam = beams[0] + t = best_beam.time_index + j = best_beam.token_index + while t > 0: + prob = emission[t - 1, blank_id].exp().item() + best_beam.path.append(Point(j, t - 1, prob)) + t -= 1 + + return best_beam.path[::-1] + + # Merge the labels @dataclass class Segment: