mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
multilingual init
This commit is contained in:
@ -284,14 +284,13 @@ def align(
|
||||
|
||||
transcription = segment['text'].strip()
|
||||
t_words = transcription.split(' ')
|
||||
t_words_clean = [''.join([w for w in word if w.upper() in model_dictionary.keys()]) for word in t_words]
|
||||
t_words_clean = [''.join([w for w in word if w in model_dictionary.keys()]) for word in t_words]
|
||||
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
||||
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
||||
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
|
||||
segment['word-level'] = []
|
||||
|
||||
if len(t_words_nonempty) > 0:
|
||||
transcription_cleaned = "|".join(t_words_nonempty).upper()
|
||||
transcription_cleaned = "|".join(t_words_nonempty).lower()
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
trellis = get_trellis(emission, tokens)
|
||||
path = backtrack(trellis, emission, tokens)
|
||||
@ -404,12 +403,14 @@ def cli():
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_960H
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c: i for i, c in enumerate(labels)}
|
||||
|
||||
if align_model in torchaudio.pipelines.__all__:
|
||||
bundle = torchaudio.pipelines.__dict__[align_model]
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
|
||||
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,
|
||||
|
Reference in New Issue
Block a user