From 50cda426ba980a53df690b92a1b1c6c812fccca6 Mon Sep 17 00:00:00 2001 From: Yasutaka Odo Date: Mon, 19 Dec 2022 22:28:28 +0900 Subject: [PATCH] add preliminary japanese support --- whisperx/transcribe.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index c915aca..3e699a5 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -1,21 +1,26 @@ import argparse import os import warnings -from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torchaudio import tqdm -from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio -from .alignment import get_trellis, backtrack, merge_repeats, merge_words +from transformers import AutoProcessor, Wav2Vec2ForCTC + +from .alignment import backtrack, get_trellis, merge_repeats, merge_words +from .audio import (HOP_LENGTH, N_FRAMES, SAMPLE_RATE, load_audio, + log_mel_spectrogram, pad_or_trim) from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass +from .utils import (exact_div, format_timestamp, optional_float, optional_int, + str2bool, write_ass, write_srt, write_txt, write_vtt) if TYPE_CHECKING: from .model import Whisper +wa2vec2_on_hugginface = ["wav2vec2-large-xlsr-53-japanese"] def transcribe( model: "Whisper", @@ -249,6 +254,7 @@ def transcribe( def align( transcript: Iterator[dict], + language: str, model: torch.nn.Module, model_dictionary: dict, audio: Union[str, np.ndarray, torch.Tensor], @@ -278,12 +284,15 @@ def align( waveform_segment = audio[:, f1:f2] with torch.inference_mode(): - emissions, _ = model(waveform_segment.to(device)) + emissions = model(waveform_segment.to(device)).logits emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() - transcription = segment['text'].strip() - t_words = transcription.split(' ') + if language != "ja": + t_words = transcription.split(' ') + else: + t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words + t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words] t_words_nonempty = [x for x in t_words_clean if x != ""] t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""] @@ -408,13 +417,19 @@ def cli(): align_model = bundle.get_model().to(device) labels = bundle.get_labels() align_dictionary = {c.lower(): i for i, c in enumerate(labels)} + elif align_model == "wav2vec2-large-xlsr-53-japanese": + processor = AutoProcessor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese") + align_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese") + align_model.to(device) + labels = processor.tokenizer.get_vocab() + align_dictionary = processor.tokenizer.get_vocab() else: print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}') raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines') for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device, - extend_duration=align_extend, start_from_previous=align_from_prev) + result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device, + extend_duration=align_extend, start_from_previous=align_from_prev) audio_basename = os.path.basename(audio_path) # save TXT