mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add preliminary japanese support
This commit is contained in:
@ -1,21 +1,26 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import tqdm
|
import tqdm
|
||||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||||
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
|
||||||
|
from .alignment import backtrack, get_trellis, merge_repeats, merge_words
|
||||||
|
from .audio import (HOP_LENGTH, N_FRAMES, SAMPLE_RATE, load_audio,
|
||||||
|
log_mel_spectrogram, pad_or_trim)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
|
from .utils import (exact_div, format_timestamp, optional_float, optional_int,
|
||||||
|
str2bool, write_ass, write_srt, write_txt, write_vtt)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
|
wa2vec2_on_hugginface = ["wav2vec2-large-xlsr-53-japanese"]
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
@ -249,6 +254,7 @@ def transcribe(
|
|||||||
|
|
||||||
def align(
|
def align(
|
||||||
transcript: Iterator[dict],
|
transcript: Iterator[dict],
|
||||||
|
language: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
model_dictionary: dict,
|
model_dictionary: dict,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
@ -278,12 +284,15 @@ def align(
|
|||||||
|
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
emissions, _ = model(waveform_segment.to(device))
|
emissions = model(waveform_segment.to(device)).logits
|
||||||
emissions = torch.log_softmax(emissions, dim=-1)
|
emissions = torch.log_softmax(emissions, dim=-1)
|
||||||
emission = emissions[0].cpu().detach()
|
emission = emissions[0].cpu().detach()
|
||||||
|
|
||||||
transcription = segment['text'].strip()
|
transcription = segment['text'].strip()
|
||||||
t_words = transcription.split(' ')
|
if language != "ja":
|
||||||
|
t_words = transcription.split(' ')
|
||||||
|
else:
|
||||||
|
t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words
|
||||||
|
|
||||||
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
||||||
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
||||||
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
|
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
|
||||||
@ -408,13 +417,19 @@ def cli():
|
|||||||
align_model = bundle.get_model().to(device)
|
align_model = bundle.get_model().to(device)
|
||||||
labels = bundle.get_labels()
|
labels = bundle.get_labels()
|
||||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||||
|
elif align_model == "wav2vec2-large-xlsr-53-japanese":
|
||||||
|
processor = AutoProcessor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
|
||||||
|
align_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
|
||||||
|
align_model.to(device)
|
||||||
|
labels = processor.tokenizer.get_vocab()
|
||||||
|
align_dictionary = processor.tokenizer.get_vocab()
|
||||||
else:
|
else:
|
||||||
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
|
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
|
||||||
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
|
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,
|
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
|
||||||
extend_duration=align_extend, start_from_previous=align_from_prev)
|
extend_duration=align_extend, start_from_previous=align_from_prev)
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
|
|
||||||
# save TXT
|
# save TXT
|
||||||
|
Reference in New Issue
Block a user