2025-01-13 08:28:27 +01:00
"""
2023-01-25 18:42:52 +00:00
Forced Alignment with Whisper
C . Max Bain
2022-12-14 18:59:12 +00:00
"""
2025-01-09 15:23:40 +08:00
import math
2023-04-24 21:08:43 +01:00
from dataclasses import dataclass
2025-01-05 11:26:18 +01:00
from typing import Iterable , Optional , Union , List
2023-04-24 21:08:43 +01:00
2023-01-25 18:42:52 +00:00
import numpy as np
import pandas as pd
2022-12-14 18:59:12 +00:00
import torch
2023-04-24 21:08:43 +01:00
import torchaudio
from transformers import Wav2Vec2ForCTC , Wav2Vec2Processor
2023-01-25 18:42:52 +00:00
2025-03-25 16:13:55 +01:00
from whisperx . audio import SAMPLE_RATE , load_audio
from whisperx . utils import interpolate_nans
from whisperx . types import (
2025-01-13 09:27:33 +01:00
AlignedTranscriptionResult ,
SingleSegment ,
SingleAlignedSegment ,
SingleWordSegment ,
SegmentData ,
)
2023-05-29 12:48:14 +01:00
from nltk . tokenize . punkt import PunktSentenceTokenizer , PunktParameters
PUNKT_ABBREVIATIONS = [ ' dr ' , ' vs ' , ' mr ' , ' mrs ' , ' prof ' ]
2023-01-25 18:42:52 +00:00
LANGUAGES_WITHOUT_SPACES = [ " ja " , " zh " ]
DEFAULT_ALIGN_MODELS_TORCH = {
" en " : " WAV2VEC2_ASR_BASE_960H " ,
" fr " : " VOXPOPULI_ASR_BASE_10K_FR " ,
" de " : " VOXPOPULI_ASR_BASE_10K_DE " ,
" es " : " VOXPOPULI_ASR_BASE_10K_ES " ,
" it " : " VOXPOPULI_ASR_BASE_10K_IT " ,
}
DEFAULT_ALIGN_MODELS_HF = {
" ja " : " jonatasgrosman/wav2vec2-large-xlsr-53-japanese " ,
" zh " : " jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn " ,
" nl " : " jonatasgrosman/wav2vec2-large-xlsr-53-dutch " ,
" uk " : " Yehor/wav2vec2-xls-r-300m-uk-with-small-lm " ,
" pt " : " jonatasgrosman/wav2vec2-large-xlsr-53-portuguese " ,
2023-01-31 19:32:31 +02:00
" ar " : " jonatasgrosman/wav2vec2-large-xlsr-53-arabic " ,
2023-05-26 21:17:01 +02:00
" cs " : " comodoro/wav2vec2-xls-r-300m-cs-250 " ,
2023-02-01 03:02:10 +02:00
" ru " : " jonatasgrosman/wav2vec2-large-xlsr-53-russian " ,
" pl " : " jonatasgrosman/wav2vec2-large-xlsr-53-polish " ,
" hu " : " jonatasgrosman/wav2vec2-large-xlsr-53-hungarian " ,
" fi " : " jonatasgrosman/wav2vec2-large-xlsr-53-finnish " ,
" fa " : " jonatasgrosman/wav2vec2-large-xlsr-53-persian " ,
" el " : " jonatasgrosman/wav2vec2-large-xlsr-53-greek " ,
2023-02-01 21:38:50 +02:00
" tr " : " mpoyraz/wav2vec2-xls-r-300m-cv7-turkish " ,
2023-05-09 23:10:13 +01:00
" da " : " saattrupdan/wav2vec2-xls-r-300m-ftspeech " ,
2023-05-03 11:26:12 -05:00
" he " : " imvladikon/wav2vec2-xls-r-300m-hebrew " ,
2023-05-26 17:14:09 +01:00
" vi " : ' nguyenvulebinh/wav2vec2-base-vi ' ,
2023-05-26 20:33:16 +09:00
" ko " : " kresnik/wav2vec2-large-xlsr-korean " ,
2023-07-24 10:47:41 +01:00
" ur " : " kingabzpro/wav2vec2-large-xls-r-300m-Urdu " ,
2023-08-10 12:13:52 +01:00
" te " : " anuragshas/wav2vec2-large-xlsr-53-telugu " ,
2023-11-16 11:43:36 +01:00
" hi " : " theainerd/Wav2Vec2-large-xlsr-hindi " ,
2023-11-17 05:18:19 +05:30
" ca " : " softcatala/wav2vec2-large-xlsr-catala " ,
2023-11-17 05:21:23 +05:30
" ml " : " gvs/wav2vec2-large-xlsr-malayalam " ,
2024-12-16 11:08:48 +01:00
" no " : " NbAiLab/nb-wav2vec2-1b-bokmaal-v2 " ,
" nn " : " NbAiLab/nb-wav2vec2-1b-nynorsk " ,
2024-08-07 10:05:17 +02:00
" sk " : " comodoro/wav2vec2-xls-r-300m-sk-cv8 " ,
" sl " : " anton-l/wav2vec2-large-xlsr-53-slovenian " ,
2024-08-08 08:37:55 +02:00
" hr " : " classla/wav2vec2-xls-r-parlaspeech-hr " ,
2024-12-16 08:09:53 +01:00
" ro " : " gigant/romanian-wav2vec2 " ,
2024-12-16 11:06:43 +01:00
" eu " : " stefan-it/wav2vec2-large-xlsr-53-basque " ,
" gl " : " ifrz/wav2vec2-large-xlsr-galician " ,
" ka " : " xsway/wav2vec2-large-xlsr-georgian " ,
2025-01-24 22:24:15 +02:00
" lv " : " jimregan/wav2vec2-large-xlsr-latvian-cv " ,
2025-02-23 16:59:48 +08:00
" tl " : " Khalsuu/filipino-wav2vec2-l-xls-r-300m-official " ,
2023-01-25 18:42:52 +00:00
}
2025-01-05 11:26:18 +01:00
def load_align_model ( language_code : str , device : str , model_name : Optional [ str ] = None , model_dir = None ) :
2023-01-25 18:42:52 +00:00
if model_name is None :
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH :
model_name = DEFAULT_ALIGN_MODELS_TORCH [ language_code ]
elif language_code in DEFAULT_ALIGN_MODELS_HF :
model_name = DEFAULT_ALIGN_MODELS_HF [ language_code ]
else :
print ( f " There is no default alignment model set for this language ( { language_code } ). \
Please find a wav2vec2 .0 model finetuned on this language in https : / / huggingface . co / models , then pass the model name in - - align_model [ MODEL_NAME ] " )
raise ValueError ( f " No default align-model for language: { language_code } " )
if model_name in torchaudio . pipelines . __all__ :
pipeline_type = " torchaudio "
bundle = torchaudio . pipelines . __dict__ [ model_name ]
2023-04-14 21:40:36 +01:00
align_model = bundle . get_model ( dl_kwargs = { " model_dir " : model_dir } ) . to ( device )
2023-01-25 18:42:52 +00:00
labels = bundle . get_labels ( )
align_dictionary = { c . lower ( ) : i for i , c in enumerate ( labels ) }
else :
try :
2025-01-01 14:22:27 +02:00
processor = Wav2Vec2Processor . from_pretrained ( model_name , cache_dir = model_dir )
align_model = Wav2Vec2ForCTC . from_pretrained ( model_name , cache_dir = model_dir )
2023-01-25 18:42:52 +00:00
except Exception as e :
print ( e )
print ( f " Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models " )
raise ValueError ( f ' The chosen align_model " { model_name } " could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14) ' )
pipeline_type = " huggingface "
align_model = align_model . to ( device )
labels = processor . tokenizer . get_vocab ( )
align_dictionary = { char . lower ( ) : code for char , code in processor . tokenizer . get_vocab ( ) . items ( ) }
align_metadata = { " language " : language_code , " dictionary " : align_dictionary , " type " : pipeline_type }
return align_model , align_metadata
def align (
2023-09-25 15:33:06 +09:00
transcript : Iterable [ SingleSegment ] ,
2023-01-25 18:42:52 +00:00
model : torch . nn . Module ,
align_model_metadata : dict ,
audio : Union [ str , np . ndarray , torch . Tensor ] ,
device : str ,
interpolate_method : str = " nearest " ,
2023-05-07 15:32:58 +01:00
return_char_alignments : bool = False ,
2023-08-17 14:57:53 +02:00
print_progress : bool = False ,
combined_progress : bool = False ,
2023-05-08 20:45:34 +02:00
) - > AlignedTranscriptionResult :
2023-01-25 18:42:52 +00:00
"""
2023-05-07 15:32:58 +01:00
Align phoneme recognition predictions to known transcription .
2023-01-25 18:42:52 +00:00
"""
2023-05-07 15:32:58 +01:00
2023-01-25 18:42:52 +00:00
if not torch . is_tensor ( audio ) :
if isinstance ( audio , str ) :
audio = load_audio ( audio )
audio = torch . from_numpy ( audio )
if len ( audio . shape ) == 1 :
audio = audio . unsqueeze ( 0 )
2023-05-07 15:32:58 +01:00
2023-01-25 18:42:52 +00:00
MAX_DURATION = audio . shape [ 1 ] / SAMPLE_RATE
model_dictionary = align_model_metadata [ " dictionary " ]
model_lang = align_model_metadata [ " language " ]
model_type = align_model_metadata [ " type " ]
2023-05-07 15:32:58 +01:00
# 1. Preprocess to keep only characters in dictionary
2023-09-27 20:10:43 +05:30
total_segments = len ( transcript )
2025-01-13 09:13:30 +01:00
# Store temporary processing values
2025-01-13 09:27:33 +01:00
segment_data : dict [ int , SegmentData ] = { }
2023-01-25 18:42:52 +00:00
for sdx , segment in enumerate ( transcript ) :
2023-05-07 15:32:58 +01:00
# strip spaces at beginning / end, but keep track of the amount.
2023-08-16 16:18:00 +02:00
if print_progress :
2023-08-17 14:53:53 +02:00
base_progress = ( ( sdx + 1 ) / total_segments ) * 100
percent_complete = ( 50 + base_progress / 2 ) if combined_progress else base_progress
2023-08-16 16:18:00 +02:00
print ( f " Progress: { percent_complete : .2f } %... " )
2023-05-07 15:32:58 +01:00
num_leading = len ( segment [ " text " ] ) - len ( segment [ " text " ] . lstrip ( ) )
num_trailing = len ( segment [ " text " ] ) - len ( segment [ " text " ] . rstrip ( ) )
text = segment [ " text " ]
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES :
per_word = text . split ( " " )
else :
per_word = text
2023-01-25 18:42:52 +00:00
2023-05-07 15:32:58 +01:00
clean_char , clean_cdx = [ ] , [ ]
for cdx , char in enumerate ( text ) :
char_ = char . lower ( )
# wav2vec2 models use "|" character to represent spaces
2023-01-25 18:42:52 +00:00
if model_lang not in LANGUAGES_WITHOUT_SPACES :
2023-05-07 15:32:58 +01:00
char_ = char_ . replace ( " " , " | " )
2023-01-25 18:42:52 +00:00
2023-05-07 15:32:58 +01:00
# ignore whitespace at beginning and end of transcript
if cdx < num_leading :
pass
elif cdx > len ( text ) - num_trailing - 1 :
pass
elif char_ in model_dictionary . keys ( ) :
clean_char . append ( char_ )
clean_cdx . append ( cdx )
2025-01-09 15:23:40 +08:00
else :
# add placeholder
clean_char . append ( ' * ' )
clean_cdx . append ( cdx )
2023-05-07 15:32:58 +01:00
clean_wdx = [ ]
for wdx , wrd in enumerate ( per_word ) :
2025-01-09 15:23:40 +08:00
if any ( [ c in model_dictionary . keys ( ) for c in wrd . lower ( ) ] ) :
clean_wdx . append ( wdx )
else :
# index for placeholder
2023-05-07 15:32:58 +01:00
clean_wdx . append ( wdx )
2023-05-29 12:48:14 +01:00
punkt_param = PunktParameters ( )
punkt_param . abbrev_types = set ( PUNKT_ABBREVIATIONS )
sentence_splitter = PunktSentenceTokenizer ( punkt_param )
sentence_spans = list ( sentence_splitter . span_tokenize ( text ) )
2023-05-07 15:32:58 +01:00
2025-01-13 09:13:30 +01:00
segment_data [ sdx ] = {
" clean_char " : clean_char ,
" clean_cdx " : clean_cdx ,
" clean_wdx " : clean_wdx ,
" sentence_spans " : sentence_spans
}
2023-05-08 20:45:34 +02:00
aligned_segments : List [ SingleAlignedSegment ] = [ ]
2023-08-16 16:13:28 +02:00
2023-05-07 15:32:58 +01:00
# 2. Get prediction matrix from alignment model & align
for sdx , segment in enumerate ( transcript ) :
2023-08-16 16:13:28 +02:00
2023-05-07 15:32:58 +01:00
t1 = segment [ " start " ]
t2 = segment [ " end " ]
text = segment [ " text " ]
2023-05-08 20:45:34 +02:00
aligned_seg : SingleAlignedSegment = {
2023-05-07 15:32:58 +01:00
" start " : t1 ,
" end " : t2 ,
" text " : text ,
" words " : [ ] ,
2025-01-13 08:28:27 +01:00
" chars " : None ,
2023-05-07 15:32:58 +01:00
}
if return_char_alignments :
aligned_seg [ " chars " ] = [ ]
# check we can align
2025-01-13 09:13:30 +01:00
if len ( segment_data [ sdx ] [ " clean_char " ] ) == 0 :
2023-05-07 15:32:58 +01:00
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): no characters in this segment found in model dictionary, resorting to original... ' )
aligned_segments . append ( aligned_seg )
continue
2023-10-15 16:25:15 +03:00
if t1 > = MAX_DURATION :
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): original start time longer than audio duration, skipping... ' )
2023-05-07 15:32:58 +01:00
aligned_segments . append ( aligned_seg )
continue
2025-01-13 09:13:30 +01:00
text_clean = " " . join ( segment_data [ sdx ] [ " clean_char " ] )
2025-01-13 20:26:27 +08:00
tokens = [ model_dictionary . get ( c , - 1 ) for c in text_clean ]
2023-05-07 15:32:58 +01:00
f1 = int ( t1 * SAMPLE_RATE )
f2 = int ( t2 * SAMPLE_RATE )
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio [ : , f1 : f2 ]
2023-10-16 20:43:37 +03:00
# Handle the minimum input length for wav2vec2 models
if waveform_segment . shape [ - 1 ] < 400 :
lengths = torch . as_tensor ( [ waveform_segment . shape [ - 1 ] ] ) . to ( device )
waveform_segment = torch . nn . functional . pad (
waveform_segment , ( 0 , 400 - waveform_segment . shape [ - 1 ] )
)
else :
lengths = None
2023-05-07 15:32:58 +01:00
with torch . inference_mode ( ) :
if model_type == " torchaudio " :
2023-10-06 00:41:23 +03:00
emissions , _ = model ( waveform_segment . to ( device ) , lengths = lengths )
2023-05-07 15:32:58 +01:00
elif model_type == " huggingface " :
emissions = model ( waveform_segment . to ( device ) ) . logits
else :
raise NotImplementedError ( f " Align model of type { model_type } not supported. " )
emissions = torch . log_softmax ( emissions , dim = - 1 )
emission = emissions [ 0 ] . cpu ( ) . detach ( )
blank_id = 0
for char , code in model_dictionary . items ( ) :
if char == ' [pad] ' or char == ' <pad> ' :
blank_id = code
trellis = get_trellis ( emission , tokens , blank_id )
2025-01-09 15:23:40 +08:00
# path = backtrack(trellis, emission, tokens, blank_id)
2025-01-09 19:33:26 +08:00
path = backtrack_beam ( trellis , emission , tokens , blank_id , beam_width = 2 )
2023-05-07 15:32:58 +01:00
if path is None :
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): backtrack failed, resorting to original... ' )
aligned_segments . append ( aligned_seg )
continue
char_segments = merge_repeats ( path , text_clean )
2025-01-09 15:23:40 +08:00
duration = t2 - t1
2023-05-07 15:32:58 +01:00
ratio = duration * waveform_segment . size ( 0 ) / ( trellis . size ( 0 ) - 1 )
# assign timestamps to aligned characters
char_segments_arr = [ ]
word_idx = 0
for cdx , char in enumerate ( text ) :
start , end , score = None , None , None
2025-01-13 09:13:30 +01:00
if cdx in segment_data [ sdx ] [ " clean_cdx " ] :
char_seg = char_segments [ segment_data [ sdx ] [ " clean_cdx " ] . index ( cdx ) ]
2023-05-07 15:32:58 +01:00
start = round ( char_seg . start * ratio + t1 , 3 )
end = round ( char_seg . end * ratio + t1 , 3 )
score = round ( char_seg . score , 3 )
char_segments_arr . append (
{
" char " : char ,
" start " : start ,
" end " : end ,
" score " : score ,
" word-idx " : word_idx ,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES :
word_idx + = 1
elif cdx == len ( text ) - 1 or text [ cdx + 1 ] == " " :
word_idx + = 1
char_segments_arr = pd . DataFrame ( char_segments_arr )
aligned_subsegments = [ ]
# assign sentence_idx to each character index
char_segments_arr [ " sentence-idx " ] = None
2025-01-13 09:13:30 +01:00
for sdx2 , ( sstart , send ) in enumerate ( segment_data [ sdx ] [ " sentence_spans " ] ) :
2023-05-07 15:32:58 +01:00
curr_chars = char_segments_arr . loc [ ( char_segments_arr . index > = sstart ) & ( char_segments_arr . index < = send ) ]
2025-01-13 09:13:30 +01:00
char_segments_arr . loc [ ( char_segments_arr . index > = sstart ) & ( char_segments_arr . index < = send ) , " sentence-idx " ] = sdx2
2023-05-07 15:32:58 +01:00
sentence_text = text [ sstart : send ]
sentence_start = curr_chars [ " start " ] . min ( )
2023-11-03 19:47:00 +01:00
end_chars = curr_chars [ curr_chars [ " char " ] != ' ' ]
sentence_end = end_chars [ " end " ] . max ( )
2023-05-07 15:32:58 +01:00
sentence_words = [ ]
for word_idx in curr_chars [ " word-idx " ] . unique ( ) :
word_chars = curr_chars . loc [ curr_chars [ " word-idx " ] == word_idx ]
word_text = " " . join ( word_chars [ " char " ] . tolist ( ) ) . strip ( )
if len ( word_text ) == 0 :
continue
2023-05-13 12:14:06 +01:00
# dont use space character for alignment
word_chars = word_chars [ word_chars [ " char " ] != " " ]
2023-05-07 15:32:58 +01:00
word_start = word_chars [ " start " ] . min ( )
word_end = word_chars [ " end " ] . max ( )
word_score = round ( word_chars [ " score " ] . mean ( ) , 3 )
# -1 indicates unalignable
word_segment = { " word " : word_text }
if not np . isnan ( word_start ) :
word_segment [ " start " ] = word_start
if not np . isnan ( word_end ) :
word_segment [ " end " ] = word_end
if not np . isnan ( word_score ) :
word_segment [ " score " ] = word_score
sentence_words . append ( word_segment )
2023-04-24 21:08:43 +01:00
2023-05-07 15:32:58 +01:00
aligned_subsegments . append ( {
" text " : sentence_text ,
" start " : sentence_start ,
" end " : sentence_end ,
" words " : sentence_words ,
} )
if return_char_alignments :
curr_chars = curr_chars [ [ " char " , " start " , " end " , " score " ] ]
curr_chars . fillna ( - 1 , inplace = True )
curr_chars = curr_chars . to_dict ( " records " )
curr_chars = [ { key : val for key , val in char . items ( ) if val != - 1 } for char in curr_chars ]
2023-05-07 20:28:33 +01:00
aligned_subsegments [ - 1 ] [ " chars " ] = curr_chars
2023-05-07 15:32:58 +01:00
aligned_subsegments = pd . DataFrame ( aligned_subsegments )
aligned_subsegments [ " start " ] = interpolate_nans ( aligned_subsegments [ " start " ] , method = interpolate_method )
aligned_subsegments [ " end " ] = interpolate_nans ( aligned_subsegments [ " end " ] , method = interpolate_method )
# concatenate sentences with same timestamps
agg_dict = { " text " : " " . join , " words " : " sum " }
2023-05-26 20:36:03 +01:00
if model_lang in LANGUAGES_WITHOUT_SPACES :
agg_dict [ " text " ] = " " . join
2023-05-07 15:32:58 +01:00
if return_char_alignments :
agg_dict [ " chars " ] = " sum "
aligned_subsegments = aligned_subsegments . groupby ( [ " start " , " end " ] , as_index = False ) . agg ( agg_dict )
aligned_subsegments = aligned_subsegments . to_dict ( ' records ' )
aligned_segments + = aligned_subsegments
# create word_segments list
2023-05-08 20:45:34 +02:00
word_segments : List [ SingleWordSegment ] = [ ]
2023-05-07 15:32:58 +01:00
for segment in aligned_segments :
word_segments + = segment [ " words " ]
return { " segments " : aligned_segments , " word_segments " : word_segments }
2023-01-25 18:42:52 +00:00
"""
source : https : / / pytorch . org / tutorials / intermediate / forced_alignment_with_torchaudio_tutorial . html
"""
2025-01-09 15:23:40 +08:00
2022-12-14 18:59:12 +00:00
def get_trellis ( emission , tokens , blank_id = 0 ) :
num_frame = emission . size ( 0 )
num_tokens = len ( tokens )
2025-01-09 15:23:40 +08:00
trellis = torch . zeros ( ( num_frame , num_tokens ) )
trellis [ 1 : , 0 ] = torch . cumsum ( emission [ 1 : , blank_id ] , 0 )
trellis [ 0 , 1 : ] = - float ( " inf " )
trellis [ - num_tokens + 1 : , 0 ] = float ( " inf " )
2022-12-14 18:59:12 +00:00
2025-01-09 15:23:40 +08:00
for t in range ( num_frame - 1 ) :
2022-12-14 18:59:12 +00:00
trellis [ t + 1 , 1 : ] = torch . maximum (
# Score for staying at the same token
trellis [ t , 1 : ] + emission [ t , blank_id ] ,
# Score for changing to the next token
2025-01-09 15:23:40 +08:00
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis [ t , : - 1 ] + get_wildcard_emission ( emission [ t ] , tokens [ 1 : ] , blank_id ) ,
2022-12-14 18:59:12 +00:00
)
return trellis
2025-01-09 15:23:40 +08:00
def get_wildcard_emission ( frame_emission , tokens , blank_id ) :
2025-01-13 22:56:48 +08:00
""" Processing token emission scores containing wildcards (vectorized version)
2025-01-09 15:23:40 +08:00
Args :
2025-01-13 22:56:48 +08:00
frame_emission : Emission probability vector for the current frame
tokens : List of token indices
blank_id : ID of the blank token
2025-01-09 15:23:40 +08:00
Returns :
2025-01-13 22:56:48 +08:00
tensor : Maximum probability score for each token position
2025-01-09 15:23:40 +08:00
"""
assert 0 < = blank_id < len ( frame_emission )
2025-01-09 19:33:26 +08:00
2025-01-13 22:56:48 +08:00
# Convert tokens to a tensor if they are not already
2025-01-09 19:33:26 +08:00
tokens = torch . tensor ( tokens ) if not isinstance ( tokens , torch . Tensor ) else tokens
2025-01-13 22:56:48 +08:00
# Create a mask to identify wildcard positions
2025-01-09 19:33:26 +08:00
wildcard_mask = ( tokens == - 1 )
2025-01-13 22:56:48 +08:00
# Get scores for non-wildcard positions
2025-05-15 14:30:55 +08:00
regular_scores = frame_emission [ tokens . clamp ( min = 0 ) . long ( ) ] # clamp to avoid -1 index
2025-01-09 19:33:26 +08:00
2025-01-13 22:56:48 +08:00
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission . clone ( ) # Create a copy
max_valid_score [ blank_id ] = float ( ' -inf ' ) # Modify the copy to exclude the blank token
2025-01-09 19:33:26 +08:00
max_valid_score = max_valid_score . max ( )
2025-01-13 22:56:48 +08:00
# Use where operation to combine results
2025-01-09 19:33:26 +08:00
result = torch . where ( wildcard_mask , max_valid_score , regular_scores )
return result
2025-01-09 15:23:40 +08:00
2022-12-14 18:59:12 +00:00
@dataclass
class Point :
token_index : int
time_index : int
score : float
2025-01-09 15:23:40 +08:00
2022-12-14 18:59:12 +00:00
def backtrack ( trellis , emission , tokens , blank_id = 0 ) :
2025-01-09 15:23:40 +08:00
t , j = trellis . size ( 0 ) - 1 , trellis . size ( 1 ) - 1
path = [ Point ( j , t , emission [ t , blank_id ] . exp ( ) . item ( ) ) ]
while j > 0 :
# Should not happen but just in case
assert t > 0
2022-12-14 18:59:12 +00:00
# 1. Figure out if the current position was stay or change
2025-01-09 15:23:40 +08:00
# Frame-wise score of stay vs change
p_stay = emission [ t - 1 , blank_id ]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission ( emission [ t - 1 ] , [ tokens [ j ] ] , blank_id ) [ 0 ]
# Context-aware score for stay vs change
stayed = trellis [ t - 1 , j ] + p_stay
changed = trellis [ t - 1 , j - 1 ] + p_change
# Update position
t - = 1
2022-12-14 18:59:12 +00:00
if changed > stayed :
j - = 1
2025-01-09 15:23:40 +08:00
# Store the path with frame-wise probability.
prob = ( p_change if changed > stayed else p_stay ) . exp ( ) . item ( )
path . append ( Point ( j , t , prob ) )
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0 :
prob = emission [ t - 1 , blank_id ] . exp ( ) . item ( )
path . append ( Point ( j , t - 1 , prob ) )
t - = 1
2022-12-14 18:59:12 +00:00
return path [ : : - 1 ]
2025-01-09 15:23:40 +08:00
@dataclass
class Path :
points : List [ Point ]
score : float
@dataclass
class BeamState :
2025-01-13 22:56:48 +08:00
""" State in beam search. """
token_index : int # Current token position
time_index : int # Current time step
score : float # Cumulative score
path : List [ Point ] # Path history
2025-01-09 15:23:40 +08:00
def backtrack_beam ( trellis , emission , tokens , blank_id = 0 , beam_width = 5 ) :
2025-01-13 22:56:48 +08:00
""" Standard CTC beam search backtracking implementation.
Args :
trellis ( torch . Tensor ) : The trellis ( or lattice ) of shape ( T , N ) , where T is the number of time steps
and N is the number of tokens ( including the blank token ) .
emission ( torch . Tensor ) : The emission probabilities of shape ( T , N ) .
tokens ( List [ int ] ) : List of token indices ( excluding the blank token ) .
blank_id ( int , optional ) : The ID of the blank token . Defaults to 0.
beam_width ( int , optional ) : The number of top paths to keep during beam search . Defaults to 5.
Returns :
List [ Point ] : the best path
2025-01-09 15:23:40 +08:00
"""
T , J = trellis . size ( 0 ) - 1 , trellis . size ( 1 ) - 1
init_state = BeamState (
token_index = J ,
time_index = T ,
score = trellis [ T , J ] ,
path = [ Point ( J , T , emission [ T , blank_id ] . exp ( ) . item ( ) ) ]
)
beams = [ init_state ]
while beams and beams [ 0 ] . token_index > 0 :
next_beams = [ ]
for beam in beams :
t , j = beam . time_index , beam . token_index
if t < = 0 :
continue
p_stay = emission [ t - 1 , blank_id ]
p_change = get_wildcard_emission ( emission [ t - 1 ] , [ tokens [ j ] ] , blank_id ) [ 0 ]
stay_score = trellis [ t - 1 , j ]
change_score = trellis [ t - 1 , j - 1 ] if j > 0 else float ( ' -inf ' )
2025-01-13 22:56:48 +08:00
# Stay
2025-01-09 15:23:40 +08:00
if not math . isinf ( stay_score ) :
new_path = beam . path . copy ( )
new_path . append ( Point ( j , t - 1 , p_stay . exp ( ) . item ( ) ) )
next_beams . append ( BeamState (
token_index = j ,
time_index = t - 1 ,
score = stay_score ,
path = new_path
) )
2025-01-13 22:56:48 +08:00
# Change
2025-01-09 15:23:40 +08:00
if j > 0 and not math . isinf ( change_score ) :
new_path = beam . path . copy ( )
new_path . append ( Point ( j - 1 , t - 1 , p_change . exp ( ) . item ( ) ) )
next_beams . append ( BeamState (
token_index = j - 1 ,
time_index = t - 1 ,
score = change_score ,
path = new_path
) )
2025-01-13 22:56:48 +08:00
# sort by score
2025-01-09 15:23:40 +08:00
beams = sorted ( next_beams , key = lambda x : x . score , reverse = True ) [ : beam_width ]
if not beams :
break
if not beams :
2025-01-09 23:13:11 +08:00
return None
2025-01-09 15:23:40 +08:00
best_beam = beams [ 0 ]
t = best_beam . time_index
j = best_beam . token_index
while t > 0 :
prob = emission [ t - 1 , blank_id ] . exp ( ) . item ( )
best_beam . path . append ( Point ( j , t - 1 , prob ) )
t - = 1
return best_beam . path [ : : - 1 ]
2022-12-14 18:59:12 +00:00
# Merge the labels
@dataclass
class Segment :
label : str
start : int
end : int
score : float
def __repr__ ( self ) :
return f " { self . label } \t ( { self . score : 4.2f } ): [ { self . start : 5d } , { self . end : 5d } ) "
@property
def length ( self ) :
return self . end - self . start
def merge_repeats ( path , transcript ) :
i1 , i2 = 0 , 0
segments = [ ]
while i1 < len ( path ) :
while i2 < len ( path ) and path [ i1 ] . token_index == path [ i2 ] . token_index :
i2 + = 1
score = sum ( path [ k ] . score for k in range ( i1 , i2 ) ) / ( i2 - i1 )
segments . append (
Segment (
transcript [ path [ i1 ] . token_index ] ,
path [ i1 ] . time_index ,
path [ i2 - 1 ] . time_index + 1 ,
score ,
)
)
i1 = i2
return segments
def merge_words ( segments , separator = " | " ) :
words = [ ]
i1 , i2 = 0 , 0
while i1 < len ( segments ) :
if i2 > = len ( segments ) or segments [ i2 ] . label == separator :
if i1 != i2 :
segs = segments [ i1 : i2 ]
word = " " . join ( [ seg . label for seg in segs ] )
score = sum ( seg . score * seg . length for seg in segs ) / sum ( seg . length for seg in segs )
words . append ( Segment ( word , segments [ i1 ] . start , segments [ i2 - 1 ] . end , score ) )
i1 = i2 + 1
i2 = i1
else :
i2 + = 1
return words