From ffbc73664c67aa76c3a6c5ab108d2c2905ea1803 Mon Sep 17 00:00:00 2001 From: liupeng Date: Mon, 13 Jan 2025 22:56:48 +0800 Subject: [PATCH] change the docstrings and comments to English --- whisperx/alignment.py | 55 ++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 39e48e5..0c4e2b0 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -403,33 +403,33 @@ def get_trellis(emission, tokens, blank_id=0): def get_wildcard_emission(frame_emission, tokens, blank_id): - """处理包含通配符的token emission分数(向量化版本) + """Processing token emission scores containing wildcards (vectorized version) Args: - frame_emission: 当前帧的emission概率向量 - tokens: token索引列表 - blank_id: blank token的ID + frame_emission: Emission probability vector for the current frame + tokens: List of token indices + blank_id: ID of the blank token Returns: - tensor: 每个token位置的最大概率分数 + tensor: Maximum probability score for each token position """ assert 0 <= blank_id < len(frame_emission) - # 将tokens转换为tensor(如果还不是的话) + # Convert tokens to a tensor if they are not already tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens - # 创建mask来标识通配符位置 + # Create a mask to identify wildcard positions wildcard_mask = (tokens == -1) - # 为非通配符位置获取分数 - regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引 + # Get scores for non-wildcard positions + regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index - # 创建掩码并计算最大值,不会修改frame_emission - max_valid_score = frame_emission.clone() # 创建副本 - max_valid_score[blank_id] = float('-inf') # 在副本上操作 + # Create a mask and compute the maximum value without modifying frame_emission + max_valid_score = frame_emission.clone() # Create a copy + max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token max_valid_score = max_valid_score.max() - # 使用where操作来组合结果 + # Use where operation to combine results result = torch.where(wildcard_mask, max_valid_score, regular_scores) return result @@ -488,15 +488,26 @@ class Path: @dataclass class BeamState: - """beam search中的状态""" - token_index: int # 当前token位置 - time_index: int # 当前时间步 - score: float # 累积分数 - path: List[Point] # 路径历史 + """State in beam search.""" + token_index: int # Current token position + time_index: int # Current time step + score: float # Cumulative score + path: List[Point] # Path history def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): - """标准CTC beam search回溯实现 + """Standard CTC beam search backtracking implementation. + + Args: + trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps + and N is the number of tokens (including the blank token). + emission (torch.Tensor): The emission probabilities of shape (T, N). + tokens (List[int]): List of token indices (excluding the blank token). + blank_id (int, optional): The ID of the blank token. Defaults to 0. + beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. + + Returns: + List[Point]: the best path """ T, J = trellis.size(0) - 1, trellis.size(1) - 1 @@ -524,7 +535,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): stay_score = trellis[t - 1, j] change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') - # Stay路径 + # Stay if not math.isinf(stay_score): new_path = beam.path.copy() new_path.append(Point(j, t - 1, p_stay.exp().item())) @@ -535,7 +546,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): path=new_path )) - # Change路径 + # Change if j > 0 and not math.isinf(change_score): new_path = beam.path.copy() new_path.append(Point(j - 1, t - 1, p_change.exp().item())) @@ -546,7 +557,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): path=new_path )) - # 只按分数排序,不需要去重 + # sort by score beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] if not beams: