mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
change the docstrings and comments to English
This commit is contained in:
@ -403,33 +403,33 @@ def get_trellis(emission, tokens, blank_id=0):
|
||||
|
||||
|
||||
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||
"""处理包含通配符的token emission分数(向量化版本)
|
||||
"""Processing token emission scores containing wildcards (vectorized version)
|
||||
|
||||
Args:
|
||||
frame_emission: 当前帧的emission概率向量
|
||||
tokens: token索引列表
|
||||
blank_id: blank token的ID
|
||||
frame_emission: Emission probability vector for the current frame
|
||||
tokens: List of token indices
|
||||
blank_id: ID of the blank token
|
||||
|
||||
Returns:
|
||||
tensor: 每个token位置的最大概率分数
|
||||
tensor: Maximum probability score for each token position
|
||||
"""
|
||||
assert 0 <= blank_id < len(frame_emission)
|
||||
|
||||
# 将tokens转换为tensor(如果还不是的话)
|
||||
# Convert tokens to a tensor if they are not already
|
||||
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||
|
||||
# 创建mask来标识通配符位置
|
||||
# Create a mask to identify wildcard positions
|
||||
wildcard_mask = (tokens == -1)
|
||||
|
||||
# 为非通配符位置获取分数
|
||||
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引
|
||||
# Get scores for non-wildcard positions
|
||||
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||
|
||||
# 创建掩码并计算最大值,不会修改frame_emission
|
||||
max_valid_score = frame_emission.clone() # 创建副本
|
||||
max_valid_score[blank_id] = float('-inf') # 在副本上操作
|
||||
# Create a mask and compute the maximum value without modifying frame_emission
|
||||
max_valid_score = frame_emission.clone() # Create a copy
|
||||
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||
max_valid_score = max_valid_score.max()
|
||||
|
||||
# 使用where操作来组合结果
|
||||
# Use where operation to combine results
|
||||
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||
|
||||
return result
|
||||
@ -488,15 +488,26 @@ class Path:
|
||||
|
||||
@dataclass
|
||||
class BeamState:
|
||||
"""beam search中的状态"""
|
||||
token_index: int # 当前token位置
|
||||
time_index: int # 当前时间步
|
||||
score: float # 累积分数
|
||||
path: List[Point] # 路径历史
|
||||
"""State in beam search."""
|
||||
token_index: int # Current token position
|
||||
time_index: int # Current time step
|
||||
score: float # Cumulative score
|
||||
path: List[Point] # Path history
|
||||
|
||||
|
||||
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
"""标准CTC beam search回溯实现
|
||||
"""Standard CTC beam search backtracking implementation.
|
||||
|
||||
Args:
|
||||
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||
and N is the number of tokens (including the blank token).
|
||||
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||
tokens (List[int]): List of token indices (excluding the blank token).
|
||||
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
List[Point]: the best path
|
||||
"""
|
||||
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||
|
||||
@ -524,7 +535,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
stay_score = trellis[t - 1, j]
|
||||
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||
|
||||
# Stay路径
|
||||
# Stay
|
||||
if not math.isinf(stay_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||
@ -535,7 +546,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# Change路径
|
||||
# Change
|
||||
if j > 0 and not math.isinf(change_score):
|
||||
new_path = beam.path.copy()
|
||||
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||
@ -546,7 +557,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||
path=new_path
|
||||
))
|
||||
|
||||
# 只按分数排序,不需要去重
|
||||
# sort by score
|
||||
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||
|
||||
if not beams:
|
||||
|
Reference in New Issue
Block a user