mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
opti the inference loop
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import gc
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -113,19 +114,6 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
vad_model = None
|
vad_model = None
|
||||||
|
|
||||||
if diarize:
|
|
||||||
if hf_token is None:
|
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
|
||||||
else:
|
|
||||||
diarize_model = None
|
|
||||||
|
|
||||||
if no_align:
|
|
||||||
align_model, align_metadata = None, None
|
|
||||||
else:
|
|
||||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
|
||||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
|
||||||
|
|
||||||
# if model_flush:
|
# if model_flush:
|
||||||
# print(">>Model flushing activated... Only loading model after ASR stage")
|
# print(">>Model flushing activated... Only loading model after ASR stage")
|
||||||
# del align_model
|
# del align_model
|
||||||
@ -150,9 +138,12 @@ def cli():
|
|||||||
|
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
|
|
||||||
|
# Part 1: VAD & ASR Loop
|
||||||
|
results = []
|
||||||
|
tmp_results = []
|
||||||
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
input_audio_path = audio_path
|
input_audio_path = audio_path
|
||||||
tfile = None
|
tfile = None
|
||||||
@ -161,7 +152,6 @@ def cli():
|
|||||||
if vad_model is not None:
|
if vad_model is not None:
|
||||||
if not audio_path.endswith(".wav"):
|
if not audio_path.endswith(".wav"):
|
||||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
|
||||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||||
if tmp_dir is not None:
|
if tmp_dir is not None:
|
||||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||||
@ -174,23 +164,52 @@ def cli():
|
|||||||
print(">>Performing transcription...")
|
print(">>Performing transcription...")
|
||||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||||
|
|
||||||
# >> Align
|
results.append((result, input_audio_path))
|
||||||
if align_model is not None and len(result["segments"]) > 0:
|
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
|
||||||
# load new language
|
|
||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
|
||||||
print(">>Performing alignment...")
|
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
|
||||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
|
||||||
|
|
||||||
# >> Diarize
|
# Unload Whisper and VAD
|
||||||
if diarize_model is not None:
|
del model
|
||||||
|
del vad_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Part 2: Align Loop
|
||||||
|
if not no_align:
|
||||||
|
tmp_results = results
|
||||||
|
results = []
|
||||||
|
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||||
|
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||||
|
for result, input_audio_path in tmp_results:
|
||||||
|
# >> Align
|
||||||
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
|
# load new language
|
||||||
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
|
print(">>Performing alignment...")
|
||||||
|
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||||
|
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||||
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
|
# Unload align model
|
||||||
|
del align_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# >> Diarize
|
||||||
|
if diarize:
|
||||||
|
if hf_token is None:
|
||||||
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
|
tmp_results = results
|
||||||
|
results = []
|
||||||
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||||
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||||
result = {"segments": results_segments, "word_segments": word_segments}
|
result = {"segments": results_segments, "word_segments": word_segments}
|
||||||
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
|
# >> Write
|
||||||
|
for result, audio_path in results:
|
||||||
writer(result, audio_path)
|
writer(result, audio_path)
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
|
Reference in New Issue
Block a user