mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
replace magic strings
This commit is contained in:
@ -17,7 +17,9 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
wa2vec2_models_on_hugginface = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
|
hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
|
||||||
|
asian_languages = ["ja"]
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
@ -282,19 +284,18 @@ def align(
|
|||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
|
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
print(language)
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if language != 'ja':
|
if language not in asian_languages:
|
||||||
emissions, _ = model(waveform_segment.to(device))
|
emissions, _ = model(waveform_segment.to(device))
|
||||||
else:
|
else:
|
||||||
emissions = model(waveform_segment.to(device)).logits
|
emissions = model(waveform_segment.to(device)).logits
|
||||||
emissions = torch.log_softmax(emissions, dim=-1)
|
emissions = torch.log_softmax(emissions, dim=-1)
|
||||||
emission = emissions[0].cpu().detach()
|
emission = emissions[0].cpu().detach()
|
||||||
transcription = segment['text'].strip()
|
transcription = segment['text'].strip()
|
||||||
if language != "ja":
|
if language not in asian_languages:
|
||||||
t_words = transcription.split(' ')
|
t_words = transcription.split(' ')
|
||||||
else:
|
else:
|
||||||
t_words = [c for c in transcription] #FIXME: ideally, we should use a tokenizer for Japanese to extract words
|
t_words = [c for c in transcription]
|
||||||
|
|
||||||
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
||||||
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
||||||
@ -434,7 +435,7 @@ def cli():
|
|||||||
align_model = bundle.get_model().to(device)
|
align_model = bundle.get_model().to(device)
|
||||||
labels = bundle.get_labels()
|
labels = bundle.get_labels()
|
||||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||||
elif align_model in wa2vec2_models_on_hugginface:
|
elif align_model in hugginface_models:
|
||||||
processor = AutoProcessor.from_pretrained(align_model)
|
processor = AutoProcessor.from_pretrained(align_model)
|
||||||
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
|
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
|
||||||
align_model.to(device)
|
align_model.to(device)
|
||||||
|
Reference in New Issue
Block a user