change the docstrings and comments to English

This commit is contained in:
liupeng
2025-01-13 22:56:48 +08:00
parent 289eadfc76
commit ffbc73664c

View File

@ -403,33 +403,33 @@ def get_trellis(emission, tokens, blank_id=0):
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""处理包含通配符的token emission分数(向量化版本)
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: 当前帧的emission概率向量
tokens: token索引列表
blank_id: blank token的ID
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: 每个token位置的最大概率分数
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)
# 将tokens转换为tensor如果还不是的话
# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# 创建mask来标识通配符位置
# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)
# 为非通配符位置获取分数
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# 创建掩码并计算最大值,不会修改frame_emission
max_valid_score = frame_emission.clone() # 创建副本
max_valid_score[blank_id] = float('-inf') # 在副本上操作
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
# 使用where操作来组合结果
# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@ -488,15 +488,26 @@ class Path:
@dataclass
class BeamState:
"""beam search中的状态"""
token_index: int # 当前token位置
time_index: int # 当前时间步
score: float # 累积分数
path: List[Point] # 路径历史
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""标准CTC beam search回溯实现
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
@ -524,7 +535,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay路径
# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
@ -535,7 +546,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
path=new_path
))
# Change路径
# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
@ -546,7 +557,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
path=new_path
))
# 只按分数排序,不需要去重
# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams: