mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
@ -2,6 +2,7 @@
|
|||||||
Forced Alignment with Whisper
|
Forced Alignment with Whisper
|
||||||
C. Max Bain
|
C. Max Bain
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, Optional, Union, List
|
from typing import Iterable, Optional, Union, List
|
||||||
@ -171,10 +172,17 @@ def align(
|
|||||||
elif char_ in model_dictionary.keys():
|
elif char_ in model_dictionary.keys():
|
||||||
clean_char.append(char_)
|
clean_char.append(char_)
|
||||||
clean_cdx.append(cdx)
|
clean_cdx.append(cdx)
|
||||||
|
else:
|
||||||
|
# add placeholder
|
||||||
|
clean_char.append('*')
|
||||||
|
clean_cdx.append(cdx)
|
||||||
|
|
||||||
clean_wdx = []
|
clean_wdx = []
|
||||||
for wdx, wrd in enumerate(per_word):
|
for wdx, wrd in enumerate(per_word):
|
||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||||
|
clean_wdx.append(wdx)
|
||||||
|
else:
|
||||||
|
# index for placeholder
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
|
|
||||||
@ -222,7 +230,7 @@ def align(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
text_clean = "".join(segment_data[sdx]["clean_char"])
|
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||||
tokens = [model_dictionary[c] for c in text_clean]
|
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
@ -255,7 +263,8 @@ def align(
|
|||||||
blank_id = code
|
blank_id = code
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens, blank_id)
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
path = backtrack(trellis, emission, tokens, blank_id)
|
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
@ -264,7 +273,7 @@ def align(
|
|||||||
|
|
||||||
char_segments = merge_repeats(path, text_clean)
|
char_segments = merge_repeats(path, text_clean)
|
||||||
|
|
||||||
duration = t2 -t1
|
duration = t2 - t1
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
|
|
||||||
# assign timestamps to aligned characters
|
# assign timestamps to aligned characters
|
||||||
@ -371,70 +380,203 @@ def align(
|
|||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_trellis(emission, tokens, blank_id=0):
|
def get_trellis(emission, tokens, blank_id=0):
|
||||||
num_frame = emission.size(0)
|
num_frame = emission.size(0)
|
||||||
num_tokens = len(tokens)
|
num_tokens = len(tokens)
|
||||||
|
|
||||||
# Trellis has extra diemsions for both time axis and tokens.
|
trellis = torch.zeros((num_frame, num_tokens))
|
||||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||||
# The extra dim for time axis is for simplification of the code.
|
trellis[0, 1:] = -float("inf")
|
||||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||||
trellis[0, 0] = 0
|
|
||||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
|
||||||
trellis[0, -num_tokens:] = -float("inf")
|
|
||||||
trellis[-num_tokens:, 0] = float("inf")
|
|
||||||
|
|
||||||
for t in range(num_frame):
|
for t in range(num_frame - 1):
|
||||||
trellis[t + 1, 1:] = torch.maximum(
|
trellis[t + 1, 1:] = torch.maximum(
|
||||||
# Score for staying at the same token
|
# Score for staying at the same token
|
||||||
trellis[t, 1:] + emission[t, blank_id],
|
trellis[t, 1:] + emission[t, blank_id],
|
||||||
# Score for changing to the next token
|
# Score for changing to the next token
|
||||||
trellis[t, :-1] + emission[t, tokens],
|
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||||
|
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||||
)
|
)
|
||||||
return trellis
|
return trellis
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||||
|
"""Processing token emission scores containing wildcards (vectorized version)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_emission: Emission probability vector for the current frame
|
||||||
|
tokens: List of token indices
|
||||||
|
blank_id: ID of the blank token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: Maximum probability score for each token position
|
||||||
|
"""
|
||||||
|
assert 0 <= blank_id < len(frame_emission)
|
||||||
|
|
||||||
|
# Convert tokens to a tensor if they are not already
|
||||||
|
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||||
|
|
||||||
|
# Create a mask to identify wildcard positions
|
||||||
|
wildcard_mask = (tokens == -1)
|
||||||
|
|
||||||
|
# Get scores for non-wildcard positions
|
||||||
|
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||||
|
|
||||||
|
# Create a mask and compute the maximum value without modifying frame_emission
|
||||||
|
max_valid_score = frame_emission.clone() # Create a copy
|
||||||
|
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||||
|
max_valid_score = max_valid_score.max()
|
||||||
|
|
||||||
|
# Use where operation to combine results
|
||||||
|
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Point:
|
class Point:
|
||||||
token_index: int
|
token_index: int
|
||||||
time_index: int
|
time_index: int
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||||
# Note:
|
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
# j and t are indices for trellis, which has extra dimensions
|
|
||||||
# for time and tokens at the beginning.
|
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||||
# When referring to time frame index `T` in trellis,
|
while j > 0:
|
||||||
# the corresponding index in emission is `T-1`.
|
# Should not happen but just in case
|
||||||
# Similarly, when referring to token index `J` in trellis,
|
assert t > 0
|
||||||
# the corresponding index in transcript is `J-1`.
|
|
||||||
j = trellis.size(1) - 1
|
|
||||||
t_start = torch.argmax(trellis[:, j]).item()
|
|
||||||
|
|
||||||
path = []
|
|
||||||
for t in range(t_start, 0, -1):
|
|
||||||
# 1. Figure out if the current position was stay or change
|
# 1. Figure out if the current position was stay or change
|
||||||
# Note (again):
|
# Frame-wise score of stay vs change
|
||||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
p_stay = emission[t - 1, blank_id]
|
||||||
# Score for token staying the same from time frame J-1 to T.
|
# p_change = emission[t - 1, tokens[j]]
|
||||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
# Score for token changing from C-1 at T-1 to J at T.
|
|
||||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
|
||||||
|
|
||||||
# 2. Store the path with frame-wise probability.
|
# Context-aware score for stay vs change
|
||||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
stayed = trellis[t - 1, j] + p_stay
|
||||||
# Return token index and time index in non-trellis coordinate.
|
changed = trellis[t - 1, j - 1] + p_change
|
||||||
path.append(Point(j - 1, t - 1, prob))
|
|
||||||
|
|
||||||
# 3. Update the token
|
# Update position
|
||||||
|
t -= 1
|
||||||
if changed > stayed:
|
if changed > stayed:
|
||||||
j -= 1
|
j -= 1
|
||||||
if j == 0:
|
|
||||||
break
|
# Store the path with frame-wise probability.
|
||||||
else:
|
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||||
# failed
|
path.append(Point(j, t, prob))
|
||||||
return None
|
|
||||||
|
# Now j == 0, which means, it reached the SoS.
|
||||||
|
# Fill up the rest for the sake of visualization
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
return path[::-1]
|
return path[::-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Path:
|
||||||
|
points: List[Point]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamState:
|
||||||
|
"""State in beam search."""
|
||||||
|
token_index: int # Current token position
|
||||||
|
time_index: int # Current time step
|
||||||
|
score: float # Cumulative score
|
||||||
|
path: List[Point] # Path history
|
||||||
|
|
||||||
|
|
||||||
|
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||||
|
"""Standard CTC beam search backtracking implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||||
|
and N is the number of tokens (including the blank token).
|
||||||
|
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||||
|
tokens (List[int]): List of token indices (excluding the blank token).
|
||||||
|
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||||
|
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Point]: the best path
|
||||||
|
"""
|
||||||
|
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
|
|
||||||
|
init_state = BeamState(
|
||||||
|
token_index=J,
|
||||||
|
time_index=T,
|
||||||
|
score=trellis[T, J],
|
||||||
|
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||||
|
)
|
||||||
|
|
||||||
|
beams = [init_state]
|
||||||
|
|
||||||
|
while beams and beams[0].token_index > 0:
|
||||||
|
next_beams = []
|
||||||
|
|
||||||
|
for beam in beams:
|
||||||
|
t, j = beam.time_index, beam.token_index
|
||||||
|
|
||||||
|
if t <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p_stay = emission[t - 1, blank_id]
|
||||||
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
|
|
||||||
|
stay_score = trellis[t - 1, j]
|
||||||
|
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||||
|
|
||||||
|
# Stay
|
||||||
|
if not math.isinf(stay_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=stay_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# Change
|
||||||
|
if j > 0 and not math.isinf(change_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j - 1,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=change_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_beam = beams[0]
|
||||||
|
t = best_beam.time_index
|
||||||
|
j = best_beam.token_index
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
best_beam.path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
|
return best_beam.path[::-1]
|
||||||
|
|
||||||
|
|
||||||
# Merge the labels
|
# Merge the labels
|
||||||
@dataclass
|
@dataclass
|
||||||
class Segment:
|
class Segment:
|
||||||
|
Reference in New Issue
Block a user