support timestamps for numbers.

This commit is contained in:
liupeng
2025-01-09 15:23:40 +08:00
parent 734084cdf6
commit 69281f3a29

View File

@ -2,6 +2,8 @@
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Optional, Union, List
@ -163,10 +165,17 @@ def align(
elif char_ in model_dictionary.keys(): elif char_ in model_dictionary.keys():
clean_char.append(char_) clean_char.append(char_)
clean_cdx.append(cdx) clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = [] clean_wdx = []
for wdx, wrd in enumerate(per_word): for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]): if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx) clean_wdx.append(wdx)
@ -211,7 +220,7 @@ def align(
continue continue
text_clean = "".join(segment["clean_char"]) text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary[c] for c in text_clean] tokens = [model_dictionary.get(c, -1) for c in text_clean]
f1 = int(t1 * SAMPLE_RATE) f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE)
@ -244,7 +253,8 @@ def align(
blank_id = code blank_id = code
trellis = get_trellis(emission, tokens, blank_id) trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id) # path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id)
if path is None: if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -360,70 +370,180 @@ def align(
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
""" """
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens. trellis = torch.zeros((num_frame, num_tokens))
# The extra dim for tokens represents <SoS> (start-of-sentence) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
# The extra dim for time axis is for simplification of the code. trellis[0, 1:] = -float("inf")
trellis = torch.empty((num_frame + 1, num_tokens + 1)) trellis[-num_tokens + 1:, 0] = float("inf")
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame): for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum( trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token # Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id], trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token # Score for changing to the next token
trellis[t, :-1] + emission[t, tokens], # trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
) )
return trellis return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""处理包含通配符的token emission分数
Args:
frame_emission: 当前帧的emission概率向量
tokens: token索引列表
Returns:
tensor: 每个token位置的最大概率分数
"""
assert 0 <= blank_id < len(frame_emission)
scores = []
for token in tokens:
if token == -1: # 通配符
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
scores.append(torch.max(valid_scores))
else:
scores.append(frame_emission[token])
return torch.tensor(scores)
@dataclass @dataclass
class Point: class Point:
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def backtrack(trellis, emission, tokens, blank_id=0): def backtrack(trellis, emission, tokens, blank_id=0):
# Note: t, j = trellis.size(0) - 1, trellis.size(1) - 1
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning. path = [Point(j, t, emission[t, blank_id].exp().item())]
# When referring to time frame index `T` in trellis, while j > 0:
# the corresponding index in emission is `T-1`. # Should not happen but just in case
# Similarly, when referring to token index `J` in trellis, assert t > 0
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change # 1. Figure out if the current position was stay or change
# Note (again): # Frame-wise score of stay vs change
# `emission[J-1]` is the emission at time frame `J` of trellis dimension. p_stay = emission[t - 1, blank_id]
# Score for token staying the same from time frame J-1 to T. # p_change = emission[t - 1, tokens[j]]
stayed = trellis[t - 1, j] + emission[t - 1, blank_id] p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability. # Context-aware score for stay vs change
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() stayed = trellis[t - 1, j] + p_stay
# Return token index and time index in non-trellis coordinate. changed = trellis[t - 1, j - 1] + p_change
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token # Update position
t -= 1
if changed > stayed: if changed > stayed:
j -= 1 j -= 1
if j == 0:
break # Store the path with frame-wise probability.
else: prob = (p_change if changed > stayed else p_stay).exp().item()
# failed path.append(Point(j, t, prob))
return None
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1] return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""beam search中的状态"""
token_index: int # 当前token位置
time_index: int # 当前时间步
score: float # 累积分数
path: List[Point] # 路径历史
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""标准CTC beam search回溯实现
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay路径
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change路径
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# 只按分数排序,不需要去重
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break
if not beams:
raise ValueError("No valid path found")
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
# Merge the labels # Merge the labels
@dataclass @dataclass
class Segment: class Segment: