multilingual init

This commit is contained in:
Max Bain
2022-12-18 12:21:24 +00:00
parent 59a390d868
commit 45e9509227
16 changed files with 973 additions and 17 deletions

View File

@ -284,14 +284,13 @@ def align(
transcription = segment['text'].strip()
t_words = transcription.split(' ')
t_words_clean = [''.join([w for w in word if w.upper() in model_dictionary.keys()]) for word in t_words]
t_words_clean = [''.join([w for w in word if w in model_dictionary.keys()]) for word in t_words]
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""]
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
segment['word-level'] = []
if len(t_words_nonempty) > 0:
transcription_cleaned = "|".join(t_words_nonempty).upper()
transcription_cleaned = "|".join(t_words_nonempty).lower()
tokens = [model_dictionary[c] for c in transcription_cleaned]
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
@ -404,12 +403,14 @@ def cli():
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_960H
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c: i for i, c in enumerate(labels)}
if align_model in torchaudio.pipelines.__all__:
bundle = torchaudio.pipelines.__dict__[align_model]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,