Fix: Ensure integer tensor indexing in get_wildcard_emission()

This commit is contained in:
Howard
2025-05-15 14:30:55 +08:00
committed by Barabazs
parent ffedc5cdf0
commit e0833da5dc

View File

@ -424,7 +424,7 @@ def get_wildcard_emission(frame_emission, tokens, blank_id):
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy