replace magic strings

This commit is contained in:
Yasutaka Odo
2022-12-21 02:11:08 +09:00
parent d7546def91
commit 93e568b3bf

View File

@ -17,7 +17,9 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
wa2vec2_models_on_hugginface = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"] hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
asian_languages = ["ja"]
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
@ -282,19 +284,18 @@ def align(
f2 = int(t2 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2] waveform_segment = audio[:, f1:f2]
print(language)
with torch.inference_mode(): with torch.inference_mode():
if language != 'ja': if language not in asian_languages:
emissions, _ = model(waveform_segment.to(device)) emissions, _ = model(waveform_segment.to(device))
else: else:
emissions = model(waveform_segment.to(device)).logits emissions = model(waveform_segment.to(device)).logits
emissions = torch.log_softmax(emissions, dim=-1) emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
transcription = segment['text'].strip() transcription = segment['text'].strip()
if language != "ja": if language not in asian_languages:
t_words = transcription.split(' ') t_words = transcription.split(' ')
else: else:
t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words t_words = [c for c in transcription]
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words] t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""] t_words_nonempty = [x for x in t_words_clean if x != ""]
@ -434,7 +435,7 @@ def cli():
align_model = bundle.get_model().to(device) align_model = bundle.get_model().to(device)
labels = bundle.get_labels() labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)} align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
elif align_model in wa2vec2_models_on_hugginface: elif align_model in hugginface_models:
processor = AutoProcessor.from_pretrained(align_model) processor = AutoProcessor.from_pretrained(align_model)
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device) align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
align_model.to(device) align_model.to(device)