mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
handle tmp wav file better
This commit is contained in:
@ -107,8 +107,6 @@ def cli():
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
|
||||
if vad_filter:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.audio import Model, Pipeline
|
||||
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
|
||||
@ -158,18 +156,25 @@ def cli():
|
||||
for audio_path in args.pop("audio"):
|
||||
input_audio_path = audio_path
|
||||
tfile = None
|
||||
|
||||
# >> VAD & ASR
|
||||
if vad_model is not None:
|
||||
if not audio_path.endswith(".wav"):
|
||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||
tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
||||
input_audio_path = tfile.name
|
||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
if tmp_dir is not None:
|
||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||
else:
|
||||
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
|
||||
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
||||
print(">>Performing VAD...")
|
||||
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
|
||||
else:
|
||||
print(">>Performing transcription...")
|
||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
@ -179,16 +184,18 @@ def cli():
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
|
||||
|
||||
|
||||
# >> Diarize
|
||||
if diarize_model is not None:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
|
||||
|
||||
if tfile is not None:
|
||||
tfile.close()
|
||||
writer(result, audio_path)
|
||||
|
||||
# cleanup
|
||||
if input_audio_path != audio_path:
|
||||
os.remove(input_audio_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
Reference in New Issue
Block a user