mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
@ -2,6 +2,7 @@
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
import math
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional, Union, List
|
||||
@ -171,10 +172,17 @@ def align(
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
else:
|
||||
# add placeholder
|
||||
clean_char.append('*')
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||
clean_wdx.append(wdx)
|
||||
else:
|
||||
# index for placeholder
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
|
||||
@ -222,7 +230,7 @@ def align(
|
||||
continue
|
||||
|
||||
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
@ -255,7 +263,8 @@ def align(
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
@ -264,7 +273,7 @@ def align(
|
||||
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
duration = t2 -t1
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
@ -371,70 +380,203 @@ def align(
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
"""
|
||||
|
||||
|
||||
def get_trellis(emission, tokens, blank_id=0):
|
||||
num_frame = emission.size(0)
|
||||
num_tokens = len(tokens)
|
||||
|
||||
# Trellis has extra diemsions for both time axis and tokens.
|
||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
||||
# The extra dim for time axis is for simplification of the code.
|
||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
||||
trellis[0, 0] = 0
|
||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
||||
trellis[0, -num_tokens:] = -float("inf")
|
||||
trellis[-num_tokens:, 0] = float("inf")
|
||||
trellis = torch.zeros((num_frame, num_tokens))
|
||||
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||
trellis[0, 1:] = -float("inf")
|
||||
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||
|
||||
for t in range(num_frame):
|
||||
for t in range(num_frame - 1):
|
||||
trellis[t + 1, 1:] = torch.maximum(
|
||||
# Score for staying at the same token
|
||||
trellis[t, 1:] + emission[t, blank_id],
|
||||
# Score for changing to the next token
|
||||
trellis[t, :-1] + emission[t, tokens],
|
||||
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||
)
|
||||
return trellis
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""Processing token emission scores containing wildcards (vectorized version)
|
||||
|
||||
Args:
|
||||
frame_emission: Emission probability vector for the current frame
|
||||
tokens: List of token indices
|
||||
blank_id: ID of the blank token
|
||||
|
||||
Returns:
|
||||
tensor: Maximum probability score for each token position
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
|
||||
# Convert tokens to a tensor if they are not already
|
||||
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||
|
||||
# Create a mask to identify wildcard positions
|
||||
wildcard_mask = (tokens == -1)
|
||||
|
||||
# Get scores for non-wildcard positions
|
||||
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||
|
||||
# Create a mask and compute the maximum value without modifying frame_emission
|
||||
max_valid_score = frame_emission.clone() # Create a copy
|
||||
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||
max_valid_score = max_valid_score.max()
|
||||
|
||||
# Use where operation to combine results
|
||||
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class Point:
|
||||
token_index: int
|
||||
time_index: int
|
||||
score: float
|
||||
|
||||
|
||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||
# Note:
|
||||
# j and t are indices for trellis, which has extra dimensions
|
||||
# for time and tokens at the beginning.
|
||||
# When referring to time frame index `T` in trellis,
|
||||
# the corresponding index in emission is `T-1`.
|
||||
# Similarly, when referring to token index `J` in trellis,
|
||||
# the corresponding index in transcript is `J-1`.
|
||||
j = trellis.size(1) - 1
|
||||
t_start = torch.argmax(trellis[:, j]).item()
|
||||
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||
while j > 0:
|
||||
# Should not happen but just in case
|
||||
assert t > 0
|
||||
|
||||
path = []
|
||||
for t in range(t_start, 0, -1):
|
||||
# 1. Figure out if the current position was stay or change
|
||||
# Note (again):
|
||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
||||
# Score for token staying the same from time frame J-1 to T.
|
||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
||||
# Score for token changing from C-1 at T-1 to J at T.
|
||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
||||
# Frame-wise score of stay vs change
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
# p_change = emission[t - 1, tokens[j]]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
# 2. Store the path with frame-wise probability.
|
||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
||||
# Return token index and time index in non-trellis coordinate.
|
||||
path.append(Point(j - 1, t - 1, prob))
|
||||
# Context-aware score for stay vs change
|
||||
stayed = trellis[t - 1, j] + p_stay
|
||||
changed = trellis[t - 1, j - 1] + p_change
|
||||
|
||||
# 3. Update the token
|
||||
# Update position
|
||||
t -= 1
|
||||
if changed > stayed:
|
||||
j -= 1
|
||||
if j == 0:
|
||||
break
|
||||
else:
|
||||
# failed
|
||||
return None
|
||||
|
||||
# Store the path with frame-wise probability.
|
||||
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||
path.append(Point(j, t, prob))
|
||||
|
||||
# Now j == 0, which means, it reached the SoS.
|
||||
# Fill up the rest for the sake of visualization
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return path[::-1]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
points: List[Point]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamState:
|
||||
"""State in beam search."""
|
||||
token_index: int # Current token position
|
||||
time_index: int # Current time step
|
||||
score: float # Cumulative score
|
||||
path: List[Point] # Path history
|
||||
|
||||
|
||||
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
"""Standard CTC beam search backtracking implementation.
|
||||
|
||||
Args:
|
||||
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||
and N is the number of tokens (including the blank token).
|
||||
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||
tokens (List[int]): List of token indices (excluding the blank token).
|
||||
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
List[Point]: the best path
|
||||
"""
|
||||
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
init_state = BeamState(
|
||||
token_index=J,
|
||||
time_index=T,
|
||||
score=trellis[T, J],
|
||||
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||
)
|
||||
|
||||
beams = [init_state]
|
||||
|
||||
while beams and beams[0].token_index > 0:
|
||||
next_beams = []
|
||||
|
||||
for beam in beams:
|
||||
t, j = beam.time_index, beam.token_index
|
||||
|
||||
if t <= 0:
|
||||
continue
|
||||
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
stay_score = trellis[t - 1, j]
|
||||
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||
|
||||
# Stay
|
||||
if not math.isinf(stay_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j,
|
||||
time_index=t - 1,
|
||||
score=stay_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# Change
|
||||
if j > 0 and not math.isinf(change_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j - 1,
|
||||
time_index=t - 1,
|
||||
score=change_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# sort by score
|
||||
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||
|
||||
if not beams:
|
||||
break
|
||||
|
||||
if not beams:
|
||||
return None
|
||||
|
||||
best_beam = beams[0]
|
||||
t = best_beam.time_index
|
||||
j = best_beam.token_index
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
best_beam.path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return best_beam.path[::-1]
|
||||
|
||||
|
||||
# Merge the labels
|
||||
@dataclass
|
||||
class Segment:
|
||||
|
Reference in New Issue
Block a user