diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index e8fa47b..ed918e0 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -1,5 +1,6 @@ import argparse import os +import gc import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -113,19 +114,6 @@ def cli(): else: vad_model = None - if diarize: - if hf_token is None: - print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") - diarize_model = DiarizationPipeline(use_auth_token=hf_token) - else: - diarize_model = None - - if no_align: - align_model, align_metadata = None, None - else: - align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified - align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) - # if model_flush: # print(">>Model flushing activated... Only loading model after ASR stage") # del align_model @@ -150,9 +138,12 @@ def cli(): from whisper import load_model - model = load_model(model_name, device=device, download_root=model_dir) - writer = get_writer(output_format, output_dir) + + # Part 1: VAD & ASR Loop + results = [] + tmp_results = [] + model = load_model(model_name, device=device, download_root=model_dir) for audio_path in args.pop("audio"): input_audio_path = audio_path tfile = None @@ -161,7 +152,6 @@ def cli(): if vad_model is not None: if not audio_path.endswith(".wav"): print(">>VAD requires .wav format, converting to wav as a tempfile...") - # tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav") audio_basename = os.path.splitext(os.path.basename(audio_path))[0] if tmp_dir is not None: input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav") @@ -173,24 +163,53 @@ def cli(): else: print(">>Performing transcription...") result = transcribe(model, input_audio_path, temperature=temperature, **args) + + results.append((result, input_audio_path)) - # >> Align - if align_model is not None and len(result["segments"]) > 0: - if result.get("language", "en") != align_metadata["language"]: - # load new language - print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") - align_model, align_metadata = load_align_model(result["language"], device) - print(">>Performing alignment...") - result = align(result["segments"], align_model, align_metadata, input_audio_path, device, - extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) + # Unload Whisper and VAD + del model + del vad_model + gc.collect() + torch.cuda.empty_cache() - # >> Diarize - if diarize_model is not None: + # Part 2: Align Loop + if not no_align: + tmp_results = results + results = [] + align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified + align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) + for result, input_audio_path in tmp_results: + # >> Align + if align_model is not None and len(result["segments"]) > 0: + if result.get("language", "en") != align_metadata["language"]: + # load new language + print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") + align_model, align_metadata = load_align_model(result["language"], device) + print(">>Performing alignment...") + result = align(result["segments"], align_model, align_metadata, input_audio_path, device, + extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) + results.append((result, input_audio_path)) + + # Unload align model + del align_model + gc.collect() + torch.cuda.empty_cache() + + # >> Diarize + if diarize: + if hf_token is None: + print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") + tmp_results = results + results = [] + diarize_model = DiarizationPipeline(use_auth_token=hf_token) + for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) result = {"segments": results_segments, "word_segments": word_segments} + results.append((result, input_audio_path)) - + # >> Write + for result, audio_path in results: writer(result, audio_path) # cleanup