mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
handle tmp wav file better
This commit is contained in:
@ -41,7 +41,12 @@ def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||
speaker = None
|
||||
speakers.append(speaker)
|
||||
seg['word-segments']['speaker'] = speakers
|
||||
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
||||
|
||||
speaker_count = pd.Series(speakers).value_counts()
|
||||
if len(speaker_count) == 0:
|
||||
seg["speaker"]= "UNKNOWN"
|
||||
else:
|
||||
seg["speaker"] = speaker_count.index[0]
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
|
@ -107,8 +107,6 @@ def cli():
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
|
||||
if vad_filter:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.audio import Model, Pipeline
|
||||
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
|
||||
@ -158,18 +156,25 @@ def cli():
|
||||
for audio_path in args.pop("audio"):
|
||||
input_audio_path = audio_path
|
||||
tfile = None
|
||||
|
||||
# >> VAD & ASR
|
||||
if vad_model is not None:
|
||||
if not audio_path.endswith(".wav"):
|
||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||
tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
||||
input_audio_path = tfile.name
|
||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
if tmp_dir is not None:
|
||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||
else:
|
||||
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
|
||||
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
||||
print(">>Performing VAD...")
|
||||
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
|
||||
else:
|
||||
print(">>Performing transcription...")
|
||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
@ -179,16 +184,18 @@ def cli():
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
|
||||
|
||||
|
||||
# >> Diarize
|
||||
if diarize_model is not None:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
|
||||
|
||||
if tfile is not None:
|
||||
tfile.close()
|
||||
writer(result, audio_path)
|
||||
|
||||
# cleanup
|
||||
if input_audio_path != audio_path:
|
||||
os.remove(input_audio_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -236,7 +236,7 @@ class WritePickle(ResultWriter):
|
||||
pd.DataFrame(result["segments"]).to_pickle(file)
|
||||
|
||||
class WriteSRTWord(ResultWriter):
|
||||
extension: str = ".word.srt"
|
||||
extension: str = "word.srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
|
Reference in New Issue
Block a user