2023-01-25 18:42:52 +00:00
""" "
Forced Alignment with Whisper
C . Max Bain
2022-12-14 18:59:12 +00:00
"""
2023-04-24 21:08:43 +01:00
from dataclasses import dataclass
2023-09-25 15:33:06 +09:00
from typing import Iterable , Union , List
2023-04-24 21:08:43 +01:00
2023-01-25 18:42:52 +00:00
import numpy as np
import pandas as pd
2022-12-14 18:59:12 +00:00
import torch
2023-04-24 21:08:43 +01:00
import torchaudio
from transformers import Wav2Vec2ForCTC , Wav2Vec2Processor
2023-01-25 18:42:52 +00:00
from . audio import SAMPLE_RATE , load_audio
from . utils import interpolate_nans
2023-05-08 20:45:34 +02:00
from . types import AlignedTranscriptionResult , SingleSegment , SingleAlignedSegment , SingleWordSegment
2023-05-07 15:32:58 +01:00
import nltk
2023-05-29 12:48:14 +01:00
from nltk . tokenize . punkt import PunktSentenceTokenizer , PunktParameters
PUNKT_ABBREVIATIONS = [ ' dr ' , ' vs ' , ' mr ' , ' mrs ' , ' prof ' ]
2023-01-25 18:42:52 +00:00
LANGUAGES_WITHOUT_SPACES = [ " ja " , " zh " ]
DEFAULT_ALIGN_MODELS_TORCH = {
" en " : " WAV2VEC2_ASR_BASE_960H " ,
" fr " : " VOXPOPULI_ASR_BASE_10K_FR " ,
" de " : " VOXPOPULI_ASR_BASE_10K_DE " ,
" es " : " VOXPOPULI_ASR_BASE_10K_ES " ,
" it " : " VOXPOPULI_ASR_BASE_10K_IT " ,
}
DEFAULT_ALIGN_MODELS_HF = {
" ja " : " jonatasgrosman/wav2vec2-large-xlsr-53-japanese " ,
" zh " : " jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn " ,
" nl " : " jonatasgrosman/wav2vec2-large-xlsr-53-dutch " ,
" uk " : " Yehor/wav2vec2-xls-r-300m-uk-with-small-lm " ,
" pt " : " jonatasgrosman/wav2vec2-large-xlsr-53-portuguese " ,
2023-01-31 19:32:31 +02:00
" ar " : " jonatasgrosman/wav2vec2-large-xlsr-53-arabic " ,
2023-05-26 21:17:01 +02:00
" cs " : " comodoro/wav2vec2-xls-r-300m-cs-250 " ,
2023-02-01 03:02:10 +02:00
" ru " : " jonatasgrosman/wav2vec2-large-xlsr-53-russian " ,
" pl " : " jonatasgrosman/wav2vec2-large-xlsr-53-polish " ,
" hu " : " jonatasgrosman/wav2vec2-large-xlsr-53-hungarian " ,
" fi " : " jonatasgrosman/wav2vec2-large-xlsr-53-finnish " ,
" fa " : " jonatasgrosman/wav2vec2-large-xlsr-53-persian " ,
" el " : " jonatasgrosman/wav2vec2-large-xlsr-53-greek " ,
2023-02-01 21:38:50 +02:00
" tr " : " mpoyraz/wav2vec2-xls-r-300m-cv7-turkish " ,
2023-05-09 23:10:13 +01:00
" da " : " saattrupdan/wav2vec2-xls-r-300m-ftspeech " ,
2023-05-03 11:26:12 -05:00
" he " : " imvladikon/wav2vec2-xls-r-300m-hebrew " ,
2023-05-26 17:14:09 +01:00
" vi " : ' nguyenvulebinh/wav2vec2-base-vi ' ,
2023-05-26 20:33:16 +09:00
" ko " : " kresnik/wav2vec2-large-xlsr-korean " ,
2023-07-24 10:47:41 +01:00
" ur " : " kingabzpro/wav2vec2-large-xls-r-300m-Urdu " ,
2023-08-10 12:13:52 +01:00
" te " : " anuragshas/wav2vec2-large-xlsr-53-telugu " ,
2023-11-16 11:43:36 +01:00
" hi " : " theainerd/Wav2Vec2-large-xlsr-hindi " ,
" ca " : " softcatala/wav2vec2-large-xlsr-catala "
2023-01-25 18:42:52 +00:00
}
2023-04-14 21:40:36 +01:00
def load_align_model ( language_code , device , model_name = None , model_dir = None ) :
2023-01-25 18:42:52 +00:00
if model_name is None :
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH :
model_name = DEFAULT_ALIGN_MODELS_TORCH [ language_code ]
elif language_code in DEFAULT_ALIGN_MODELS_HF :
model_name = DEFAULT_ALIGN_MODELS_HF [ language_code ]
else :
print ( f " There is no default alignment model set for this language ( { language_code } ). \
Please find a wav2vec2 .0 model finetuned on this language in https : / / huggingface . co / models , then pass the model name in - - align_model [ MODEL_NAME ] " )
raise ValueError ( f " No default align-model for language: { language_code } " )
if model_name in torchaudio . pipelines . __all__ :
pipeline_type = " torchaudio "
bundle = torchaudio . pipelines . __dict__ [ model_name ]
2023-04-14 21:40:36 +01:00
align_model = bundle . get_model ( dl_kwargs = { " model_dir " : model_dir } ) . to ( device )
2023-01-25 18:42:52 +00:00
labels = bundle . get_labels ( )
align_dictionary = { c . lower ( ) : i for i , c in enumerate ( labels ) }
else :
try :
2023-01-31 23:24:26 -03:00
processor = Wav2Vec2Processor . from_pretrained ( model_name )
2023-01-25 18:42:52 +00:00
align_model = Wav2Vec2ForCTC . from_pretrained ( model_name )
except Exception as e :
print ( e )
print ( f " Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models " )
raise ValueError ( f ' The chosen align_model " { model_name } " could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14) ' )
pipeline_type = " huggingface "
align_model = align_model . to ( device )
labels = processor . tokenizer . get_vocab ( )
align_dictionary = { char . lower ( ) : code for char , code in processor . tokenizer . get_vocab ( ) . items ( ) }
align_metadata = { " language " : language_code , " dictionary " : align_dictionary , " type " : pipeline_type }
return align_model , align_metadata
def align (
2023-09-25 15:33:06 +09:00
transcript : Iterable [ SingleSegment ] ,
2023-01-25 18:42:52 +00:00
model : torch . nn . Module ,
align_model_metadata : dict ,
audio : Union [ str , np . ndarray , torch . Tensor ] ,
device : str ,
interpolate_method : str = " nearest " ,
2023-05-07 15:32:58 +01:00
return_char_alignments : bool = False ,
2023-08-17 14:57:53 +02:00
print_progress : bool = False ,
combined_progress : bool = False ,
2023-05-08 20:45:34 +02:00
) - > AlignedTranscriptionResult :
2023-01-25 18:42:52 +00:00
"""
2023-05-07 15:32:58 +01:00
Align phoneme recognition predictions to known transcription .
2023-01-25 18:42:52 +00:00
"""
2023-05-07 15:32:58 +01:00
2023-01-25 18:42:52 +00:00
if not torch . is_tensor ( audio ) :
if isinstance ( audio , str ) :
audio = load_audio ( audio )
audio = torch . from_numpy ( audio )
if len ( audio . shape ) == 1 :
audio = audio . unsqueeze ( 0 )
2023-05-07 15:32:58 +01:00
2023-01-25 18:42:52 +00:00
MAX_DURATION = audio . shape [ 1 ] / SAMPLE_RATE
model_dictionary = align_model_metadata [ " dictionary " ]
model_lang = align_model_metadata [ " language " ]
model_type = align_model_metadata [ " type " ]
2023-05-07 15:32:58 +01:00
# 1. Preprocess to keep only characters in dictionary
2023-09-27 20:10:43 +05:30
total_segments = len ( transcript )
2023-01-25 18:42:52 +00:00
for sdx , segment in enumerate ( transcript ) :
2023-05-07 15:32:58 +01:00
# strip spaces at beginning / end, but keep track of the amount.
2023-08-16 16:18:00 +02:00
if print_progress :
2023-08-17 14:53:53 +02:00
base_progress = ( ( sdx + 1 ) / total_segments ) * 100
percent_complete = ( 50 + base_progress / 2 ) if combined_progress else base_progress
2023-08-16 16:18:00 +02:00
print ( f " Progress: { percent_complete : .2f } %... " )
2023-05-07 15:32:58 +01:00
num_leading = len ( segment [ " text " ] ) - len ( segment [ " text " ] . lstrip ( ) )
num_trailing = len ( segment [ " text " ] ) - len ( segment [ " text " ] . rstrip ( ) )
text = segment [ " text " ]
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES :
per_word = text . split ( " " )
else :
per_word = text
2023-01-25 18:42:52 +00:00
2023-05-07 15:32:58 +01:00
clean_char , clean_cdx = [ ] , [ ]
for cdx , char in enumerate ( text ) :
char_ = char . lower ( )
# wav2vec2 models use "|" character to represent spaces
2023-01-25 18:42:52 +00:00
if model_lang not in LANGUAGES_WITHOUT_SPACES :
2023-05-07 15:32:58 +01:00
char_ = char_ . replace ( " " , " | " )
2023-01-25 18:42:52 +00:00
2023-05-07 15:32:58 +01:00
# ignore whitespace at beginning and end of transcript
if cdx < num_leading :
pass
elif cdx > len ( text ) - num_trailing - 1 :
pass
elif char_ in model_dictionary . keys ( ) :
clean_char . append ( char_ )
clean_cdx . append ( cdx )
clean_wdx = [ ]
for wdx , wrd in enumerate ( per_word ) :
if any ( [ c in model_dictionary . keys ( ) for c in wrd ] ) :
clean_wdx . append ( wdx )
2023-05-29 12:48:14 +01:00
punkt_param = PunktParameters ( )
punkt_param . abbrev_types = set ( PUNKT_ABBREVIATIONS )
sentence_splitter = PunktSentenceTokenizer ( punkt_param )
sentence_spans = list ( sentence_splitter . span_tokenize ( text ) )
2023-05-07 15:32:58 +01:00
segment [ " clean_char " ] = clean_char
segment [ " clean_cdx " ] = clean_cdx
segment [ " clean_wdx " ] = clean_wdx
segment [ " sentence_spans " ] = sentence_spans
2023-01-25 18:42:52 +00:00
2023-05-08 20:45:34 +02:00
aligned_segments : List [ SingleAlignedSegment ] = [ ]
2023-08-16 16:13:28 +02:00
2023-05-07 15:32:58 +01:00
# 2. Get prediction matrix from alignment model & align
for sdx , segment in enumerate ( transcript ) :
2023-08-16 16:13:28 +02:00
2023-05-07 15:32:58 +01:00
t1 = segment [ " start " ]
t2 = segment [ " end " ]
text = segment [ " text " ]
2023-05-08 20:45:34 +02:00
aligned_seg : SingleAlignedSegment = {
2023-05-07 15:32:58 +01:00
" start " : t1 ,
" end " : t2 ,
" text " : text ,
" words " : [ ] ,
}
if return_char_alignments :
aligned_seg [ " chars " ] = [ ]
# check we can align
if len ( segment [ " clean_char " ] ) == 0 :
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): no characters in this segment found in model dictionary, resorting to original... ' )
aligned_segments . append ( aligned_seg )
continue
2023-10-15 16:25:15 +03:00
if t1 > = MAX_DURATION :
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): original start time longer than audio duration, skipping... ' )
2023-05-07 15:32:58 +01:00
aligned_segments . append ( aligned_seg )
continue
text_clean = " " . join ( segment [ " clean_char " ] )
tokens = [ model_dictionary [ c ] for c in text_clean ]
f1 = int ( t1 * SAMPLE_RATE )
f2 = int ( t2 * SAMPLE_RATE )
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio [ : , f1 : f2 ]
2023-10-16 20:43:37 +03:00
# Handle the minimum input length for wav2vec2 models
if waveform_segment . shape [ - 1 ] < 400 :
lengths = torch . as_tensor ( [ waveform_segment . shape [ - 1 ] ] ) . to ( device )
waveform_segment = torch . nn . functional . pad (
waveform_segment , ( 0 , 400 - waveform_segment . shape [ - 1 ] )
)
else :
lengths = None
2023-05-07 15:32:58 +01:00
with torch . inference_mode ( ) :
if model_type == " torchaudio " :
2023-10-06 00:41:23 +03:00
emissions , _ = model ( waveform_segment . to ( device ) , lengths = lengths )
2023-05-07 15:32:58 +01:00
elif model_type == " huggingface " :
emissions = model ( waveform_segment . to ( device ) ) . logits
else :
raise NotImplementedError ( f " Align model of type { model_type } not supported. " )
emissions = torch . log_softmax ( emissions , dim = - 1 )
emission = emissions [ 0 ] . cpu ( ) . detach ( )
blank_id = 0
for char , code in model_dictionary . items ( ) :
if char == ' [pad] ' or char == ' <pad> ' :
blank_id = code
trellis = get_trellis ( emission , tokens , blank_id )
path = backtrack ( trellis , emission , tokens , blank_id )
if path is None :
print ( f ' Failed to align segment ( " { segment [ " text " ] } " ): backtrack failed, resorting to original... ' )
aligned_segments . append ( aligned_seg )
continue
char_segments = merge_repeats ( path , text_clean )
duration = t2 - t1
ratio = duration * waveform_segment . size ( 0 ) / ( trellis . size ( 0 ) - 1 )
# assign timestamps to aligned characters
char_segments_arr = [ ]
word_idx = 0
for cdx , char in enumerate ( text ) :
start , end , score = None , None , None
if cdx in segment [ " clean_cdx " ] :
char_seg = char_segments [ segment [ " clean_cdx " ] . index ( cdx ) ]
start = round ( char_seg . start * ratio + t1 , 3 )
end = round ( char_seg . end * ratio + t1 , 3 )
score = round ( char_seg . score , 3 )
char_segments_arr . append (
{
" char " : char ,
" start " : start ,
" end " : end ,
" score " : score ,
" word-idx " : word_idx ,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES :
word_idx + = 1
elif cdx == len ( text ) - 1 or text [ cdx + 1 ] == " " :
word_idx + = 1
char_segments_arr = pd . DataFrame ( char_segments_arr )
aligned_subsegments = [ ]
# assign sentence_idx to each character index
char_segments_arr [ " sentence-idx " ] = None
for sdx , ( sstart , send ) in enumerate ( segment [ " sentence_spans " ] ) :
curr_chars = char_segments_arr . loc [ ( char_segments_arr . index > = sstart ) & ( char_segments_arr . index < = send ) ]
char_segments_arr . loc [ ( char_segments_arr . index > = sstart ) & ( char_segments_arr . index < = send ) , " sentence-idx " ] = sdx
2023-01-25 18:42:52 +00:00
2023-05-07 15:32:58 +01:00
sentence_text = text [ sstart : send ]
sentence_start = curr_chars [ " start " ] . min ( )
2023-11-03 19:47:00 +01:00
end_chars = curr_chars [ curr_chars [ " char " ] != ' ' ]
sentence_end = end_chars [ " end " ] . max ( )
2023-05-07 15:32:58 +01:00
sentence_words = [ ]
for word_idx in curr_chars [ " word-idx " ] . unique ( ) :
word_chars = curr_chars . loc [ curr_chars [ " word-idx " ] == word_idx ]
word_text = " " . join ( word_chars [ " char " ] . tolist ( ) ) . strip ( )
if len ( word_text ) == 0 :
continue
2023-05-13 12:14:06 +01:00
# dont use space character for alignment
word_chars = word_chars [ word_chars [ " char " ] != " " ]
2023-05-07 15:32:58 +01:00
word_start = word_chars [ " start " ] . min ( )
word_end = word_chars [ " end " ] . max ( )
word_score = round ( word_chars [ " score " ] . mean ( ) , 3 )
# -1 indicates unalignable
word_segment = { " word " : word_text }
if not np . isnan ( word_start ) :
word_segment [ " start " ] = word_start
if not np . isnan ( word_end ) :
word_segment [ " end " ] = word_end
if not np . isnan ( word_score ) :
word_segment [ " score " ] = word_score
sentence_words . append ( word_segment )
2023-04-24 21:08:43 +01:00
2023-05-07 15:32:58 +01:00
aligned_subsegments . append ( {
" text " : sentence_text ,
" start " : sentence_start ,
" end " : sentence_end ,
" words " : sentence_words ,
} )
if return_char_alignments :
curr_chars = curr_chars [ [ " char " , " start " , " end " , " score " ] ]
curr_chars . fillna ( - 1 , inplace = True )
curr_chars = curr_chars . to_dict ( " records " )
curr_chars = [ { key : val for key , val in char . items ( ) if val != - 1 } for char in curr_chars ]
2023-05-07 20:28:33 +01:00
aligned_subsegments [ - 1 ] [ " chars " ] = curr_chars
2023-05-07 15:32:58 +01:00
aligned_subsegments = pd . DataFrame ( aligned_subsegments )
aligned_subsegments [ " start " ] = interpolate_nans ( aligned_subsegments [ " start " ] , method = interpolate_method )
aligned_subsegments [ " end " ] = interpolate_nans ( aligned_subsegments [ " end " ] , method = interpolate_method )
# concatenate sentences with same timestamps
agg_dict = { " text " : " " . join , " words " : " sum " }
2023-05-26 20:36:03 +01:00
if model_lang in LANGUAGES_WITHOUT_SPACES :
agg_dict [ " text " ] = " " . join
2023-05-07 15:32:58 +01:00
if return_char_alignments :
agg_dict [ " chars " ] = " sum "
aligned_subsegments = aligned_subsegments . groupby ( [ " start " , " end " ] , as_index = False ) . agg ( agg_dict )
aligned_subsegments = aligned_subsegments . to_dict ( ' records ' )
aligned_segments + = aligned_subsegments
# create word_segments list
2023-05-08 20:45:34 +02:00
word_segments : List [ SingleWordSegment ] = [ ]
2023-05-07 15:32:58 +01:00
for segment in aligned_segments :
word_segments + = segment [ " words " ]
return { " segments " : aligned_segments , " word_segments " : word_segments }
2023-01-25 18:42:52 +00:00
"""
source : https : / / pytorch . org / tutorials / intermediate / forced_alignment_with_torchaudio_tutorial . html
"""
2022-12-14 18:59:12 +00:00
def get_trellis ( emission , tokens , blank_id = 0 ) :
num_frame = emission . size ( 0 )
num_tokens = len ( tokens )
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch . empty ( ( num_frame + 1 , num_tokens + 1 ) )
trellis [ 0 , 0 ] = 0
trellis [ 1 : , 0 ] = torch . cumsum ( emission [ : , 0 ] , 0 )
trellis [ 0 , - num_tokens : ] = - float ( " inf " )
trellis [ - num_tokens : , 0 ] = float ( " inf " )
for t in range ( num_frame ) :
trellis [ t + 1 , 1 : ] = torch . maximum (
# Score for staying at the same token
trellis [ t , 1 : ] + emission [ t , blank_id ] ,
# Score for changing to the next token
trellis [ t , : - 1 ] + emission [ t , tokens ] ,
)
return trellis
@dataclass
class Point :
token_index : int
time_index : int
score : float
def backtrack ( trellis , emission , tokens , blank_id = 0 ) :
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis . size ( 1 ) - 1
t_start = torch . argmax ( trellis [ : , j ] ) . item ( )
path = [ ]
for t in range ( t_start , 0 , - 1 ) :
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis [ t - 1 , j ] + emission [ t - 1 , blank_id ]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis [ t - 1 , j - 1 ] + emission [ t - 1 , tokens [ j - 1 ] ]
# 2. Store the path with frame-wise probability.
prob = emission [ t - 1 , tokens [ j - 1 ] if changed > stayed else 0 ] . exp ( ) . item ( )
# Return token index and time index in non-trellis coordinate.
path . append ( Point ( j - 1 , t - 1 , prob ) )
# 3. Update the token
if changed > stayed :
j - = 1
if j == 0 :
break
else :
2023-01-05 11:15:19 +00:00
# failed
return None
2022-12-14 18:59:12 +00:00
return path [ : : - 1 ]
# Merge the labels
@dataclass
class Segment :
label : str
start : int
end : int
score : float
def __repr__ ( self ) :
return f " { self . label } \t ( { self . score : 4.2f } ): [ { self . start : 5d } , { self . end : 5d } ) "
@property
def length ( self ) :
return self . end - self . start
def merge_repeats ( path , transcript ) :
i1 , i2 = 0 , 0
segments = [ ]
while i1 < len ( path ) :
while i2 < len ( path ) and path [ i1 ] . token_index == path [ i2 ] . token_index :
i2 + = 1
score = sum ( path [ k ] . score for k in range ( i1 , i2 ) ) / ( i2 - i1 )
segments . append (
Segment (
transcript [ path [ i1 ] . token_index ] ,
path [ i1 ] . time_index ,
path [ i2 - 1 ] . time_index + 1 ,
score ,
)
)
i1 = i2
return segments
def merge_words ( segments , separator = " | " ) :
words = [ ]
i1 , i2 = 0 , 0
while i1 < len ( segments ) :
if i2 > = len ( segments ) or segments [ i2 ] . label == separator :
if i1 != i2 :
segs = segments [ i1 : i2 ]
word = " " . join ( [ seg . label for seg in segs ] )
score = sum ( seg . score * seg . length for seg in segs ) / sum ( seg . length for seg in segs )
words . append ( Segment ( word , segments [ i1 ] . start , segments [ i2 - 1 ] . end , score ) )
i1 = i2 + 1
i2 = i1
else :
i2 + = 1
return words