fix error message

This commit is contained in:
Yasutaka Odo
2022-12-20 22:29:18 +09:00
parent f00e9cb149
commit 5d7c3b521c
2 changed files with 13 additions and 8 deletions

View File

@ -20,7 +20,7 @@ from .utils import (exact_div, format_timestamp, optional_float, optional_int,
if TYPE_CHECKING:
from .model import Whisper
wa2vec2_on_hugginface = ["wav2vec2-large-xlsr-53-japanese"]
wa2vec2_models_on_hugginface = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
def transcribe(
model: "Whisper",
@ -320,7 +320,7 @@ def align(
segment['start'] = t1_actual
segment['end'] = t2_actual
prev_t2 = segment['end']
prev_t2 = segment['end']
# merge missing words to previous, or merge with next word ahead if idx == 0
@ -417,19 +417,19 @@ def cli():
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
elif align_model == "wav2vec2-large-xlsr-53-japanese":
processor = AutoProcessor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
align_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-japanese")
elif align_model in wa2vec2_models_on_hugginface:
processor = AutoProcessor.from_pretrained(align_model)
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = processor.tokenizer.get_vocab()
else:
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface}')
raise ValueError(f'Align model "{align_model}" not supported')
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev)
extend_duration=align_extend, start_from_previous=align_from_prev)
audio_basename = os.path.basename(audio_path)
# save TXT