support timestamps for numbers.

This commit is contained in:
liupeng
2025-01-09 15:23:40 +08:00
parent 734084cdf6
commit 69281f3a29

View File

@ -2,6 +2,8 @@
Forced Alignment with Whisper
C. Max Bain
"""
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
@ -163,10 +165,17 @@ def align(
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx)
@ -211,7 +220,7 @@ def align(
continue
text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
tokens = [model_dictionary.get(c, -1) for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
@ -244,7 +253,8 @@ def align(
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -253,7 +263,7 @@ def align(
char_segments = merge_repeats(path, text_clean)
duration = t2 -t1
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters
@ -360,70 +370,180 @@ def align(
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1:, 0] = float("inf")
for t in range(num_frame):
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
)
return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""处理包含通配符的token emission分数
Args:
frame_emission: 当前帧的emission概率向量
tokens: token索引列表
Returns:
tensor: 每个token位置的最大概率分数
"""
assert 0 <= blank_id < len(frame_emission)
scores = []
for token in tokens:
if token == -1: # 通配符
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
scores.append(torch.max(valid_scores))
else:
scores.append(frame_emission[token])
return torch.tensor(scores)
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# 3. Update the token
# Update position
t -= 1
if changed > stayed:
j -= 1
if j == 0:
break
else:
# failed
return None
# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""beam search中的状态"""
token_index: int # 当前token位置
time_index: int # 当前时间步
score: float # 累积分数
path: List[Point] # 路径历史
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""标准CTC beam search回溯实现
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay路径
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change路径
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# 只按分数排序,不需要去重
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break
if not beams:
raise ValueError("No valid path found")
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
# Merge the labels
@dataclass
class Segment: