mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
make align a bit faster.
This commit is contained in:
@ -254,7 +254,7 @@ def align(
|
|||||||
|
|
||||||
trellis = get_trellis(emission, tokens, blank_id)
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
path = backtrack_beam(trellis, emission, tokens, blank_id)
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
@ -393,24 +393,36 @@ def get_trellis(emission, tokens, blank_id=0):
|
|||||||
|
|
||||||
|
|
||||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||||
"""处理包含通配符的token emission分数
|
"""处理包含通配符的token emission分数(向量化版本)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame_emission: 当前帧的emission概率向量
|
frame_emission: 当前帧的emission概率向量
|
||||||
tokens: token索引列表
|
tokens: token索引列表
|
||||||
|
blank_id: blank token的ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tensor: 每个token位置的最大概率分数
|
tensor: 每个token位置的最大概率分数
|
||||||
"""
|
"""
|
||||||
assert 0 <= blank_id < len(frame_emission)
|
assert 0 <= blank_id < len(frame_emission)
|
||||||
scores = []
|
|
||||||
for token in tokens:
|
# 将tokens转换为tensor(如果还不是的话)
|
||||||
if token == -1: # 通配符
|
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||||
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
|
|
||||||
scores.append(torch.max(valid_scores))
|
# 创建mask来标识通配符位置
|
||||||
else:
|
wildcard_mask = (tokens == -1)
|
||||||
scores.append(frame_emission[token])
|
|
||||||
return torch.tensor(scores)
|
# 为非通配符位置获取分数
|
||||||
|
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引
|
||||||
|
|
||||||
|
# 创建掩码并计算最大值,不会修改frame_emission
|
||||||
|
max_valid_score = frame_emission.clone() # 创建副本
|
||||||
|
max_valid_score[blank_id] = float('-inf') # 在副本上操作
|
||||||
|
max_valid_score = max_valid_score.max()
|
||||||
|
|
||||||
|
# 使用where操作来组合结果
|
||||||
|
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
Reference in New Issue
Block a user