mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
![]() |
"""
|
||
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||
|
"""
|
||
|
import torch
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
def get_trellis(emission, tokens, blank_id=0):
|
||
|
num_frame = emission.size(0)
|
||
|
num_tokens = len(tokens)
|
||
|
|
||
|
# Trellis has extra diemsions for both time axis and tokens.
|
||
|
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
||
|
# The extra dim for time axis is for simplification of the code.
|
||
|
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
||
|
trellis[0, 0] = 0
|
||
|
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
||
|
trellis[0, -num_tokens:] = -float("inf")
|
||
|
trellis[-num_tokens:, 0] = float("inf")
|
||
|
|
||
|
for t in range(num_frame):
|
||
|
trellis[t + 1, 1:] = torch.maximum(
|
||
|
# Score for staying at the same token
|
||
|
trellis[t, 1:] + emission[t, blank_id],
|
||
|
# Score for changing to the next token
|
||
|
trellis[t, :-1] + emission[t, tokens],
|
||
|
)
|
||
|
return trellis
|
||
|
|
||
|
@dataclass
|
||
|
class Point:
|
||
|
token_index: int
|
||
|
time_index: int
|
||
|
score: float
|
||
|
|
||
|
def backtrack(trellis, emission, tokens, blank_id=0):
|
||
|
# Note:
|
||
|
# j and t are indices for trellis, which has extra dimensions
|
||
|
# for time and tokens at the beginning.
|
||
|
# When referring to time frame index `T` in trellis,
|
||
|
# the corresponding index in emission is `T-1`.
|
||
|
# Similarly, when referring to token index `J` in trellis,
|
||
|
# the corresponding index in transcript is `J-1`.
|
||
|
j = trellis.size(1) - 1
|
||
|
t_start = torch.argmax(trellis[:, j]).item()
|
||
|
|
||
|
path = []
|
||
|
for t in range(t_start, 0, -1):
|
||
|
# 1. Figure out if the current position was stay or change
|
||
|
# Note (again):
|
||
|
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
||
|
# Score for token staying the same from time frame J-1 to T.
|
||
|
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
||
|
# Score for token changing from C-1 at T-1 to J at T.
|
||
|
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
||
|
|
||
|
# 2. Store the path with frame-wise probability.
|
||
|
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
||
|
# Return token index and time index in non-trellis coordinate.
|
||
|
path.append(Point(j - 1, t - 1, prob))
|
||
|
|
||
|
# 3. Update the token
|
||
|
if changed > stayed:
|
||
|
j -= 1
|
||
|
if j == 0:
|
||
|
break
|
||
|
else:
|
||
|
raise ValueError("Failed to align")
|
||
|
return path[::-1]
|
||
|
|
||
|
# Merge the labels
|
||
|
@dataclass
|
||
|
class Segment:
|
||
|
label: str
|
||
|
start: int
|
||
|
end: int
|
||
|
score: float
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
|
||
|
|
||
|
@property
|
||
|
def length(self):
|
||
|
return self.end - self.start
|
||
|
|
||
|
def merge_repeats(path, transcript):
|
||
|
i1, i2 = 0, 0
|
||
|
segments = []
|
||
|
while i1 < len(path):
|
||
|
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
|
||
|
i2 += 1
|
||
|
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
||
|
segments.append(
|
||
|
Segment(
|
||
|
transcript[path[i1].token_index],
|
||
|
path[i1].time_index,
|
||
|
path[i2 - 1].time_index + 1,
|
||
|
score,
|
||
|
)
|
||
|
)
|
||
|
i1 = i2
|
||
|
return segments
|
||
|
|
||
|
def merge_words(segments, separator="|"):
|
||
|
words = []
|
||
|
i1, i2 = 0, 0
|
||
|
while i1 < len(segments):
|
||
|
if i2 >= len(segments) or segments[i2].label == separator:
|
||
|
if i1 != i2:
|
||
|
segs = segments[i1:i2]
|
||
|
word = "".join([seg.label for seg in segs])
|
||
|
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
|
||
|
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
|
||
|
i1 = i2 + 1
|
||
|
i2 = i1
|
||
|
else:
|
||
|
i2 += 1
|
||
|
return words
|