Merge pull request #986 from bfs18/main

support timestamp for numbers.
This commit is contained in:
Max Bain
2025-01-14 21:03:51 +00:00
committed by GitHub

View File

@ -2,6 +2,7 @@
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Optional, Union, List
@ -171,10 +172,17 @@ def align(
elif char_ in model_dictionary.keys(): elif char_ in model_dictionary.keys():
clean_char.append(char_) clean_char.append(char_)
clean_cdx.append(cdx) clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = [] clean_wdx = []
for wdx, wrd in enumerate(per_word): for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]): if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx) clean_wdx.append(wdx)
@ -222,7 +230,7 @@ def align(
continue continue
text_clean = "".join(segment_data[sdx]["clean_char"]) text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean] tokens = [model_dictionary.get(c, -1) for c in text_clean]
f1 = int(t1 * SAMPLE_RATE) f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE)
@ -255,7 +263,8 @@ def align(
blank_id = code blank_id = code
trellis = get_trellis(emission, tokens, blank_id) trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id) # path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
if path is None: if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -264,7 +273,7 @@ def align(
char_segments = merge_repeats(path, text_clean) char_segments = merge_repeats(path, text_clean)
duration = t2 -t1 duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters # assign timestamps to aligned characters
@ -371,70 +380,203 @@ def align(
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
""" """
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens. trellis = torch.zeros((num_frame, num_tokens))
# The extra dim for tokens represents <SoS> (start-of-sentence) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
# The extra dim for time axis is for simplification of the code. trellis[0, 1:] = -float("inf")
trellis = torch.empty((num_frame + 1, num_tokens + 1)) trellis[-num_tokens + 1:, 0] = float("inf")
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame): for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum( trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token # Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id], trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token # Score for changing to the next token
trellis[t, :-1] + emission[t, tokens], # trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
) )
return trellis return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)
# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@dataclass @dataclass
class Point: class Point:
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def backtrack(trellis, emission, tokens, blank_id=0): def backtrack(trellis, emission, tokens, blank_id=0):
# Note: t, j = trellis.size(0) - 1, trellis.size(1) - 1
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning. path = [Point(j, t, emission[t, blank_id].exp().item())]
# When referring to time frame index `T` in trellis, while j > 0:
# the corresponding index in emission is `T-1`. # Should not happen but just in case
# Similarly, when referring to token index `J` in trellis, assert t > 0
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change # 1. Figure out if the current position was stay or change
# Note (again): # Frame-wise score of stay vs change
# `emission[J-1]` is the emission at time frame `J` of trellis dimension. p_stay = emission[t - 1, blank_id]
# Score for token staying the same from time frame J-1 to T. # p_change = emission[t - 1, tokens[j]]
stayed = trellis[t - 1, j] + emission[t - 1, blank_id] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability. # Context-aware score for stay vs change
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() stayed = trellis[t - 1, j] + p_stay
# Return token index and time index in non-trellis coordinate. changed = trellis[t - 1, j - 1] + p_change
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token # Update position
t -= 1
if changed > stayed: if changed > stayed:
j -= 1 j -= 1
if j == 0:
break # Store the path with frame-wise probability.
else: prob = (p_change if changed > stayed else p_stay).exp().item()
# failed path.append(Point(j, t, prob))
return None
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1] return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break
if not beams:
return None
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
# Merge the labels # Merge the labels
@dataclass @dataclass
class Segment: class Segment: