mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
make align a bit faster.
This commit is contained in:
@ -254,7 +254,7 @@ def align(
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id)
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
@ -393,24 +393,36 @@ def get_trellis(emission, tokens, blank_id=0):
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""处理包含通配符的token emission分数
|
||||
"""处理包含通配符的token emission分数(向量化版本)
|
||||
|
||||
Args:
|
||||
frame_emission: 当前帧的emission概率向量
|
||||
tokens: token索引列表
|
||||
blank_id: blank token的ID
|
||||
|
||||
Returns:
|
||||
tensor: 每个token位置的最大概率分数
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
scores = []
|
||||
for token in tokens:
|
||||
if token == -1: # 通配符
|
||||
valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]])
|
||||
scores.append(torch.max(valid_scores))
|
||||
else:
|
||||
scores.append(frame_emission[token])
|
||||
return torch.tensor(scores)
|
||||
|
||||
# 将tokens转换为tensor(如果还不是的话)
|
||||
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||
|
||||
# 创建mask来标识通配符位置
|
||||
wildcard_mask = (tokens == -1)
|
||||
|
||||
# 为非通配符位置获取分数
|
||||
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引
|
||||
|
||||
# 创建掩码并计算最大值,不会修改frame_emission
|
||||
max_valid_score = frame_emission.clone() # 创建副本
|
||||
max_valid_score[blank_id] = float('-inf') # 在副本上操作
|
||||
max_valid_score = max_valid_score.max()
|
||||
|
||||
# 使用where操作来组合结果
|
||||
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Reference in New Issue
Block a user