mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
630 lines
23 KiB
Python
630 lines
23 KiB
Python
"""
|
|
Forced Alignment with Whisper
|
|
C. Max Bain
|
|
"""
|
|
import math
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Iterable, Optional, Union, List
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torchaudio
|
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
|
|
from whisperx.audio import SAMPLE_RATE, load_audio
|
|
from whisperx.utils import interpolate_nans
|
|
from whisperx.types import (
|
|
AlignedTranscriptionResult,
|
|
SingleSegment,
|
|
SingleAlignedSegment,
|
|
SingleWordSegment,
|
|
SegmentData,
|
|
)
|
|
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
|
|
|
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
|
|
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
|
|
|
DEFAULT_ALIGN_MODELS_TORCH = {
|
|
"en": "WAV2VEC2_ASR_BASE_960H",
|
|
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
|
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
|
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
|
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
|
}
|
|
|
|
DEFAULT_ALIGN_MODELS_HF = {
|
|
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
|
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
|
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
|
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
|
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
|
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
|
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
|
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
|
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
|
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
|
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
|
|
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
|
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
|
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
|
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
|
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
|
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
|
|
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
|
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
|
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
|
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
|
|
"ca": "softcatala/wav2vec2-large-xlsr-catala",
|
|
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
|
|
"no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2",
|
|
"nn": "NbAiLab/nb-wav2vec2-1b-nynorsk",
|
|
"sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
|
|
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
|
|
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
|
"ro": "gigant/romanian-wav2vec2",
|
|
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
|
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
|
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
|
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
|
|
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
|
|
}
|
|
|
|
|
|
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
|
|
if model_name is None:
|
|
# use default model
|
|
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
|
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
|
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
|
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
|
else:
|
|
print(f"There is no default alignment model set for this language ({language_code}).\
|
|
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
|
raise ValueError(f"No default align-model for language: {language_code}")
|
|
|
|
if model_name in torchaudio.pipelines.__all__:
|
|
pipeline_type = "torchaudio"
|
|
bundle = torchaudio.pipelines.__dict__[model_name]
|
|
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
|
|
labels = bundle.get_labels()
|
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
|
else:
|
|
try:
|
|
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
|
|
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
|
|
except Exception as e:
|
|
print(e)
|
|
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
|
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
|
pipeline_type = "huggingface"
|
|
align_model = align_model.to(device)
|
|
labels = processor.tokenizer.get_vocab()
|
|
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
|
|
|
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
|
|
|
return align_model, align_metadata
|
|
|
|
|
|
def align(
|
|
transcript: Iterable[SingleSegment],
|
|
model: torch.nn.Module,
|
|
align_model_metadata: dict,
|
|
audio: Union[str, np.ndarray, torch.Tensor],
|
|
device: str,
|
|
interpolate_method: str = "nearest",
|
|
return_char_alignments: bool = False,
|
|
print_progress: bool = False,
|
|
combined_progress: bool = False,
|
|
) -> AlignedTranscriptionResult:
|
|
"""
|
|
Align phoneme recognition predictions to known transcription.
|
|
"""
|
|
|
|
if not torch.is_tensor(audio):
|
|
if isinstance(audio, str):
|
|
audio = load_audio(audio)
|
|
audio = torch.from_numpy(audio)
|
|
if len(audio.shape) == 1:
|
|
audio = audio.unsqueeze(0)
|
|
|
|
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
|
|
|
model_dictionary = align_model_metadata["dictionary"]
|
|
model_lang = align_model_metadata["language"]
|
|
model_type = align_model_metadata["type"]
|
|
|
|
# 1. Preprocess to keep only characters in dictionary
|
|
total_segments = len(transcript)
|
|
# Store temporary processing values
|
|
segment_data: dict[int, SegmentData] = {}
|
|
for sdx, segment in enumerate(transcript):
|
|
# strip spaces at beginning / end, but keep track of the amount.
|
|
if print_progress:
|
|
base_progress = ((sdx + 1) / total_segments) * 100
|
|
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
|
print(f"Progress: {percent_complete:.2f}%...")
|
|
|
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
|
text = segment["text"]
|
|
|
|
# split into words
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
per_word = text.split(" ")
|
|
else:
|
|
per_word = text
|
|
|
|
clean_char, clean_cdx = [], []
|
|
for cdx, char in enumerate(text):
|
|
char_ = char.lower()
|
|
# wav2vec2 models use "|" character to represent spaces
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
char_ = char_.replace(" ", "|")
|
|
|
|
# ignore whitespace at beginning and end of transcript
|
|
if cdx < num_leading:
|
|
pass
|
|
elif cdx > len(text) - num_trailing - 1:
|
|
pass
|
|
elif char_ in model_dictionary.keys():
|
|
clean_char.append(char_)
|
|
clean_cdx.append(cdx)
|
|
else:
|
|
# add placeholder
|
|
clean_char.append('*')
|
|
clean_cdx.append(cdx)
|
|
|
|
clean_wdx = []
|
|
for wdx, wrd in enumerate(per_word):
|
|
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
|
clean_wdx.append(wdx)
|
|
else:
|
|
# index for placeholder
|
|
clean_wdx.append(wdx)
|
|
|
|
|
|
punkt_param = PunktParameters()
|
|
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
|
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
|
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
|
|
|
segment_data[sdx] = {
|
|
"clean_char": clean_char,
|
|
"clean_cdx": clean_cdx,
|
|
"clean_wdx": clean_wdx,
|
|
"sentence_spans": sentence_spans
|
|
}
|
|
|
|
aligned_segments: List[SingleAlignedSegment] = []
|
|
|
|
# 2. Get prediction matrix from alignment model & align
|
|
for sdx, segment in enumerate(transcript):
|
|
|
|
t1 = segment["start"]
|
|
t2 = segment["end"]
|
|
text = segment["text"]
|
|
|
|
aligned_seg: SingleAlignedSegment = {
|
|
"start": t1,
|
|
"end": t2,
|
|
"text": text,
|
|
"words": [],
|
|
"chars": None,
|
|
}
|
|
|
|
if return_char_alignments:
|
|
aligned_seg["chars"] = []
|
|
|
|
# check we can align
|
|
if len(segment_data[sdx]["clean_char"]) == 0:
|
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
|
aligned_segments.append(aligned_seg)
|
|
continue
|
|
|
|
if t1 >= MAX_DURATION:
|
|
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
|
aligned_segments.append(aligned_seg)
|
|
continue
|
|
|
|
text_clean = "".join(segment_data[sdx]["clean_char"])
|
|
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
|
|
|
f1 = int(t1 * SAMPLE_RATE)
|
|
f2 = int(t2 * SAMPLE_RATE)
|
|
|
|
# TODO: Probably can get some speedup gain with batched inference here
|
|
waveform_segment = audio[:, f1:f2]
|
|
# Handle the minimum input length for wav2vec2 models
|
|
if waveform_segment.shape[-1] < 400:
|
|
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
|
waveform_segment = torch.nn.functional.pad(
|
|
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
|
)
|
|
else:
|
|
lengths = None
|
|
|
|
with torch.inference_mode():
|
|
if model_type == "torchaudio":
|
|
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
|
elif model_type == "huggingface":
|
|
emissions = model(waveform_segment.to(device)).logits
|
|
else:
|
|
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
|
emissions = torch.log_softmax(emissions, dim=-1)
|
|
|
|
emission = emissions[0].cpu().detach()
|
|
|
|
blank_id = 0
|
|
for char, code in model_dictionary.items():
|
|
if char == '[pad]' or char == '<pad>':
|
|
blank_id = code
|
|
|
|
trellis = get_trellis(emission, tokens, blank_id)
|
|
# path = backtrack(trellis, emission, tokens, blank_id)
|
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
|
|
|
if path is None:
|
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
|
aligned_segments.append(aligned_seg)
|
|
continue
|
|
|
|
char_segments = merge_repeats(path, text_clean)
|
|
|
|
duration = t2 - t1
|
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
|
|
|
# assign timestamps to aligned characters
|
|
char_segments_arr = []
|
|
word_idx = 0
|
|
for cdx, char in enumerate(text):
|
|
start, end, score = None, None, None
|
|
if cdx in segment_data[sdx]["clean_cdx"]:
|
|
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
|
start = round(char_seg.start * ratio + t1, 3)
|
|
end = round(char_seg.end * ratio + t1, 3)
|
|
score = round(char_seg.score, 3)
|
|
|
|
char_segments_arr.append(
|
|
{
|
|
"char": char,
|
|
"start": start,
|
|
"end": end,
|
|
"score": score,
|
|
"word-idx": word_idx,
|
|
}
|
|
)
|
|
|
|
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
word_idx += 1
|
|
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
|
word_idx += 1
|
|
|
|
char_segments_arr = pd.DataFrame(char_segments_arr)
|
|
|
|
aligned_subsegments = []
|
|
# assign sentence_idx to each character index
|
|
char_segments_arr["sentence-idx"] = None
|
|
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
|
|
|
sentence_text = text[sstart:send]
|
|
sentence_start = curr_chars["start"].min()
|
|
end_chars = curr_chars[curr_chars["char"] != ' ']
|
|
sentence_end = end_chars["end"].max()
|
|
sentence_words = []
|
|
|
|
for word_idx in curr_chars["word-idx"].unique():
|
|
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
|
word_text = "".join(word_chars["char"].tolist()).strip()
|
|
if len(word_text) == 0:
|
|
continue
|
|
|
|
# dont use space character for alignment
|
|
word_chars = word_chars[word_chars["char"] != " "]
|
|
|
|
word_start = word_chars["start"].min()
|
|
word_end = word_chars["end"].max()
|
|
word_score = round(word_chars["score"].mean(), 3)
|
|
|
|
# -1 indicates unalignable
|
|
word_segment = {"word": word_text}
|
|
|
|
if not np.isnan(word_start):
|
|
word_segment["start"] = word_start
|
|
if not np.isnan(word_end):
|
|
word_segment["end"] = word_end
|
|
if not np.isnan(word_score):
|
|
word_segment["score"] = word_score
|
|
|
|
sentence_words.append(word_segment)
|
|
|
|
aligned_subsegments.append({
|
|
"text": sentence_text,
|
|
"start": sentence_start,
|
|
"end": sentence_end,
|
|
"words": sentence_words,
|
|
})
|
|
|
|
if return_char_alignments:
|
|
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
|
curr_chars.fillna(-1, inplace=True)
|
|
curr_chars = curr_chars.to_dict("records")
|
|
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
|
aligned_subsegments[-1]["chars"] = curr_chars
|
|
|
|
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
|
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
|
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
|
# concatenate sentences with same timestamps
|
|
agg_dict = {"text": " ".join, "words": "sum"}
|
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
agg_dict["text"] = "".join
|
|
if return_char_alignments:
|
|
agg_dict["chars"] = "sum"
|
|
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
|
aligned_subsegments = aligned_subsegments.to_dict('records')
|
|
aligned_segments += aligned_subsegments
|
|
|
|
# create word_segments list
|
|
word_segments: List[SingleWordSegment] = []
|
|
for segment in aligned_segments:
|
|
word_segments += segment["words"]
|
|
|
|
return {"segments": aligned_segments, "word_segments": word_segments}
|
|
|
|
"""
|
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
|
"""
|
|
|
|
|
|
def get_trellis(emission, tokens, blank_id=0):
|
|
num_frame = emission.size(0)
|
|
num_tokens = len(tokens)
|
|
|
|
trellis = torch.zeros((num_frame, num_tokens))
|
|
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
|
trellis[0, 1:] = -float("inf")
|
|
trellis[-num_tokens + 1:, 0] = float("inf")
|
|
|
|
for t in range(num_frame - 1):
|
|
trellis[t + 1, 1:] = torch.maximum(
|
|
# Score for staying at the same token
|
|
trellis[t, 1:] + emission[t, blank_id],
|
|
# Score for changing to the next token
|
|
# trellis[t, :-1] + emission[t, tokens[1:]],
|
|
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
|
)
|
|
return trellis
|
|
|
|
|
|
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
|
"""Processing token emission scores containing wildcards (vectorized version)
|
|
|
|
Args:
|
|
frame_emission: Emission probability vector for the current frame
|
|
tokens: List of token indices
|
|
blank_id: ID of the blank token
|
|
|
|
Returns:
|
|
tensor: Maximum probability score for each token position
|
|
"""
|
|
assert 0 <= blank_id < len(frame_emission)
|
|
|
|
# Convert tokens to a tensor if they are not already
|
|
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
|
|
|
# Create a mask to identify wildcard positions
|
|
wildcard_mask = (tokens == -1)
|
|
|
|
# Get scores for non-wildcard positions
|
|
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
|
|
|
# Create a mask and compute the maximum value without modifying frame_emission
|
|
max_valid_score = frame_emission.clone() # Create a copy
|
|
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
|
max_valid_score = max_valid_score.max()
|
|
|
|
# Use where operation to combine results
|
|
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
|
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class Point:
|
|
token_index: int
|
|
time_index: int
|
|
score: float
|
|
|
|
|
|
def backtrack(trellis, emission, tokens, blank_id=0):
|
|
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
|
|
|
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
|
while j > 0:
|
|
# Should not happen but just in case
|
|
assert t > 0
|
|
|
|
# 1. Figure out if the current position was stay or change
|
|
# Frame-wise score of stay vs change
|
|
p_stay = emission[t - 1, blank_id]
|
|
# p_change = emission[t - 1, tokens[j]]
|
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
|
|
|
# Context-aware score for stay vs change
|
|
stayed = trellis[t - 1, j] + p_stay
|
|
changed = trellis[t - 1, j - 1] + p_change
|
|
|
|
# Update position
|
|
t -= 1
|
|
if changed > stayed:
|
|
j -= 1
|
|
|
|
# Store the path with frame-wise probability.
|
|
prob = (p_change if changed > stayed else p_stay).exp().item()
|
|
path.append(Point(j, t, prob))
|
|
|
|
# Now j == 0, which means, it reached the SoS.
|
|
# Fill up the rest for the sake of visualization
|
|
while t > 0:
|
|
prob = emission[t - 1, blank_id].exp().item()
|
|
path.append(Point(j, t - 1, prob))
|
|
t -= 1
|
|
|
|
return path[::-1]
|
|
|
|
|
|
|
|
@dataclass
|
|
class Path:
|
|
points: List[Point]
|
|
score: float
|
|
|
|
|
|
@dataclass
|
|
class BeamState:
|
|
"""State in beam search."""
|
|
token_index: int # Current token position
|
|
time_index: int # Current time step
|
|
score: float # Cumulative score
|
|
path: List[Point] # Path history
|
|
|
|
|
|
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
|
"""Standard CTC beam search backtracking implementation.
|
|
|
|
Args:
|
|
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
|
and N is the number of tokens (including the blank token).
|
|
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
|
tokens (List[int]): List of token indices (excluding the blank token).
|
|
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
|
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
|
|
|
Returns:
|
|
List[Point]: the best path
|
|
"""
|
|
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
|
|
|
init_state = BeamState(
|
|
token_index=J,
|
|
time_index=T,
|
|
score=trellis[T, J],
|
|
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
|
)
|
|
|
|
beams = [init_state]
|
|
|
|
while beams and beams[0].token_index > 0:
|
|
next_beams = []
|
|
|
|
for beam in beams:
|
|
t, j = beam.time_index, beam.token_index
|
|
|
|
if t <= 0:
|
|
continue
|
|
|
|
p_stay = emission[t - 1, blank_id]
|
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
|
|
|
stay_score = trellis[t - 1, j]
|
|
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
|
|
|
# Stay
|
|
if not math.isinf(stay_score):
|
|
new_path = beam.path.copy()
|
|
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
|
next_beams.append(BeamState(
|
|
token_index=j,
|
|
time_index=t - 1,
|
|
score=stay_score,
|
|
path=new_path
|
|
))
|
|
|
|
# Change
|
|
if j > 0 and not math.isinf(change_score):
|
|
new_path = beam.path.copy()
|
|
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
|
next_beams.append(BeamState(
|
|
token_index=j - 1,
|
|
time_index=t - 1,
|
|
score=change_score,
|
|
path=new_path
|
|
))
|
|
|
|
# sort by score
|
|
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
|
|
|
if not beams:
|
|
break
|
|
|
|
if not beams:
|
|
return None
|
|
|
|
best_beam = beams[0]
|
|
t = best_beam.time_index
|
|
j = best_beam.token_index
|
|
while t > 0:
|
|
prob = emission[t - 1, blank_id].exp().item()
|
|
best_beam.path.append(Point(j, t - 1, prob))
|
|
t -= 1
|
|
|
|
return best_beam.path[::-1]
|
|
|
|
|
|
# Merge the labels
|
|
@dataclass
|
|
class Segment:
|
|
label: str
|
|
start: int
|
|
end: int
|
|
score: float
|
|
|
|
def __repr__(self):
|
|
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
|
|
|
|
@property
|
|
def length(self):
|
|
return self.end - self.start
|
|
|
|
def merge_repeats(path, transcript):
|
|
i1, i2 = 0, 0
|
|
segments = []
|
|
while i1 < len(path):
|
|
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
|
|
i2 += 1
|
|
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
|
segments.append(
|
|
Segment(
|
|
transcript[path[i1].token_index],
|
|
path[i1].time_index,
|
|
path[i2 - 1].time_index + 1,
|
|
score,
|
|
)
|
|
)
|
|
i1 = i2
|
|
return segments
|
|
|
|
def merge_words(segments, separator="|"):
|
|
words = []
|
|
i1, i2 = 0, 0
|
|
while i1 < len(segments):
|
|
if i2 >= len(segments) or segments[i2].label == separator:
|
|
if i1 != i2:
|
|
segs = segments[i1:i2]
|
|
word = "".join([seg.label for seg in segs])
|
|
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
|
|
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
|
|
i1 = i2 + 1
|
|
i2 = i1
|
|
else:
|
|
i2 += 1
|
|
return words
|