mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #169 from invisprints/v2-opt-load-model
Optimize the inference process and reduce the memory usage
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import gc
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
@ -113,19 +114,6 @@ def cli():
|
||||
else:
|
||||
vad_model = None
|
||||
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
else:
|
||||
diarize_model = None
|
||||
|
||||
if no_align:
|
||||
align_model, align_metadata = None, None
|
||||
else:
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
|
||||
# if model_flush:
|
||||
# print(">>Model flushing activated... Only loading model after ASR stage")
|
||||
# del align_model
|
||||
@ -150,9 +138,12 @@ def cli():
|
||||
|
||||
from whisper import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
|
||||
# Part 1: VAD & ASR Loop
|
||||
results = []
|
||||
tmp_results = []
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
for audio_path in args.pop("audio"):
|
||||
input_audio_path = audio_path
|
||||
tfile = None
|
||||
@ -161,7 +152,6 @@ def cli():
|
||||
if vad_model is not None:
|
||||
if not audio_path.endswith(".wav"):
|
||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
if tmp_dir is not None:
|
||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||
@ -173,24 +163,53 @@ def cli():
|
||||
else:
|
||||
print(">>Performing transcription...")
|
||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
# Unload Whisper and VAD
|
||||
del model
|
||||
del vad_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# >> Diarize
|
||||
if diarize_model is not None:
|
||||
# Part 2: Align Loop
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
for result, input_audio_path in tmp_results:
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
# Unload align model
|
||||
del align_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
writer(result, audio_path)
|
||||
|
||||
# cleanup
|
||||
|
Reference in New Issue
Block a user