From 65b2332e139843cc553f589dc54e004e3c25ba39 Mon Sep 17 00:00:00 2001 From: liupeng Date: Thu, 9 Jan 2025 19:33:26 +0800 Subject: [PATCH] make align a bit faster. --- whisperx/alignment.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 03b5a49..872859f 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -254,7 +254,7 @@ def align( trellis = get_trellis(emission, tokens, blank_id) # path = backtrack(trellis, emission, tokens, blank_id) - path = backtrack_beam(trellis, emission, tokens, blank_id) + path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) if path is None: print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') @@ -393,24 +393,36 @@ def get_trellis(emission, tokens, blank_id=0): def get_wildcard_emission(frame_emission, tokens, blank_id): - """处理包含通配符的token emission分数 + """处理包含通配符的token emission分数(向量化版本) Args: frame_emission: 当前帧的emission概率向量 tokens: token索引列表 + blank_id: blank token的ID Returns: tensor: 每个token位置的最大概率分数 """ assert 0 <= blank_id < len(frame_emission) - scores = [] - for token in tokens: - if token == -1: # 通配符 - valid_scores = torch.cat([frame_emission[:blank_id], frame_emission[blank_id + 1:]]) - scores.append(torch.max(valid_scores)) - else: - scores.append(frame_emission[token]) - return torch.tensor(scores) + + # 将tokens转换为tensor(如果还不是的话) + tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens + + # 创建mask来标识通配符位置 + wildcard_mask = (tokens == -1) + + # 为非通配符位置获取分数 + regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引 + + # 创建掩码并计算最大值,不会修改frame_emission + max_valid_score = frame_emission.clone() # 创建副本 + max_valid_score[blank_id] = float('-inf') # 在副本上操作 + max_valid_score = max_valid_score.max() + + # 使用where操作来组合结果 + result = torch.where(wildcard_mask, max_valid_score, regular_scores) + + return result @dataclass