make align a bit faster.

This commit is contained in:
liupeng
2025-01-09 19:33:26 +08:00
parent 69281f3a29
commit 65b2332e13

View File

@ -254,7 +254,7 @@ def align(
trellis = get_trellis(emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -393,24 +393,36 @@ def get_trellis(emission, tokens, blank_id=0):
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""处理包含通配符的token emission分数
"""处理包含通配符的token emission分数(向量化版本)
Args:
frame_emission: 当前帧的emission概率向量
tokens: token索引列表
blank_id: blank token的ID
Returns:
tensor: 每个token位置的最大概率分数
"""
assert 0 <= blank_id < len(frame_emission)
scores = []
for token in tokens:
if token == -1: # 通配符
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
scores.append(torch.max(valid_scores))
else:
scores.append(frame_emission[token])
return torch.tensor(scores)
# 将tokens转换为tensor如果还不是的话
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# 创建mask来标识通配符位置
wildcard_mask = (tokens == -1)
# 为非通配符位置获取分数
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引
# 创建掩码并计算最大值不会修改frame_emission
max_valid_score = frame_emission.clone() # 创建副本
max_valid_score[blank_id] = float('-inf') # 在副本上操作
max_valid_score = max_valid_score.max()
# 使用where操作来组合结果
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@dataclass