Update decoding.py

Changes from https://github.com/openai/whisper/pull/914/
This commit is contained in:
Fernando O. Gallego
2023-03-24 11:56:41 +01:00
committed by GitHub
parent d1b4ff8228
commit 33dd3b9bcd

View File

@ -413,7 +413,8 @@ class ApplyTimestampRules(LogitFilter):
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
@ -422,6 +423,11 @@ class ApplyTimestampRules(LogitFilter):
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning