From 93e568b3bf24b531b119914c93e04a5d3c62f05c Mon Sep 17 00:00:00 2001 From: Yasutaka Odo Date: Wed, 21 Dec 2022 02:11:08 +0900 Subject: [PATCH] replace magic strings --- whisperx/transcribe.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3c5b0a7..7a0e401 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -17,7 +17,9 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st if TYPE_CHECKING: from .model import Whisper -wa2vec2_models_on_hugginface = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"] +hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"] +asian_languages = ["ja"] + def transcribe( model: "Whisper", @@ -282,19 +284,18 @@ def align( f2 = int(t2 * SAMPLE_RATE) waveform_segment = audio[:, f1:f2] - print(language) with torch.inference_mode(): - if language != 'ja': + if language not in asian_languages: emissions, _ = model(waveform_segment.to(device)) else: emissions = model(waveform_segment.to(device)).logits emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() transcription = segment['text'].strip() - if language != "ja": + if language not in asian_languages: t_words = transcription.split(' ') else: - t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words + t_words = [c for c in transcription] t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words] t_words_nonempty = [x for x in t_words_clean if x != ""] @@ -346,7 +347,7 @@ def align( if x == 0: t_words[x+1] = " ".join([curr_word, t_words[x+1]]) else: - word_segments_list[-1]['text'] += ' ' + curr_word + word_segments_list[-1]['text'] += ' ' + curr_word else: # then we resort back to original whisper timestamps # segment['start] and segment['end'] are unchanged @@ -434,7 +435,7 @@ def cli(): align_model = bundle.get_model().to(device) labels = bundle.get_labels() align_dictionary = {c.lower(): i for i, c in enumerate(labels)} - elif align_model in wa2vec2_models_on_hugginface: + elif align_model in hugginface_models: processor = AutoProcessor.from_pretrained(align_model) align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device) align_model.to(device)