change the docstrings and comments to English

This commit is contained in:
liupeng
2025-01-13 22:56:48 +08:00
parent 289eadfc76
commit ffbc73664c

View File

@ -403,33 +403,33 @@ def get_trellis(emission, tokens, blank_id=0):
def get_wildcard_emission(frame_emission, tokens, blank_id): def get_wildcard_emission(frame_emission, tokens, blank_id):
"""处理包含通配符的token emission分数(向量化版本) """Processing token emission scores containing wildcards (vectorized version)
Args: Args:
frame_emission: 当前帧的emission概率向量 frame_emission: Emission probability vector for the current frame
tokens: token索引列表 tokens: List of token indices
blank_id: blank token的ID blank_id: ID of the blank token
Returns: Returns:
tensor: 每个token位置的最大概率分数 tensor: Maximum probability score for each token position
""" """
assert 0 <= blank_id < len(frame_emission) assert 0 <= blank_id < len(frame_emission)
# 将tokens转换为tensor如果还不是的话 # Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# 创建mask来标识通配符位置 # Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1) wildcard_mask = (tokens == -1)
# 为非通配符位置获取分数 # Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp避免-1索引 regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# 创建掩码并计算最大值,不会修改frame_emission # Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # 创建副本 max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # 在副本上操作 max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max() max_valid_score = max_valid_score.max()
# 使用where操作来组合结果 # Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores) result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result return result
@ -488,15 +488,26 @@ class Path:
@dataclass @dataclass
class BeamState: class BeamState:
"""beam search中的状态""" """State in beam search."""
token_index: int # 当前token位置 token_index: int # Current token position
time_index: int # 当前时间步 time_index: int # Current time step
score: float # 累积分数 score: float # Cumulative score
path: List[Point] # 路径历史 path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""标准CTC beam search回溯实现 """Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
""" """
T, J = trellis.size(0) - 1, trellis.size(1) - 1 T, J = trellis.size(0) - 1, trellis.size(1) - 1
@ -524,7 +535,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
stay_score = trellis[t - 1, j] stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay路径 # Stay
if not math.isinf(stay_score): if not math.isinf(stay_score):
new_path = beam.path.copy() new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item())) new_path.append(Point(j, t - 1, p_stay.exp().item()))
@ -535,7 +546,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
path=new_path path=new_path
)) ))
# Change路径 # Change
if j > 0 and not math.isinf(change_score): if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy() new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item())) new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
@ -546,7 +557,7 @@ def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
path=new_path path=new_path
)) ))
# 只按分数排序,不需要去重 # sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams: if not beams: