2022-12-14 18:59:12 +00:00
import argparse
import os
import warnings
2022-12-21 01:03:52 +09:00
from typing import List , Optional , Tuple , Union , Iterator , TYPE_CHECKING
2022-12-14 18:59:12 +00:00
import numpy as np
import torch
import torchaudio
2022-12-19 22:28:28 +09:00
from transformers import AutoProcessor , Wav2Vec2ForCTC
2022-12-21 01:03:52 +09:00
import tqdm
from . audio import SAMPLE_RATE , N_FRAMES , HOP_LENGTH , pad_or_trim , log_mel_spectrogram , load_audio
from . alignment import get_trellis , backtrack , merge_repeats , merge_words
2022-12-14 18:59:12 +00:00
from . decoding import DecodingOptions , DecodingResult
from . tokenizer import LANGUAGES , TO_LANGUAGE_CODE , get_tokenizer
2022-12-21 01:03:52 +09:00
from . utils import exact_div , format_timestamp , optional_int , optional_float , str2bool , write_txt , write_vtt , write_srt , write_ass
2022-12-14 18:59:12 +00:00
if TYPE_CHECKING :
from . model import Whisper
2022-12-20 19:54:55 +00:00
LANGUAGES_WITHOUT_SPACES = [ " ja " , " zh " ]
DEFAULT_ALIGN_MODELS_TORCH = {
" en " : " WAV2VEC2_ASR_BASE_960H " ,
" fr " : " VOXPOPULI_ASR_BASE_10K_FR " ,
" de " : " VOXPOPULI_ASR_BASE_10K_DE " ,
" es " : " VOXPOPULI_ASR_BASE_10K_ES " ,
" it " : " VOXPOPULI_ASR_BASE_10K_IT " ,
}
DEFAULT_ALIGN_MODELS_HF = {
" ja " : " jonatasgrosman/wav2vec2-large-xlsr-53-japanese " ,
2022-12-23 00:41:12 +00:00
" zh " : " jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn " ,
" nl " : " jonatasgrosman/wav2vec2-large-xlsr-53-dutch " ,
2022-12-24 15:05:13 +02:00
" uk " : " Yehor/wav2vec2-xls-r-300m-uk-with-small-lm " ,
2022-12-20 19:54:55 +00:00
}
2022-12-21 02:11:08 +09:00
2022-12-14 18:59:12 +00:00
def transcribe (
model : " Whisper " ,
audio : Union [ str , np . ndarray , torch . Tensor ] ,
* ,
verbose : Optional [ bool ] = None ,
temperature : Union [ float , Tuple [ float , . . . ] ] = ( 0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ) ,
compression_ratio_threshold : Optional [ float ] = 2.4 ,
logprob_threshold : Optional [ float ] = - 1.0 ,
no_speech_threshold : Optional [ float ] = 0.6 ,
2022-12-19 19:41:39 +00:00
condition_on_previous_text : bool = False , # turn off by default due to errors it causes
2022-12-14 18:59:12 +00:00
* * decode_options ,
) :
"""
Transcribe an audio file using Whisper
Parameters
- - - - - - - - - -
model : Whisper
The Whisper model instance
audio : Union [ str , np . ndarray , torch . Tensor ]
The path to the audio file to open , or the audio waveform
verbose : bool
Whether to display the text being decoded to the console . If True , displays all the details ,
If False , displays minimal details . If None , does not display anything
temperature : Union [ float , Tuple [ float , . . . ] ]
Temperature for sampling . It can be a tuple of temperatures , which will be successfully used
upon failures according to either ` compression_ratio_threshold ` or ` logprob_threshold ` .
compression_ratio_threshold : float
If the gzip compression ratio is above this value , treat as failed
logprob_threshold : float
If the average log probability over sampled tokens is below this value , treat as failed
no_speech_threshold : float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below ` logprob_threshold ` , consider the segment as silent
condition_on_previous_text : bool
if True , the previous output of the model is provided as a prompt for the next window ;
disabling may make the text inconsistent across windows , but the model becomes less prone to
getting stuck in a failure loop , such as repetition looping or timestamps going out of sync .
decode_options : dict
Keyword arguments to construct ` DecodingOptions ` instances
Returns
- - - - - - -
A dictionary containing the resulting text ( " text " ) and segment - level details ( " segments " ) , and
the spoken language ( " language " ) , which is detected when ` decode_options [ " language " ] ` is None .
"""
dtype = torch . float16 if decode_options . get ( " fp16 " , True ) else torch . float32
if model . device == torch . device ( " cpu " ) :
if torch . cuda . is_available ( ) :
warnings . warn ( " Performing inference on CPU when CUDA is available " )
if dtype == torch . float16 :
warnings . warn ( " FP16 is not supported on CPU; using FP32 instead " )
dtype = torch . float32
if dtype == torch . float32 :
decode_options [ " fp16 " ] = False
mel = log_mel_spectrogram ( audio )
if decode_options . get ( " language " , None ) is None :
if not model . is_multilingual :
decode_options [ " language " ] = " en "
else :
if verbose :
print ( " Detecting language using up to the first 30 seconds. Use `--language` to specify the language " )
segment = pad_or_trim ( mel , N_FRAMES ) . to ( model . device ) . to ( dtype )
_ , probs = model . detect_language ( segment )
decode_options [ " language " ] = max ( probs , key = probs . get )
if verbose is not None :
print ( f " Detected language: { LANGUAGES [ decode_options [ ' language ' ] ] . title ( ) } " )
language = decode_options [ " language " ]
task = decode_options . get ( " task " , " transcribe " )
tokenizer = get_tokenizer ( model . is_multilingual , language = language , task = task )
def decode_with_fallback ( segment : torch . Tensor ) - > DecodingResult :
temperatures = [ temperature ] if isinstance ( temperature , ( int , float ) ) else temperature
decode_result = None
for t in temperatures :
kwargs = { * * decode_options }
if t > 0 :
# disable beam_size and patience when t > 0
kwargs . pop ( " beam_size " , None )
kwargs . pop ( " patience " , None )
else :
# disable best_of when t == 0
kwargs . pop ( " best_of " , None )
options = DecodingOptions ( * * kwargs , temperature = t )
decode_result = model . decode ( segment , options )
needs_fallback = False
if compression_ratio_threshold is not None and decode_result . compression_ratio > compression_ratio_threshold :
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result . avg_logprob < logprob_threshold :
needs_fallback = True # average log probability is too low
if not needs_fallback :
break
return decode_result
seek = 0
input_stride = exact_div (
N_FRAMES , model . dims . n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = [ ]
all_segments = [ ]
prompt_reset_since = 0
initial_prompt = decode_options . pop ( " initial_prompt " , None ) or [ ]
if initial_prompt :
initial_prompt = tokenizer . encode ( " " + initial_prompt . strip ( ) )
all_tokens . extend ( initial_prompt )
def add_segment (
* , start : float , end : float , text_tokens : torch . Tensor , result : DecodingResult
) :
text = tokenizer . decode ( [ token for token in text_tokens if token < tokenizer . eot ] )
if len ( text . strip ( ) ) == 0 : # skip empty text output
return
all_segments . append (
{
" id " : len ( all_segments ) ,
" seek " : seek ,
" start " : start ,
" end " : end ,
" text " : text ,
" tokens " : text_tokens . tolist ( ) ,
" temperature " : result . temperature ,
" avg_logprob " : result . avg_logprob ,
" compression_ratio " : result . compression_ratio ,
" no_speech_prob " : result . no_speech_prob ,
}
)
if verbose :
print ( f " [ { format_timestamp ( start ) } --> { format_timestamp ( end ) } ] { text } " )
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel . shape [ - 1 ]
previous_seek_value = seek
with tqdm . tqdm ( total = num_frames , unit = ' frames ' , disable = verbose is not False ) as pbar :
while seek < num_frames :
timestamp_offset = float ( seek * HOP_LENGTH / SAMPLE_RATE )
segment = pad_or_trim ( mel [ : , seek : ] , N_FRAMES ) . to ( model . device ) . to ( dtype )
segment_duration = segment . shape [ - 1 ] * HOP_LENGTH / SAMPLE_RATE
decode_options [ " prompt " ] = all_tokens [ prompt_reset_since : ]
result : DecodingResult = decode_with_fallback ( segment )
tokens = torch . tensor ( result . tokens )
if no_speech_threshold is not None :
# no voice activity check
should_skip = result . no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result . avg_logprob > logprob_threshold :
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip :
seek + = segment . shape [ - 1 ] # fast-forward to the next segment boundary
continue
timestamp_tokens : torch . Tensor = tokens . ge ( tokenizer . timestamp_begin )
consecutive = torch . where ( timestamp_tokens [ : - 1 ] & timestamp_tokens [ 1 : ] ) [ 0 ] . add_ ( 1 )
if len ( consecutive ) > 0 : # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive :
sliced_tokens = tokens [ last_slice : current_slice ]
start_timestamp_position = (
sliced_tokens [ 0 ] . item ( ) - tokenizer . timestamp_begin
)
end_timestamp_position = (
sliced_tokens [ - 1 ] . item ( ) - tokenizer . timestamp_begin
)
add_segment (
start = timestamp_offset + start_timestamp_position * time_precision ,
end = timestamp_offset + end_timestamp_position * time_precision ,
text_tokens = sliced_tokens [ 1 : - 1 ] ,
result = result ,
)
last_slice = current_slice
last_timestamp_position = (
tokens [ last_slice - 1 ] . item ( ) - tokenizer . timestamp_begin
)
seek + = last_timestamp_position * input_stride
all_tokens . extend ( tokens [ : last_slice + 1 ] . tolist ( ) )
else :
duration = segment_duration
timestamps = tokens [ timestamp_tokens . nonzero ( ) . flatten ( ) ]
if len ( timestamps ) > 0 and timestamps [ - 1 ] . item ( ) != tokenizer . timestamp_begin :
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps [ - 1 ] . item ( ) - tokenizer . timestamp_begin
duration = last_timestamp_position * time_precision
add_segment (
start = timestamp_offset ,
end = timestamp_offset + duration ,
text_tokens = tokens ,
result = result ,
)
seek + = segment . shape [ - 1 ]
all_tokens . extend ( tokens . tolist ( ) )
if not condition_on_previous_text or result . temperature > 0.5 :
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len ( all_tokens )
# update progress bar
pbar . update ( min ( num_frames , seek ) - previous_seek_value )
previous_seek_value = seek
return dict ( text = tokenizer . decode ( all_tokens [ len ( initial_prompt ) : ] ) , segments = all_segments , language = language )
def align (
transcript : Iterator [ dict ] ,
model : torch . nn . Module ,
2022-12-20 19:54:55 +00:00
align_model_metadata : dict ,
2022-12-14 18:59:12 +00:00
audio : Union [ str , np . ndarray , torch . Tensor ] ,
device : str ,
extend_duration : float = 0.0 ,
start_from_previous : bool = True ,
2022-12-19 19:12:50 +00:00
drop_non_aligned_words : bool = False ,
2022-12-14 18:59:12 +00:00
) :
print ( " Performing alignment... " )
if not torch . is_tensor ( audio ) :
if isinstance ( audio , str ) :
audio = load_audio ( audio )
audio = torch . from_numpy ( audio )
if len ( audio . shape ) == 1 :
audio = audio . unsqueeze ( 0 )
MAX_DURATION = audio . shape [ 1 ] / SAMPLE_RATE
2022-12-20 19:54:55 +00:00
model_dictionary = align_model_metadata [ ' dictionary ' ]
model_lang = align_model_metadata [ ' language ' ]
model_type = align_model_metadata [ ' type ' ]
2022-12-14 18:59:12 +00:00
prev_t2 = 0
2022-12-19 19:12:50 +00:00
word_segments_list = [ ]
2022-12-14 18:59:12 +00:00
for idx , segment in enumerate ( transcript ) :
2022-12-29 10:54:23 +00:00
if int ( segment [ ' start ' ] * SAMPLE_RATE ) > = audio . shape [ 1 ] :
# original whisper error, transcript is outside of duration of audio, not possible. Skip to next (finish).
continue
2022-12-14 18:59:12 +00:00
t1 = max ( segment [ ' start ' ] - extend_duration , 0 )
t2 = min ( segment [ ' end ' ] + extend_duration , MAX_DURATION )
if start_from_previous and t1 < prev_t2 :
t1 = prev_t2
f1 = int ( t1 * SAMPLE_RATE )
f2 = int ( t2 * SAMPLE_RATE )
waveform_segment = audio [ : , f1 : f2 ]
with torch . inference_mode ( ) :
2022-12-20 19:54:55 +00:00
if model_type == " torchaudio " :
2022-12-21 01:42:45 +09:00
emissions , _ = model ( waveform_segment . to ( device ) )
2022-12-20 19:54:55 +00:00
elif model_type == " huggingface " :
2022-12-21 01:42:45 +09:00
emissions = model ( waveform_segment . to ( device ) ) . logits
2022-12-20 19:54:55 +00:00
else :
raise NotImplementedError ( f " Align model of type { model_type } not supported. " )
2022-12-14 18:59:12 +00:00
emissions = torch . log_softmax ( emissions , dim = - 1 )
emission = emissions [ 0 ] . cpu ( ) . detach ( )
transcription = segment [ ' text ' ] . strip ( )
2022-12-23 00:41:12 +00:00
if model_lang not in LANGUAGES_WITHOUT_SPACES :
2022-12-19 22:28:28 +09:00
t_words = transcription . split ( ' ' )
else :
2022-12-21 02:11:08 +09:00
t_words = [ c for c in transcription ]
2022-12-19 22:28:28 +09:00
2022-12-18 12:21:24 +00:00
t_words_clean = [ ' ' . join ( [ w for w in word if w . lower ( ) in model_dictionary . keys ( ) ] ) for word in t_words ]
2022-12-14 18:59:12 +00:00
t_words_nonempty = [ x for x in t_words_clean if x != " " ]
t_words_nonempty_idx = [ x for x in range ( len ( t_words_clean ) ) if t_words_clean [ x ] != " " ]
2022-12-17 17:24:48 +00:00
segment [ ' word-level ' ] = [ ]
2022-12-14 18:59:12 +00:00
if len ( t_words_nonempty ) > 0 :
2022-12-18 12:21:24 +00:00
transcription_cleaned = " | " . join ( t_words_nonempty ) . lower ( )
2022-12-14 18:59:12 +00:00
tokens = [ model_dictionary [ c ] for c in transcription_cleaned ]
trellis = get_trellis ( emission , tokens )
path = backtrack ( trellis , emission , tokens )
segments = merge_repeats ( path , transcription_cleaned )
word_segments = merge_words ( segments )
ratio = waveform_segment . size ( 0 ) / ( trellis . size ( 0 ) - 1 )
duration = t2 - t1
local = [ ]
t_local = [ None ] * len ( t_words )
for wdx , word in enumerate ( word_segments ) :
t1_ = ratio * word . start
t2_ = ratio * word . end
local . append ( ( t1_ , t2_ ) )
t_local [ t_words_nonempty_idx [ wdx ] ] = ( t1_ * duration + t1 , t2_ * duration + t1 )
t1_actual = t1 + local [ 0 ] [ 0 ] * duration
t2_actual = t1 + local [ - 1 ] [ 1 ] * duration
segment [ ' start ' ] = t1_actual
segment [ ' end ' ] = t2_actual
2022-12-20 22:29:18 +09:00
prev_t2 = segment [ ' end ' ]
2022-12-14 18:59:12 +00:00
2022-12-19 19:12:50 +00:00
# for the .ass output
2022-12-14 18:59:12 +00:00
for x in range ( len ( t_local ) ) :
curr_word = t_words [ x ]
curr_timestamp = t_local [ x ]
if curr_timestamp is not None :
2022-12-17 17:24:48 +00:00
segment [ ' word-level ' ] . append ( { " text " : curr_word , " start " : curr_timestamp [ 0 ] , " end " : curr_timestamp [ 1 ] } )
2022-12-14 18:59:12 +00:00
else :
2022-12-17 17:24:48 +00:00
segment [ ' word-level ' ] . append ( { " text " : curr_word , " start " : None , " end " : None } )
2022-12-14 18:59:12 +00:00
2022-12-19 19:12:50 +00:00
# for per-word .srt ouput
# merge missing words to previous, or merge with next word ahead if idx == 0
for x in range ( len ( t_local ) ) :
curr_word = t_words [ x ]
curr_timestamp = t_local [ x ]
if curr_timestamp is not None :
word_segments_list . append ( { " text " : curr_word , " start " : curr_timestamp [ 0 ] , " end " : curr_timestamp [ 1 ] } )
elif not drop_non_aligned_words :
# then we merge
if x == 0 :
t_words [ x + 1 ] = " " . join ( [ curr_word , t_words [ x + 1 ] ] )
else :
2022-12-21 02:11:08 +09:00
word_segments_list [ - 1 ] [ ' text ' ] + = ' ' + curr_word
2022-12-14 18:59:12 +00:00
else :
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
prev_t2 = 0
2022-12-17 17:24:48 +00:00
segment [ ' word-level ' ] . append ( { " text " : segment [ ' text ' ] , " start " : segment [ ' start ' ] , " end " : segment [ ' end ' ] } )
2022-12-19 19:12:50 +00:00
word_segments_list . append ( { " text " : segment [ ' text ' ] , " start " : segment [ ' start ' ] , " end " : segment [ ' end ' ] } )
2022-12-14 18:59:12 +00:00
print ( f " [ { format_timestamp ( segment [ ' start ' ] ) } --> { format_timestamp ( segment [ ' end ' ] ) } ] { segment [ ' text ' ] } " )
2022-12-19 19:12:50 +00:00
return { " segments " : transcript , " word_segments " : word_segments_list }
2022-12-14 18:59:12 +00:00
2022-12-20 19:54:55 +00:00
def load_align_model ( language_code , device , model_name = None ) :
if model_name is None :
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH :
model_name = DEFAULT_ALIGN_MODELS_TORCH [ language_code ]
elif language_code in DEFAULT_ALIGN_MODELS_HF :
model_name = DEFAULT_ALIGN_MODELS_HF [ language_code ]
else :
print ( f " There is no default alignment model set for this language ( { language_code } ). \
Please find a wav2vec2 .0 model finetuned on this language in https : / / huggingface . co / models , then pass the model name in - - align_model [ MODEL_NAME ] " )
raise ValueError ( f " No default align-model for language: { language_code } " )
if model_name in torchaudio . pipelines . __all__ :
pipeline_type = " torchaudio "
bundle = torchaudio . pipelines . __dict__ [ model_name ]
align_model = bundle . get_model ( ) . to ( device )
labels = bundle . get_labels ( )
align_dictionary = { c . lower ( ) : i for i , c in enumerate ( labels ) }
else :
try :
processor = AutoProcessor . from_pretrained ( model_name )
align_model = Wav2Vec2ForCTC . from_pretrained ( model_name )
except Exception as e :
print ( e )
print ( f " Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models " )
raise ValueError ( f ' The chosen align_model " { model_name } " could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14) ' )
pipeline_type = " huggingface "
align_model = align_model . to ( device )
labels = processor . tokenizer . get_vocab ( )
align_dictionary = { char . lower ( ) : code for char , code in processor . tokenizer . get_vocab ( ) . items ( ) }
align_metadata = { " language " : language_code , " dictionary " : align_dictionary , " type " : pipeline_type }
return align_model , align_metadata
2022-12-14 18:59:12 +00:00
def cli ( ) :
from . import available_models
parser = argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( " audio " , nargs = " + " , type = str , help = " audio file(s) to transcribe " )
parser . add_argument ( " --model " , default = " small " , choices = available_models ( ) , help = " name of the Whisper model to use " )
parser . add_argument ( " --model_dir " , type = str , default = None , help = " the path to save model files; uses ~/.cache/whisper by default " )
parser . add_argument ( " --device " , default = " cuda " if torch . cuda . is_available ( ) else " cpu " , help = " device to use for PyTorch inference " )
# alignment params
2022-12-20 19:54:55 +00:00
parser . add_argument ( " --align_model " , default = None , help = " Name of phoneme-level ASR model to do alignment " )
2022-12-17 17:24:48 +00:00
parser . add_argument ( " --align_extend " , default = 2 , type = float , help = " Seconds before and after to extend the whisper segments for alignment " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --align_from_prev " , default = True , type = bool , help = " Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment " )
2022-12-19 19:12:50 +00:00
parser . add_argument ( " --drop_non_aligned " , action = " store_true " , help = " For word .srt, whether to drop non aliged words, or merge them into neighbouring. " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --output_dir " , " -o " , type = str , default = " . " , help = " directory to save the outputs " )
2022-12-23 00:41:12 +00:00
parser . add_argument ( " --output_type " , default = " srt " , choices = [ ' all ' , ' srt ' , ' vtt ' , ' txt ' ] , help = " File type for desired output save " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --verbose " , type = str2bool , default = True , help = " whether to print out the progress and debug messages " )
parser . add_argument ( " --task " , type = str , default = " transcribe " , choices = [ " transcribe " , " translate " ] , help = " whether to perform X->X speech recognition ( ' transcribe ' ) or X->English translation ( ' translate ' ) " )
parser . add_argument ( " --language " , type = str , default = None , choices = sorted ( LANGUAGES . keys ( ) ) + sorted ( [ k . title ( ) for k in TO_LANGUAGE_CODE . keys ( ) ] ) , help = " language spoken in the audio, specify None to perform language detection " )
parser . add_argument ( " --temperature " , type = float , default = 0 , help = " temperature to use for sampling " )
parser . add_argument ( " --best_of " , type = optional_int , default = 5 , help = " number of candidates when sampling with non-zero temperature " )
parser . add_argument ( " --beam_size " , type = optional_int , default = 5 , help = " number of beams in beam search, only applicable when temperature is zero " )
parser . add_argument ( " --patience " , type = float , default = None , help = " optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search " )
parser . add_argument ( " --length_penalty " , type = float , default = None , help = " optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default " )
parser . add_argument ( " --suppress_tokens " , type = str , default = " -1 " , help = " comma-separated list of token ids to suppress during sampling; ' -1 ' will suppress most special characters except common punctuations " )
parser . add_argument ( " --initial_prompt " , type = str , default = None , help = " optional text to provide as a prompt for the first window. " )
2022-12-15 19:44:49 +00:00
parser . add_argument ( " --condition_on_previous_text " , type = str2bool , default = False , help = " if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop " )
2022-12-14 18:59:12 +00:00
parser . add_argument ( " --fp16 " , type = str2bool , default = True , help = " whether to perform inference in fp16; True by default " )
parser . add_argument ( " --temperature_increment_on_fallback " , type = optional_float , default = 0.2 , help = " temperature to increase when falling back when the decoding fails to meet either of the thresholds below " )
parser . add_argument ( " --compression_ratio_threshold " , type = optional_float , default = 2.4 , help = " if the gzip compression ratio is higher than this value, treat the decoding as failed " )
parser . add_argument ( " --logprob_threshold " , type = optional_float , default = - 1.0 , help = " if the average log probability is lower than this value, treat the decoding as failed " )
parser . add_argument ( " --no_speech_threshold " , type = optional_float , default = 0.6 , help = " if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence " )
parser . add_argument ( " --threads " , type = optional_int , default = 0 , help = " number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS " )
args = parser . parse_args ( ) . __dict__
model_name : str = args . pop ( " model " )
model_dir : str = args . pop ( " model_dir " )
output_dir : str = args . pop ( " output_dir " )
output_type : str = args . pop ( " output_type " )
device : str = args . pop ( " device " )
align_model : str = args . pop ( " align_model " )
align_extend : float = args . pop ( " align_extend " )
align_from_prev : bool = args . pop ( " align_from_prev " )
2022-12-19 19:12:50 +00:00
drop_non_aligned : bool = args . pop ( " drop_non_aligned " )
2022-12-14 18:59:12 +00:00
os . makedirs ( output_dir , exist_ok = True )
if model_name . endswith ( " .en " ) and args [ " language " ] not in { " en " , " English " } :
if args [ " language " ] is not None :
warnings . warn ( f " { model_name } is an English-only model but receipted ' { args [ ' language ' ] } ' ; using English instead. " )
args [ " language " ] = " en "
temperature = args . pop ( " temperature " )
temperature_increment_on_fallback = args . pop ( " temperature_increment_on_fallback " )
if temperature_increment_on_fallback is not None :
temperature = tuple ( np . arange ( temperature , 1.0 + 1e-6 , temperature_increment_on_fallback ) )
else :
temperature = [ temperature ]
threads = args . pop ( " threads " )
if threads > 0 :
torch . set_num_threads ( threads )
from . import load_model
model = load_model ( model_name , device = device , download_root = model_dir )
2022-12-20 19:54:55 +00:00
align_language = args [ " language " ] if args [ " language " ] is not None else " en " # default to loading english if not specified
align_model , align_metadata = load_align_model ( align_language , device , model_name = align_model )
2022-12-14 18:59:12 +00:00
for audio_path in args . pop ( " audio " ) :
result = transcribe ( model , audio_path , temperature = temperature , * * args )
2022-12-20 19:54:55 +00:00
if result [ " language " ] != align_metadata [ " language " ] :
# load new language
print ( f " New language found ( { result [ ' language ' ] } )! Previous was ( { align_metadata [ ' language ' ] } ), loading new alignment model for new language... " )
align_model , align_metadata = load_align_model ( result [ " language " ] , device )
2022-12-23 00:41:12 +00:00
result_aligned = align ( result [ " segments " ] , align_model , align_metadata , audio_path , device ,
2022-12-21 01:20:35 +09:00
extend_duration = align_extend , start_from_previous = align_from_prev , drop_non_aligned_words = drop_non_aligned )
2022-12-14 18:59:12 +00:00
audio_basename = os . path . basename ( audio_path )
# save TXT
if output_type in [ " txt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .txt " ) , " w " , encoding = " utf-8 " ) as txt :
write_txt ( result_aligned [ " segments " ] , file = txt )
# save VTT
if output_type in [ " vtt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .vtt " ) , " w " , encoding = " utf-8 " ) as vtt :
write_vtt ( result_aligned [ " segments " ] , file = vtt )
# save SRT
if output_type in [ " srt " , " all " ] :
with open ( os . path . join ( output_dir , audio_basename + " .srt " ) , " w " , encoding = " utf-8 " ) as srt :
write_srt ( result_aligned [ " segments " ] , file = srt )
2022-12-17 17:24:48 +00:00
2022-12-19 19:12:50 +00:00
# save per-word SRT
with open ( os . path . join ( output_dir , audio_basename + " .word.srt " ) , " w " , encoding = " utf-8 " ) as srt :
write_srt ( result_aligned [ " word_segments " ] , file = srt )
2022-12-17 17:24:48 +00:00
# save ASS
2022-12-23 00:41:12 +00:00
with open ( os . path . join ( output_dir , audio_basename + " .ass " ) , " w " , encoding = " utf-8 " ) as ass :
write_ass ( result_aligned [ " segments " ] , file = ass )
2022-12-14 18:59:12 +00:00
if __name__ == ' __main__ ' :
cli ( )