mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
support timestamps for numbers.
This commit is contained in:
@ -2,6 +2,8 @@
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
import math
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
@ -163,10 +165,17 @@ def align(
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
else:
|
||||
# add placeholder
|
||||
clean_char.append('*')
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||
clean_wdx.append(wdx)
|
||||
else:
|
||||
# index for placeholder
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
|
||||
@ -211,7 +220,7 @@ def align(
|
||||
continue
|
||||
|
||||
text_clean = "".join(segment["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
@ -244,7 +253,8 @@ def align(
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
@ -253,7 +263,7 @@ def align(
|
||||
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
duration = t2 -t1
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
@ -360,70 +370,180 @@ def align(
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
"""
|
||||
|
||||
|
||||
def get_trellis(emission, tokens, blank_id=0):
|
||||
num_frame = emission.size(0)
|
||||
num_tokens = len(tokens)
|
||||
|
||||
# Trellis has extra diemsions for both time axis and tokens.
|
||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
||||
# The extra dim for time axis is for simplification of the code.
|
||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
||||
trellis[0, 0] = 0
|
||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
||||
trellis[0, -num_tokens:] = -float("inf")
|
||||
trellis[-num_tokens:, 0] = float("inf")
|
||||
trellis = torch.zeros((num_frame, num_tokens))
|
||||
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||
trellis[0, 1:] = -float("inf")
|
||||
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||
|
||||
for t in range(num_frame):
|
||||
for t in range(num_frame - 1):
|
||||
trellis[t + 1, 1:] = torch.maximum(
|
||||
# Score for staying at the same token
|
||||
trellis[t, 1:] + emission[t, blank_id],
|
||||
# Score for changing to the next token
|
||||
trellis[t, :-1] + emission[t, tokens],
|
||||
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||
)
|
||||
return trellis
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""处理包含通配符的token emission分数
|
||||
|
||||
Args:
|
||||
frame_emission: 当前帧的emission概率向量
|
||||
tokens: token索引列表
|
||||
|
||||
Returns:
|
||||
tensor: 每个token位置的最大概率分数
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
scores = []
|
||||
for token in tokens:
|
||||
if token == -1: # 通配符
|
||||
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
|
||||
scores.append(torch.max(valid_scores))
|
||||
else:
|
||||
scores.append(frame_emission[token])
|
||||
return torch.tensor(scores)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Point:
|
||||
token_index: int
|
||||
time_index: int
|
||||
score: float
|
||||
|
||||
|
||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||
# Note:
|
||||
# j and t are indices for trellis, which has extra dimensions
|
||||
# for time and tokens at the beginning.
|
||||
# When referring to time frame index `T` in trellis,
|
||||
# the corresponding index in emission is `T-1`.
|
||||
# Similarly, when referring to token index `J` in trellis,
|
||||
# the corresponding index in transcript is `J-1`.
|
||||
j = trellis.size(1) - 1
|
||||
t_start = torch.argmax(trellis[:, j]).item()
|
||||
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||
while j > 0:
|
||||
# Should not happen but just in case
|
||||
assert t > 0
|
||||
|
||||
path = []
|
||||
for t in range(t_start, 0, -1):
|
||||
# 1. Figure out if the current position was stay or change
|
||||
# Note (again):
|
||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
||||
# Score for token staying the same from time frame J-1 to T.
|
||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
||||
# Score for token changing from C-1 at T-1 to J at T.
|
||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
||||
# Frame-wise score of stay vs change
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
# p_change = emission[t - 1, tokens[j]]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
# 2. Store the path with frame-wise probability.
|
||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
||||
# Return token index and time index in non-trellis coordinate.
|
||||
path.append(Point(j - 1, t - 1, prob))
|
||||
# Context-aware score for stay vs change
|
||||
stayed = trellis[t - 1, j] + p_stay
|
||||
changed = trellis[t - 1, j - 1] + p_change
|
||||
|
||||
# 3. Update the token
|
||||
# Update position
|
||||
t -= 1
|
||||
if changed > stayed:
|
||||
j -= 1
|
||||
if j == 0:
|
||||
break
|
||||
else:
|
||||
# failed
|
||||
return None
|
||||
|
||||
# Store the path with frame-wise probability.
|
||||
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||
path.append(Point(j, t, prob))
|
||||
|
||||
# Now j == 0, which means, it reached the SoS.
|
||||
# Fill up the rest for the sake of visualization
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return path[::-1]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
points: List[Point]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamState:
|
||||
"""beam search中的状态"""
|
||||
token_index: int # 当前token位置
|
||||
time_index: int # 当前时间步
|
||||
score: float # 累积分数
|
||||
path: List[Point] # 路径历史
|
||||
|
||||
|
||||
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
"""标准CTC beam search回溯实现
|
||||
"""
|
||||
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
init_state = BeamState(
|
||||
token_index=J,
|
||||
time_index=T,
|
||||
score=trellis[T, J],
|
||||
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||
)
|
||||
|
||||
beams = [init_state]
|
||||
|
||||
while beams and beams[0].token_index > 0:
|
||||
next_beams = []
|
||||
|
||||
for beam in beams:
|
||||
t, j = beam.time_index, beam.token_index
|
||||
|
||||
if t <= 0:
|
||||
continue
|
||||
|
||||
p_stay = emission[t - 1, blank_id]
|
||||
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||
|
||||
stay_score = trellis[t - 1, j]
|
||||
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||
|
||||
# Stay路径
|
||||
if not math.isinf(stay_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j,
|
||||
time_index=t - 1,
|
||||
score=stay_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# Change路径
|
||||
if j > 0 and not math.isinf(change_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||
next_beams.append(BeamState(
|
||||
token_index=j - 1,
|
||||
time_index=t - 1,
|
||||
score=change_score,
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# 只按分数排序,不需要去重
|
||||
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||
|
||||
if not beams:
|
||||
break
|
||||
|
||||
if not beams:
|
||||
raise ValueError("No valid path found")
|
||||
|
||||
best_beam = beams[0]
|
||||
t = best_beam.time_index
|
||||
j = best_beam.token_index
|
||||
while t > 0:
|
||||
prob = emission[t - 1, blank_id].exp().item()
|
||||
best_beam.path.append(Point(j, t - 1, prob))
|
||||
t -= 1
|
||||
|
||||
return best_beam.path[::-1]
|
||||
|
||||
|
||||
# Merge the labels
|
||||
@dataclass
|
||||
class Segment:
|
||||
|
Reference in New Issue
Block a user