2022-12-14 18:59:12 +00:00
import argparse
import os
import warnings
2022-12-21 01:03:52 +09:00
from typing import List , Optional , Tuple , Union , Iterator , TYPE_CHECKING
2022-12-14 18:59:12 +00:00
import numpy as np
import torch
import torchaudio
2022-12-19 22:28:28 +09:00
from transformers import AutoProcessor , Wav2Vec2ForCTC
2022-12-21 01:03:52 +09:00
import tqdm
2023-01-20 12:54:20 +00:00
from . audio import SAMPLE_RATE , N_FRAMES , HOP_LENGTH , CHUNK_LENGTH , pad_or_trim , log_mel_spectrogram , load_audio
2022-12-21 01:03:52 +09:00
from . alignment import get_trellis , backtrack , merge_repeats , merge_words
2022-12-14 18:59:12 +00:00
from . decoding import DecodingOptions , DecodingResult
from . tokenizer import LANGUAGES , TO_LANGUAGE_CODE , get_tokenizer
2023-01-24 15:02:08 +00:00
from . utils import exact_div , format_timestamp , optional_int , optional_float , str2bool , interpolate_nans , write_txt , write_vtt , write_srt , write_ass , write_tsv
2023-01-20 12:54:20 +00:00
import pandas as pd
2022-12-14 18:59:12 +00:00
if TYPE_CHECKING :
from . model import Whisper
2022-12-20 19:54:55 +00:00
LANGUAGES_WITHOUT_SPACES = [ " ja " , " zh " ]
DEFAULT_ALIGN_MODELS_TORCH = {
" en " : " WAV2VEC2_ASR_BASE_960H " ,
" fr " : " VOXPOPULI_ASR_BASE_10K_FR " ,
" de " : " VOXPOPULI_ASR_BASE_10K_DE " ,
" es " : " VOXPOPULI_ASR_BASE_10K_ES " ,
" it " : " VOXPOPULI_ASR_BASE_10K_IT " ,
}
DEFAULT_ALIGN_MODELS_HF = {
" ja " : " jonatasgrosman/wav2vec2-large-xlsr-53-japanese " ,
2022-12-23 00:41:12 +00:00
" zh " : " jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn " ,
" nl " : " jonatasgrosman/wav2vec2-large-xlsr-53-dutch " ,
2022-12-24 15:05:13 +02:00
" uk " : " Yehor/wav2vec2-xls-r-300m-uk-with-small-lm " ,
2023-01-11 19:50:34 -03:00
" pt " : " jonatasgrosman/wav2vec2-large-xlsr-53-portuguese " ,
2022-12-20 19:54:55 +00:00
}
2022-12-21 02:11:08 +09:00
2022-12-14 18:59:12 +00:00
def transcribe (
model : " Whisper " ,
audio : Union [ str , np . ndarray , torch . Tensor ] ,
* ,
verbose : Optional [ bool ] = None ,
temperature : Union [ float , Tuple [ float , . . . ] ] = ( 0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ) ,
compression_ratio_threshold : Optional [ float ] = 2.4 ,
logprob_threshold : Optional [ float ] = - 1.0 ,
no_speech_threshold : Optional [ float ] = 0.6 ,
2022-12-19 19:41:39 +00:00
condition_on_previous_text : bool = False , # turn off by default due to errors it causes
2023-01-20 12:54:20 +00:00
mel : np . ndarray = None ,
2022-12-14 18:59:12 +00:00
* * decode_options ,
) :
"""
Transcribe an audio file using Whisper
Parameters
- - - - - - - - - -
model : Whisper
The Whisper model instance
audio : Union [ str , np . ndarray , torch . Tensor ]
The path to the audio file to open , or the audio waveform
verbose : bool
Whether to display the text being decoded to the console . If True , displays all the details ,
If False , displays minimal details . If None , does not display anything
temperature : Union [ float , Tuple [ float , . . . ] ]
Temperature for sampling . It can be a tuple of temperatures , which will be successfully used
upon failures according to either ` compression_ratio_threshold ` or ` logprob_threshold ` .
compression_ratio_threshold : float
If the gzip compression ratio is above this value , treat as failed
logprob_threshold : float
If the average log probability over sampled tokens is below this value , treat as failed
no_speech_threshold : float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below ` logprob_threshold ` , consider the segment as silent
condition_on_previous_text : bool
if True , the previous output of the model is provided as a prompt for the next window ;
disabling may make the text inconsistent across windows , but the model becomes less prone to
getting stuck in a failure loop , such as repetition looping or timestamps going out of sync .
decode_options : dict
Keyword arguments to construct ` DecodingOptions ` instances
Returns
- - - - - - -
A dictionary containing the resulting text ( " text " ) and segment - level details ( " segments " ) , and
the spoken language ( " language " ) , which is detected when ` decode_options [ " language " ] ` is None .
"""
dtype = torch . float16 if decode_options . get ( " fp16 " , True ) else torch . float32
if model . device == torch . device ( " cpu " ) :
if torch . cuda . is_available ( ) :
warnings . warn ( " Performing inference on CPU when CUDA is available " )
if dtype == torch . float16 :
warnings . warn ( " FP16 is not supported on CPU; using FP32 instead " )
dtype = torch . float32
if dtype == torch . float32 :
decode_options [ " fp16 " ] = False
2023-01-20 12:54:20 +00:00
if mel is None :
mel = log_mel_spectrogram ( audio )
2022-12-14 18:59:12 +00:00
if decode_options . get ( " language " , None ) is None :
if not model . is_multilingual :
decode_options [ " language " ] = " en "
else :
if verbose :
print ( " Detecting language using up to the first 30 seconds. Use `--language` to specify the language " )
segment = pad_or_trim ( mel , N_FRAMES ) . to ( model . device ) . to ( dtype )
_ , probs = model . detect_language ( segment )
decode_options [ " language " ] = max ( probs , key = probs . get )
if verbose is not None :
print ( f " Detected language: { LANGUAGES [ decode_options [ ' language ' ] ] . title ( ) } " )
language = decode_options [ " language " ]
task = decode_options . get ( " task " , " transcribe " )
tokenizer = get_tokenizer ( model . is_multilingual , language = language , task = task )
def decode_with_fallback ( segment : torch . Tensor ) - > DecodingResult :
temperatures = [ temperature ] if isinstance ( temperature , ( int , float ) ) else temperature
decode_result = None
for t in temperatures :
kwargs = { * * decode_options }
if t > 0 :
# disable beam_size and patience when t > 0
kwargs . pop ( " beam_size " , None )
kwargs . pop ( " patience " , None )
else :
# disable best_of when t == 0
kwargs . pop ( " best_of " , None )
options = DecodingOptions ( * * kwargs , temperature = t )
decode_result = model . decode ( segment , options )
needs_fallback = False
if compression_ratio_threshold is not None and decode_result . compression_ratio > compression_ratio_threshold :
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result . avg_logprob < logprob_threshold :
needs_fallback = True # average log probability is too low
if not needs_fallback :
break
return decode_result
seek = 0
input_stride = exact_div (
N_FRAMES , model . dims . n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = [ ]
all_segments = [ ]
prompt_reset_since = 0
initial_prompt = decode_options . pop ( " initial_prompt " , None ) or [ ]
if initial_prompt :
initial_prompt = tokenizer . encode ( " " + initial_prompt . strip ( ) )
all_tokens . extend ( initial_prompt )
def add_segment (
* , start : float , end : float , text_tokens : torch . Tensor , result : DecodingResult
) :
text = tokenizer . decode ( [ token for token in text_tokens if token < tokenizer . eot ] )
if len ( text . strip ( ) ) == 0 : # skip empty text output
return
all_segments . append (
{
" id " : len ( all_segments ) ,
" seek " : seek ,
" start " : start ,
" end " : end ,
" text " : text ,
" tokens " : text_tokens . tolist ( ) ,
" temperature " : result . temperature ,
" avg_logprob " : result . avg_logprob ,
" compression_ratio " : result . compression_ratio ,
" no_speech_prob " : result . no_speech_prob ,
}
)
if verbose :
print ( f " [ { format_timestamp ( start ) } --> { format_timestamp ( end ) } ] { text } " )
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel . shape [ - 1 ]
previous_seek_value = seek
with tqdm . tqdm ( total = num_frames , unit = ' frames ' , disable = verbose is not False ) as pbar :
while seek < num_frames :
timestamp_offset = float ( seek * HOP_LENGTH / SAMPLE_RATE )
segment = pad_or_trim ( mel [ : , seek : ] , N_FRAMES ) . to ( model . device ) . to ( dtype )
segment_duration = segment . shape [ - 1 ] * HOP_LENGTH / SAMPLE_RATE
decode_options [ " prompt " ] = all_tokens [ prompt_reset_since : ]
result : DecodingResult = decode_with_fallback ( segment )
tokens = torch . tensor ( result . tokens )
if no_speech_threshold is not None :
# no voice activity check
should_skip = result . no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result . avg_logprob > logprob_threshold :
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip :
seek + = segment . shape [ - 1 ] # fast-forward to the next segment boundary
continue
timestamp_tokens : torch . Tensor = tokens . ge ( tokenizer . timestamp_begin )
consecutive = torch . where ( timestamp_tokens [ : - 1 ] & timestamp_tokens [ 1 : ] ) [ 0 ] . add_ ( 1 )
if len ( consecutive ) > 0 : # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive :
sliced_tokens = tokens [ last_slice : current_slice ]
start_timestamp_position = (
sliced_tokens [ 0 ] . item ( ) - tokenizer . timestamp_begin
)
end_timestamp_position = (
sliced_tokens [ - 1 ] . item ( ) - tokenizer . timestamp_begin
)
2023-01-08 14:01:10 +00:00
# clamp end-time to at least be 1 frame after start-time
end_timestamp_position = max ( end_timestamp_position , start_timestamp_position + time_precision )
2022-12-14 18:59:12 +00:00
add_segment (
start = timestamp_offset + start_timestamp_position * time_precision ,
end = timestamp_offset + end_timestamp_position * time_precision ,
text_tokens = sliced_tokens [ 1 : - 1 ] ,
result = result ,
)
last_slice = current_slice
last_timestamp_position = (
tokens [ last_slice - 1 ] . item ( ) - tokenizer . timestamp_begin
)
seek + = last_timestamp_position * input_stride
all_tokens . extend ( tokens [ : last_slice + 1 ] . tolist ( ) )
else :
duration = segment_duration
timestamps = tokens [ timestamp_tokens . nonzero ( ) . flatten ( ) ]
if len ( timestamps ) > 0 and timestamps [ - 1 ] . item ( ) != tokenizer . timestamp_begin :
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps [ - 1 ] . item ( ) - tokenizer . timestamp_begin
duration = last_timestamp_position * time_precision
add_segment (
start = timestamp_offset ,
end = timestamp_offset + duration ,
text_tokens = tokens ,
result = result ,
)
seek + = segment . shape [ - 1 ]
all_tokens . extend ( tokens . tolist ( ) )
if not condition_on_previous_text or result . temperature > 0.5 :
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len ( all_tokens )
# update progress bar
pbar . update ( min ( num_frames , seek ) - previous_seek_value )
previous_seek_value = seek
return dict ( text = tokenizer . decode ( all_tokens [ len ( initial_prompt ) : ] ) , segments = all_segments , language = language )
def align (
transcript : Iterator [ dict ] ,
model : torch . nn . Module ,
2022-12-20 19:54:55 +00:00
align_model_metadata : dict ,
2022-12-14 18:59:12 +00:00
audio : Union [ str , np . ndarray , torch . Tensor ] ,
device : str ,
extend_duration : float = 0.0 ,
start_from_previous : bool = True ,
2023-01-24 15:02:08 +00:00
interpolate_method : str = " nearest " ,
2022-12-14 18:59:12 +00:00
) :
2023-01-24 15:02:08 +00:00
"""
Force align phoneme recognition predictions to known transcription
Parameters
- - - - - - - - - -
transcript : Iterator [ dict ]
The Whisper model instance
model : torch . nn . Module
Alignment model ( wav2vec2 )
audio : Union [ str , np . ndarray , torch . Tensor ]
The path to the audio file to open , or the audio waveform
device : str
cuda device
extend_duration : float
Amount to pad input segments by . If not using vad - - filter then recommended to use 2 seconds
If the gzip compression ratio is above this value , treat as failed
interpolate_method : str [ " nearest " , " linear " , " ignore " ]
Method to assign timestamps to non - aligned words . Words are not able to be aligned when none of the characters occur in the align model dictionary .
" nearest " copies timestamp of nearest word within the segment . " linear " is linear interpolation . " drop " removes that word from output .
Returns
- - - - - - -
A dictionary containing the resulting text ( " text " ) and segment - level details ( " segments " ) , and
the spoken language ( " language " ) , which is detected when ` decode_options [ " language " ] ` is None .
"""
2022-12-14 18:59:12 +00:00
if not torch . is_tensor ( audio ) :
if isinstance ( audio , str ) :
audio = load_audio ( audio )
audio = torch . from_numpy ( audio )
if len ( audio . shape ) == 1 :
audio = audio . unsqueeze ( 0 )
MAX_DURATION = audio . shape [ 1 ] / SAMPLE_RATE
2023-01-24 15:02:08 +00:00
model_dictionary = align_model_metadata [ " dictionary " ]
model_lang = align_model_metadata [ " language " ]
model_type = align_model_metadata [ " type " ]
aligned_segments = [ ]
2022-12-20 19:54:55 +00:00
2022-12-14 18:59:12 +00:00
prev_t2 = 0
2023-01-24 15:02:08 +00:00
for segment in transcript :
2023-01-24 16:37:19 +00:00
aligned_subsegments = [ ]
2023-01-24 15:02:08 +00:00
while True :
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len ( segment [ " text " ] ) - len ( segment [ " text " ] . lstrip ( ) )
num_trailing = len ( segment [ " text " ] ) - len ( segment [ " text " ] . rstrip ( ) )
transcription = segment [ " text " ]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES :
per_word = transcription . split ( " " )
2022-12-20 19:54:55 +00:00
else :
2023-01-24 15:02:08 +00:00
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char , clean_cdx = [ ] , [ ]
for cdx , char in enumerate ( transcription ) :
char_ = char . lower ( )
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES :
char_ = char_ . replace ( " " , " | " )
# ignore whitespace at beginning and end of transcript
if cdx < num_leading :
pass
elif cdx > len ( transcription ) - num_trailing - 1 :
pass
elif char_ in model_dictionary . keys ( ) :
clean_char . append ( char_ )
clean_cdx . append ( cdx )
clean_wdx = [ ]
for wdx , wrd in enumerate ( per_word ) :
if any ( [ c in model_dictionary . keys ( ) for c in wrd ] ) :
clean_wdx . append ( wdx )
# if no characters are in the dictionary, then we skip this segment...
if len ( clean_char ) == 0 :
print ( " Failed to align segment: no characters in this segment found in model dictionary, resorting to original... " )
break
transcription_cleaned = " " . join ( clean_char )
tokens = [ model_dictionary [ c ] for c in transcription_cleaned ]
2023-01-20 12:54:20 +00:00
2023-01-24 15:02:08 +00:00
# pad according original timestamps
t1 = max ( segment [ " start " ] - extend_duration , 0 )
t2 = min ( segment [ " end " ] + extend_duration , MAX_DURATION )
2022-12-19 22:28:28 +09:00
2023-01-24 15:02:08 +00:00
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2 :
t1 = prev_t2
# check if timestamp range is still valid
if t1 > = MAX_DURATION :
print ( " Failed to align segment: original start time longer than audio duration, skipping... " )
break
if t2 - t1 < 0.02 :
print ( " Failed to align segment: duration smaller than 0.02s time precision " )
break
f1 = int ( t1 * SAMPLE_RATE )
f2 = int ( t2 * SAMPLE_RATE )
waveform_segment = audio [ : , f1 : f2 ]
with torch . inference_mode ( ) :
if model_type == " torchaudio " :
emissions , _ = model ( waveform_segment . to ( device ) )
elif model_type == " huggingface " :
emissions = model ( waveform_segment . to ( device ) ) . logits
else :
raise NotImplementedError ( f " Align model of type { model_type } not supported. " )
emissions = torch . log_softmax ( emissions , dim = - 1 )
emission = emissions [ 0 ] . cpu ( ) . detach ( )
2022-12-17 17:24:48 +00:00
2022-12-14 18:59:12 +00:00
trellis = get_trellis ( emission , tokens )
path = backtrack ( trellis , emission , tokens )
2023-01-05 11:15:19 +00:00
if path is None :
print ( " Failed to align segment: backtrack failed, resorting to original... " )
2023-01-24 15:02:08 +00:00
break
char_segments = merge_repeats ( path , transcription_cleaned )
# word_segments = merge_words(char_segments)
# sub-segments
if " seg-text " not in segment :
segment [ " seg-text " ] = [ transcription ]
v = 0
seg_lens = [ 0 ] + [ len ( x ) for x in segment [ " seg-text " ] ]
seg_lens_cumsum = [ v := v + n for n in seg_lens ]
sub_seg_idx = 0
char_level = {
" start " : [ ] ,
" end " : [ ] ,
" score " : [ ] ,
" word-index " : [ ] ,
}
word_level = {
" start " : [ ] ,
" end " : [ ] ,
" score " : [ ] ,
" segment-text-start " : [ ] ,
" segment-text-end " : [ ]
}
wdx = 0
seg_start_actual , seg_end_actual = None , None
duration = t2 - t1
ratio = duration * waveform_segment . size ( 0 ) / ( trellis . size ( 0 ) - 1 )
cdx_prev = 0
for cdx , char in enumerate ( transcription + " " ) :
is_last = False
if cdx == len ( transcription ) :
break
elif cdx + 1 == len ( transcription ) :
is_last = True
start , end , score = None , None , None
if cdx in clean_cdx :
char_seg = char_segments [ clean_cdx . index ( cdx ) ]
start = char_seg . start * ratio + t1
end = char_seg . end * ratio + t1
score = char_seg . score
char_level [ " start " ] . append ( start )
char_level [ " end " ] . append ( end )
char_level [ " score " ] . append ( score )
char_level [ " word-index " ] . append ( wdx )
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES :
# character == word
wdx + = 1
elif is_last or transcription [ cdx + 1 ] == " " or cdx == seg_lens_cumsum [ sub_seg_idx + 1 ] - 1 :
wdx + = 1
word_level [ " start " ] . append ( None )
word_level [ " end " ] . append ( None )
word_level [ " score " ] . append ( None )
word_level [ " segment-text-start " ] . append ( cdx_prev - seg_lens_cumsum [ sub_seg_idx ] )
word_level [ " segment-text-end " ] . append ( cdx + 1 - seg_lens_cumsum [ sub_seg_idx ] )
cdx_prev = cdx + 2
if is_last or cdx == seg_lens_cumsum [ sub_seg_idx + 1 ] - 1 :
if model_lang not in LANGUAGES_WITHOUT_SPACES :
char_level = pd . DataFrame ( char_level )
word_level = pd . DataFrame ( word_level )
not_space = pd . Series ( list ( segment [ " seg-text " ] [ sub_seg_idx ] ) ) != " "
word_level [ " start " ] = char_level [ not_space ] . groupby ( " word-index " ) [ " start " ] . min ( ) # take min of all chars in a word ignoring space
word_level [ " end " ] = char_level [ not_space ] . groupby ( " word-index " ) [ " end " ] . max ( ) # take max of all chars in a word
# fill missing
if interpolate_method != " ignore " :
word_level [ " start " ] = interpolate_nans ( word_level [ " start " ] , method = interpolate_method )
word_level [ " end " ] = interpolate_nans ( word_level [ " end " ] , method = interpolate_method )
word_level [ " start " ] = word_level [ " start " ] . values . tolist ( )
word_level [ " end " ] = word_level [ " end " ] . values . tolist ( )
word_level [ " score " ] = char_level . groupby ( " word-index " ) [ " score " ] . mean ( ) # take mean of all scores
char_level = char_level . replace ( { np . nan : None } ) . to_dict ( " list " )
word_level = pd . DataFrame ( word_level ) . replace ( { np . nan : None } ) . to_dict ( " list " )
2022-12-19 19:12:50 +00:00
else :
2023-01-24 15:02:08 +00:00
word_level = None
2023-01-24 16:37:19 +00:00
aligned_subsegments . append (
2023-01-24 15:02:08 +00:00
{
" text " : segment [ " seg-text " ] [ sub_seg_idx ] ,
" start " : seg_start_actual ,
" end " : seg_end_actual ,
" char-segments " : char_level ,
" word-segments " : word_level
}
)
if " language " in segment :
2023-01-24 16:37:19 +00:00
aligned_subsegments [ - 1 ] [ " language " ] = segment [ " language " ]
2023-01-24 15:02:08 +00:00
char_level = {
" start " : [ ] ,
" end " : [ ] ,
" score " : [ ] ,
" word-index " : [ ] ,
}
word_level = {
" start " : [ ] ,
" end " : [ ] ,
" score " : [ ] ,
" segment-text-start " : [ ] ,
" segment-text-end " : [ ]
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx + = 1
seg_start_actual , seg_end_actual = None , None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None :
seg_start_actual = start
if end is not None :
seg_end_actual = end
prev_t2 = segment [ " end " ]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success :
prev_t2 = 0
2023-01-24 16:37:19 +00:00
start = interpolate_nans ( pd . DataFrame ( aligned_subsegments ) [ " start " ] , method = interpolate_method )
end = interpolate_nans ( pd . DataFrame ( aligned_subsegments ) [ " end " ] , method = interpolate_method )
for idx , seg in enumerate ( aligned_subsegments ) :
seg [ ' start ' ] = start . iloc [ idx ]
seg [ ' end ' ] = end . iloc [ idx ]
aligned_segments + = aligned_subsegments
2023-01-24 15:02:08 +00:00
# create word level segments for .srt
word_seg = [ ]
for seg in aligned_segments :
if model_lang in LANGUAGES_WITHOUT_SPACES :
# character based
seg [ " word-segments " ] = seg [ " char-segments " ]
seg [ " word-segments " ] [ " segment-text-start " ] = range ( len ( seg [ ' word-segments ' ] [ ' start ' ] ) )
seg [ " word-segments " ] [ " segment-text-end " ] = range ( 1 , len ( seg [ ' word-segments ' ] [ ' start ' ] ) + 1 )
wseg = pd . DataFrame ( seg [ " word-segments " ] ) . replace ( { np . nan : None } )
for wdx , wrow in wseg . iterrows ( ) :
if wrow [ " start " ] is not None :
word_seg . append (
{
" start " : wrow [ " start " ] ,
" end " : wrow [ " end " ] ,
" text " : seg [ " text " ] [ int ( wrow [ " segment-text-start " ] ) : int ( wrow [ " segment-text-end " ] ) ]
}
)
2023-01-05 11:15:19 +00:00
2023-01-24 15:02:08 +00:00
return { " segments " : aligned_segments , " word_segments " : word_seg }
2022-12-14 18:59:12 +00:00
2022-12-20 19:54:55 +00:00
def load_align_model ( language_code , device , model_name = None ) :
if model_name is None :
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH :
model_name = DEFAULT_ALIGN_MODELS_TORCH [ language_code ]
elif language_code in DEFAULT_ALIGN_MODELS_HF :
model_name = DEFAULT_ALIGN_MODELS_HF [ language_code ]
else :
print ( f " There is no default alignment model set for this language ( { language_code } ). \
Please find a wav2vec2 .0 model finetuned on this language in https : / / huggingface . co / models , then pass the model name in - - align_model [ MODEL_NAME ] " )
raise ValueError ( f " No default align-model for language: { language_code } " )
if model_name in torchaudio . pipelines . __all__ :
pipeline_type = " torchaudio "
bundle = torchaudio . pipelines . __dict__ [ model_name ]
align_model = bundle . get_model ( ) . to ( device )
labels = bundle . get_labels ( )
align_dictionary = { c . lower ( ) : i for i , c in enumerate ( labels ) }
else :
try :
processor = AutoProcessor . from_pretrained ( model_name )
align_model = Wav2Vec2ForCTC . from_pretrained ( model_name )
except Exception as e :
print ( e )
print ( f " Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models " )
raise ValueError ( f ' The chosen align_model " { model_name } " could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14) ' )
pipeline_type = " huggingface "
align_model = align_model . to ( device )
labels = processor . tokenizer . get_vocab ( )
align_dictionary = { char . lower ( ) : code for char , code in processor . tokenizer . get_vocab ( ) . items ( ) }
align_metadata = { " language " : language_code , " dictionary " : align_dictionary , " type " : pipeline_type }
return align_model , align_metadata
2023-01-20 12:54:20 +00:00
2023-01-24 15:02:08 +00:00
def merge_chunks ( segments , chunk_size = CHUNK_LENGTH ) :
"""
Merge VAD segments into larger segments of size ~ CHUNK_LENGTH .
"""
2023-01-20 12:54:20 +00:00
curr_start = 0
curr_end = 0
merged_segments = [ ]
seg_idxs = [ ]
speaker_idxs = [ ]
for sdx , seg in enumerate ( segments ) :
if seg . end - curr_start > chunk_size and curr_end - curr_start > 0 :
merged_segments . append ( {
" start " : curr_start ,
" end " : curr_end ,
" segments " : seg_idxs ,
} )
curr_start = seg . start
seg_idxs = [ ]
speaker_idxs = [ ]
curr_end = seg . end
seg_idxs . append ( ( seg . start , seg . end ) )
speaker_idxs . append ( seg . speaker )
# add final
merged_segments . append ( {
" start " : curr_start ,
" end " : curr_end ,
" segments " : seg_idxs ,
} )
return merged_segments
2023-01-24 15:02:08 +00:00
def transcribe_with_vad (
2023-01-20 12:54:20 +00:00
model : " Whisper " ,
audio : Union [ str , np . ndarray , torch . Tensor ] ,
2023-01-24 15:02:08 +00:00
vad_pipeline ,
2023-01-20 12:54:20 +00:00
mel = None ,
2023-01-24 15:02:08 +00:00
verbose : Optional [ bool ] = None ,
2023-01-20 12:54:20 +00:00
* * kwargs
) :
2023-01-24 15:02:08 +00:00
"""
Transcribe per VAD segment
"""
2023-01-20 12:54:20 +00:00
if mel is None :
mel = log_mel_spectrogram ( audio )
prev = 0
2023-01-24 15:02:08 +00:00
output = { " segments " : [ ] }
2023-01-20 12:54:20 +00:00
2023-01-24 15:02:08 +00:00
vad_segments_list = [ ]
vad_segments = vad_pipeline ( audio )
for speech_turn in vad_segments . get_timeline ( ) . support ( ) :
vad_segments_list . append ( Segment ( speech_turn . start , speech_turn . end , " UNKNOWN " ) )
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks ( vad_segments_list )
2023-01-20 12:54:20 +00:00
2023-01-24 15:02:08 +00:00
for sdx , seg_t in enumerate ( vad_segments ) :
if verbose :
print ( f " ~~ Transcribing VAD chunk: ( { format_timestamp ( seg_t [ ' start ' ] ) } --> { format_timestamp ( seg_t [ ' end ' ] ) } ) ~~ " )
seg_f_start , seg_f_end = int ( seg_t [ " start " ] * SAMPLE_RATE / HOP_LENGTH ) , int ( seg_t [ " end " ] * SAMPLE_RATE / HOP_LENGTH )
2023-01-20 12:54:20 +00:00
local_f_start , local_f_end = seg_f_start - prev , seg_f_end - prev
mel = mel [ : , local_f_start : ] # seek forward
prev = seg_f_start
local_mel = mel [ : , : local_f_end - local_f_start ]
2023-01-24 15:02:08 +00:00
result = transcribe ( model , audio , mel = local_mel , verbose = verbose , * * kwargs )
seg_t [ " text " ] = result [ " text " ]
output [ " segments " ] . append (
2023-01-20 12:54:20 +00:00
{
2023-01-24 15:02:08 +00:00
" start " : seg_t [ " start " ] ,
" end " : seg_t [ " end " ] ,
" language " : result [ " language " ] ,
" text " : result [ " text " ] ,
" seg-text " : [ x [ " text " ] for x in result [ " segments " ] ] ,
" seg-start " : [ x [ " start " ] for x in result [ " segments " ] ] ,
" seg-end " : [ x [ " end " ] for x in result [ " segments " ] ] ,
2023-01-20 12:54:20 +00:00
}
)
2023-01-24 15:02:08 +00:00
output [ " language " ] = output [ " segments " ] [ 0 ] [ " language " ]
2023-01-20 12:54:20 +00:00
return output
2023-01-24 15:02:08 +00:00
def assign_word_speakers ( diarize_df , result_segments , fill_nearest = False ) :
for seg in result_segments :
wdf = pd . DataFrame ( seg [ ' word-segments ' ] )
if len ( wdf [ ' start ' ] . dropna ( ) ) == 0 :
wdf [ ' start ' ] = seg [ ' start ' ]
wdf [ ' end ' ] = seg [ ' end ' ]
speakers = [ ]
for wdx , wrow in wdf . iterrows ( ) :
diarize_df [ ' intersection ' ] = np . minimum ( diarize_df [ ' end ' ] , wrow [ ' end ' ] ) - np . maximum ( diarize_df [ ' start ' ] , wrow [ ' start ' ] )
diarize_df [ ' union ' ] = np . maximum ( diarize_df [ ' end ' ] , wrow [ ' end ' ] ) - np . minimum ( diarize_df [ ' start ' ] , wrow [ ' start ' ] )
# remove no hit
if not fill_nearest :
dia_tmp = diarize_df [ diarize_df [ ' intersection ' ] > 0 ]
else :
dia_tmp = diarize_df
if len ( dia_tmp ) == 0 :
speaker = None
else :
speaker = dia_tmp . sort_values ( " intersection " , ascending = False ) . iloc [ 0 ] [ 2 ]
speakers . append ( speaker )
seg [ ' word-segments ' ] [ ' speaker ' ] = speakers
seg [ " speaker " ] = pd . Series ( speakers ) . value_counts ( ) . index [ 0 ]
# create word level segments for .srt
word_seg = [ ]
for seg in result_segments :
wseg = pd . DataFrame ( seg [ " word-segments " ] )
for wdx , wrow in wseg . iterrows ( ) :
if wrow [ " start " ] is not None :
speaker = wrow [ ' speaker ' ]
if speaker is None or speaker == np . nan :
speaker = " UNKNOWN "
word_seg . append (
{
" start " : wrow [ " start " ] ,
" end " : wrow [ " end " ] ,
" text " : f " [ { speaker } ]: " + seg [ " text " ] [ int ( wrow [ " segment-text-start " ] ) : int ( wrow [ " segment-text-end " ] ) ]
}
)
# TODO: create segments but split words on new speaker
return result_segments , word_seg
2023-01-20 12:54:20 +00:00
class Segment :
def __init__ ( self , start , end , speaker = None ) :
self . start = start
self . end = end
self . speaker = speaker
2022-12-14 18:59:12 +00:00
def cli ( ) :
from . import available_models
parser = argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( " audio " , nargs = " + " , type = str , help = " audio file(s) to transcribe " )
parser . add_argument ( " --model " , default = " small " , choices = available_models ( ) , help = " name of the Whisper model to use " )
parser . add_argument ( " --model_dir " , type = str , default = None , help = " the path to save model files; uses ~/.cache/whisper by default " )
parser . add_argument ( " --device " , default = " cuda " if torch . cuda . is_available ( ) else " cpu " , help = " device to use for PyTorch inference " )
# alignment params
2022-12-20 19:54:55 +00:00
parser . add_argument ( " --align_model " , default = None , help = " Name of phoneme-level ASR model to do alignment " )
2022-12-17 17:24:48 +00:00
parser . add_argument ( " --align_extend " , default = 2 , type = float , help = " Seconds before and after to extend the whisper segments for alignment " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --align_from_prev " , default = True , type = bool , help = " Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment " )
2023-01-24 15:02:08 +00:00
parser . add_argument ( " --interpolate_method " , default = " nearest " , choices = [ " nearest " , " linear " , " ignore " ] , help = " For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring. " )
# vad params
parser . add_argument ( " --vad_filter " , action = " store_true " , help = " Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute. " )
2023-01-20 12:54:20 +00:00
parser . add_argument ( " --vad_input " , default = None , type = str )
2023-01-24 15:02:08 +00:00
# diarization params
parser . add_argument ( " --diarize " , action = ' store_true ' )
parser . add_argument ( " --min_speakers " , default = None , type = int )
parser . add_argument ( " --max_speakers " , default = None , type = int )
# output save params
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --output_dir " , " -o " , type = str , default = " . " , help = " directory to save the outputs " )
2023-01-24 15:02:08 +00:00
parser . add_argument ( " --output_type " , default = " all " , choices = [ " all " , " srt " , " srt-word " , " vtt " , " txt " , " tsv " , " ass " , " ass-char " ] , help = " File type for desired output save " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --verbose " , type = str2bool , default = True , help = " whether to print out the progress and debug messages " )
parser . add_argument ( " --task " , type = str , default = " transcribe " , choices = [ " transcribe " , " translate " ] , help = " whether to perform X->X speech recognition ( ' transcribe ' ) or X->English translation ( ' translate ' ) " )
parser . add_argument ( " --language " , type = str , default = None , choices = sorted ( LANGUAGES . keys ( ) ) + sorted ( [ k . title ( ) for k in TO_LANGUAGE_CODE . keys ( ) ] ) , help = " language spoken in the audio, specify None to perform language detection " )
parser . add_argument ( " --temperature " , type = float , default = 0 , help = " temperature to use for sampling " )
parser . add_argument ( " --best_of " , type = optional_int , default = 5 , help = " number of candidates when sampling with non-zero temperature " )
parser . add_argument ( " --beam_size " , type = optional_int , default = 5 , help = " number of beams in beam search, only applicable when temperature is zero " )
parser . add_argument ( " --patience " , type = float , default = None , help = " optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search " )
parser . add_argument ( " --length_penalty " , type = float , default = None , help = " optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default " )
parser . add_argument ( " --suppress_tokens " , type = str , default = " -1 " , help = " comma-separated list of token ids to suppress during sampling; ' -1 ' will suppress most special characters except common punctuations " )
parser . add_argument ( " --initial_prompt " , type = str , default = None , help = " optional text to provide as a prompt for the first window. " )
2022-12-15 19:44:49 +00:00
parser . add_argument ( " --condition_on_previous_text " , type = str2bool , default = False , help = " if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --fp16 " , type = str2bool , default = True , help = " whether to perform inference in fp16; True by default " )
parser . add_argument ( " --temperature_increment_on_fallback " , type = optional_float , default = 0.2 , help = " temperature to increase when falling back when the decoding fails to meet either of the thresholds below " )
parser . add_argument ( " --compression_ratio_threshold " , type = optional_float , default = 2.4 , help = " if the gzip compression ratio is higher than this value, treat the decoding as failed " )
parser . add_argument ( " --logprob_threshold " , type = optional_float , default = - 1.0 , help = " if the average log probability is lower than this value, treat the decoding as failed " )
parser . add_argument ( " --no_speech_threshold " , type = optional_float , default = 0.6 , help = " if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence " )
parser . add_argument ( " --threads " , type = optional_int , default = 0 , help = " number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS " )
args = parser . parse_args ( ) . __dict__
model_name : str = args . pop ( " model " )
model_dir : str = args . pop ( " model_dir " )
output_dir : str = args . pop ( " output_dir " )
output_type : str = args . pop ( " output_type " )
device : str = args . pop ( " device " )
align_model : str = args . pop ( " align_model " )
align_extend : float = args . pop ( " align_extend " )
align_from_prev : bool = args . pop ( " align_from_prev " )
2023-01-24 15:02:08 +00:00
interpolate_method : bool = args . pop ( " interpolate_method " )
2022-12-14 18:59:12 +00:00
2023-01-20 12:54:20 +00:00
vad_filter : bool = args . pop ( " vad_filter " )
vad_input : bool = args . pop ( " vad_input " )
2023-01-24 15:02:08 +00:00
diarize : bool = args . pop ( " diarize " )
min_speakers : int = args . pop ( " min_speakers " )
max_speakers : int = args . pop ( " max_speakers " )
2023-01-20 12:54:20 +00:00
vad_pipeline = None
if vad_input is not None :
vad_input = pd . read_csv ( vad_input , header = None , sep = " " )
elif vad_filter :
from pyannote . audio import Pipeline
vad_pipeline = Pipeline . from_pretrained ( " pyannote/voice-activity-detection " )
2023-01-24 15:02:08 +00:00
diarize_pipeline = None
if diarize :
from pyannote . audio import Pipeline
diarize_pipeline = Pipeline . from_pretrained ( " pyannote/speaker-diarization@2.1 " )
2023-01-20 12:54:20 +00:00
2022-12-14 18:59:12 +00:00
os . makedirs ( output_dir , exist_ok = True )
if model_name . endswith ( " .en " ) and args [ " language " ] not in { " en " , " English " } :
if args [ " language " ] is not None :
2023-01-24 15:02:08 +00:00
warnings . warn ( f ' { model_name } is an English-only model but receipted " { args [ " language " ] } " ; using English instead. ' )
2022-12-14 18:59:12 +00:00
args [ " language " ] = " en "
temperature = args . pop ( " temperature " )
temperature_increment_on_fallback = args . pop ( " temperature_increment_on_fallback " )
if temperature_increment_on_fallback is not None :
temperature = tuple ( np . arange ( temperature , 1.0 + 1e-6 , temperature_increment_on_fallback ) )
else :
temperature = [ temperature ]
threads = args . pop ( " threads " )
if threads > 0 :
torch . set_num_threads ( threads )
from . import load_model
model = load_model ( model_name , device = device , download_root = model_dir )
2022-12-20 19:54:55 +00:00
align_language = args [ " language " ] if args [ " language " ] is not None else " en " # default to loading english if not specified
align_model , align_metadata = load_align_model ( align_language , device , model_name = align_model )
2022-12-14 18:59:12 +00:00
for audio_path in args . pop ( " audio " ) :
2023-01-24 15:02:08 +00:00
if vad_filter :
print ( " Performing VAD... " )
result = transcribe_with_vad ( model , audio_path , vad_pipeline , temperature = temperature , * * args )
2023-01-20 12:54:20 +00:00
else :
print ( " Performing transcription... " )
result = transcribe ( model , audio_path , temperature = temperature , * * args )
2022-12-20 19:54:55 +00:00
if result [ " language " ] != align_metadata [ " language " ] :
# load new language
print ( f " New language found ( { result [ ' language ' ] } )! Previous was ( { align_metadata [ ' language ' ] } ), loading new alignment model for new language... " )
align_model , align_metadata = load_align_model ( result [ " language " ] , device )
2023-01-08 14:01:10 +00:00
print ( " Performing alignment... " )
2022-12-23 00:41:12 +00:00
result_aligned = align ( result [ " segments " ] , align_model , align_metadata , audio_path , device ,
2023-01-24 15:02:08 +00:00
extend_duration = align_extend , start_from_previous = align_from_prev , interpolate_method = interpolate_method )
2022-12-14 18:59:12 +00:00
audio_basename = os . path . basename ( audio_path )
2023-01-24 15:02:08 +00:00
if diarize :
print ( " Performing diarization... " )
diarize_segments = diarize_pipeline ( audio_path , min_speakers = min_speakers , max_speakers = max_speakers )
diarize_df = pd . DataFrame ( diarize_segments . itertracks ( yield_label = True ) )
diarize_df [ ' start ' ] = diarize_df [ 0 ] . apply ( lambda x : x . start )
diarize_df [ ' end ' ] = diarize_df [ 0 ] . apply ( lambda x : x . end )
# assumes each utterance is single speaker (needs fix)
result_segments , word_segments = assign_word_speakers ( diarize_df , result_aligned [ " segments " ] , fill_nearest = True )
result_aligned [ " segments " ] = result_segments
result_aligned [ " word_segments " ] = word_segments
2022-12-14 18:59:12 +00:00
# save TXT
if output_type in [ " txt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .txt " ) , " w " , encoding = " utf-8 " ) as txt :
write_txt ( result_aligned [ " segments " ] , file = txt )
# save VTT
if output_type in [ " vtt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .vtt " ) , " w " , encoding = " utf-8 " ) as vtt :
write_vtt ( result_aligned [ " segments " ] , file = vtt )
# save SRT
if output_type in [ " srt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .srt " ) , " w " , encoding = " utf-8 " ) as srt :
write_srt ( result_aligned [ " segments " ] , file = srt )
2022-12-17 17:24:48 +00:00
2023-01-24 15:02:08 +00:00
# save TSV
if output_type in [ " tsv " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .srt " ) , " w " , encoding = " utf-8 " ) as srt :
write_tsv ( result_aligned [ " segments " ] , file = srt )
# save SRT word-level
if output_type in [ " srt-word " , " all " ] :
# save per-word SRT
with open ( os . path . join ( output_dir , audio_basename + " .word.srt " ) , " w " , encoding = " utf-8 " ) as srt :
write_srt ( result_aligned [ " word_segments " ] , file = srt )
2022-12-19 19:12:50 +00:00
2022-12-17 17:24:48 +00:00
# save ASS
2023-01-24 15:02:08 +00:00
if output_type in [ " ass " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .ass " ) , " w " , encoding = " utf-8 " ) as ass :
write_ass ( result_aligned [ " segments " ] , file = ass )
2023-01-20 12:54:20 +00:00
2023-01-24 15:02:08 +00:00
# save ASS character-level
if output_type in [ " ass-char " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .char.ass " ) , " w " , encoding = " utf-8 " ) as ass :
write_ass ( result_aligned [ " segments " ] , file = ass , resolution = " char " )
2022-12-14 18:59:12 +00:00
2023-01-24 15:02:08 +00:00
if __name__ == " __main__ " :
2022-12-14 18:59:12 +00:00
cli ( )