add preliminary japanese support

This commit is contained in:
Yasutaka Odo
2022-12-19 22:28:28 +09:00
parent 6b64cb079a
commit 50cda426ba

View File

@ -1,21 +1,26 @@
import argparse
import os
import warnings
from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torchaudio
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from transformers import AutoProcessor, Wav2Vec2ForCTC
from .alignment import backtrack, get_trellis, merge_repeats, merge_words
from .audio import (HOP_LENGTH, N_FRAMES, SAMPLE_RATE, load_audio,
log_mel_spectrogram, pad_or_trim)
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
from .utils import (exact_div, format_timestamp, optional_float, optional_int,
str2bool, write_ass, write_srt, write_txt, write_vtt)
if TYPE_CHECKING:
from .model import Whisper
wa2vec2_on_hugginface = ["wav2vec2-large-xlsr-53-japanese"]
def transcribe(
model: "Whisper",
@ -249,6 +254,7 @@ def transcribe(
def align(
transcript: Iterator[dict],
language: str,
model: torch.nn.Module,
model_dictionary: dict,
audio: Union[str, np.ndarray, torch.Tensor],
@ -278,12 +284,15 @@ def align(
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
emissions, _ = model(waveform_segment.to(device))
emissions = model(waveform_segment.to(device)).logits
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
transcription = segment['text'].strip()
if language != "ja":
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""]
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
@ -408,12 +417,18 @@ def cli():
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
elif align_model == "wav2vec2-large-xlsr-53-japanese":
processor = AutoProcessor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
align_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = processor.tokenizer.get_vocab()
else:
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev)
audio_basename = os.path.basename(audio_path)