opti the inference loop

This commit is contained in:
invisprints
2023-04-09 15:58:55 +08:00
parent 9482d324d0
commit bb15c9428f

View File

@ -1,5 +1,6 @@
import argparse
import os
import gc
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
@ -113,19 +114,6 @@ def cli():
else:
vad_model = None
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
else:
diarize_model = None
if no_align:
align_model, align_metadata = None, None
else:
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
# if model_flush:
# print(">>Model flushing activated... Only loading model after ASR stage")
# del align_model
@ -150,9 +138,12 @@ def cli():
from whisper import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"):
input_audio_path = audio_path
tfile = None
@ -161,7 +152,6 @@ def cli():
if vad_model is not None:
if not audio_path.endswith(".wav"):
print(">>VAD requires .wav format, converting to wav as a tempfile...")
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
if tmp_dir is not None:
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
@ -173,24 +163,53 @@ def cli():
else:
print(">>Performing transcription...")
result = transcribe(model, input_audio_path, temperature=temperature, **args)
results.append((result, input_audio_path))
# >> Align
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
# Unload Whisper and VAD
del model
del vad_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize_model is not None:
# Part 2: Align Loop
if not no_align:
tmp_results = results
results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for result, input_audio_path in tmp_results:
# >> Align
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
results.append((result, input_audio_path))
# Unload align model
del align_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
result = {"segments": results_segments, "word_segments": word_segments}
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
writer(result, audio_path)
# cleanup