handle tmp wav file better

This commit is contained in:
Max Bain
2023-04-01 00:06:40 +01:00
parent b9ca701d69
commit 11a78d7ced
4 changed files with 71 additions and 24 deletions

View File

@ -29,7 +29,7 @@
</p>
<img width="1216" align="center" alt="whisperx-arch" src="https://user-images.githubusercontent.com/36994049/211200186-8b779e26-0bfd-4127-aee2-5a9238b95e1f.png">
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment.
@ -39,7 +39,7 @@
<h2 align="left", id="what-is-it">What is it 🔎</h2>
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0), multilingual use-case.
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
@ -48,11 +48,13 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
<h2 align="left", id="highlights">New🚨</h2>
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.* Repo will be updated soon with this efficient batch inference.
- Batch processing: Add `--vad_filter --parallel_bs [int]` for transcribing long audio file in batches (only supported with VAD filtering). Replace `[int]` with a batch size that fits your GPU memory, e.g. `--parallel_bs 16`.
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
@ -89,9 +91,9 @@ Run whisper on example segment (using default params)
whisperx examples/sample01.wav
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx examples/sample01.wav --model large-v2 --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -153,9 +155,8 @@ In addition to forced alignment, the following two modifications have been made
<h2 align="left" id="limitations">Limitations ⚠️</h2>
- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
- If not using VAD filter, whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
- Overlapping speech is not handled particularly well by whisper nor whisperx
- Diariazation is far from perfect.
@ -180,21 +181,23 @@ The next major upgrade we are working on is whisper with speaker diarization, so
* [x] Incorporating speaker diarization
* [x] Inference speedup with batch processing
* [ ] Automatic .wav conversion to make VAD compatible
* [ ] Model flush, for low gpu mem resources
* [ ] Improve diarization (word level). *Harder than first thought...*
<h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for queries
Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and University of Oxford.
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
@ -214,3 +217,35 @@ If you use this in your research, please cite the paper:
}
```
as well the following works, used in each stage of the pipeline:
```bibtex
@article{radford2022robust,
title={Robust speech recognition via large-scale weak supervision},
author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
journal={arXiv preprint arXiv:2212.04356},
year={2022}
}
```
```bibtex
@article{baevski2020wav2vec,
title={wav2vec 2.0: A framework for self-supervised learning of speech representations},
author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael},
journal={Advances in neural information processing systems},
volume={33},
pages={12449--12460},
year={2020}
}
```
```bibtex
@inproceedings{bredin2020pyannote,
title={Pyannote. audio: neural building blocks for speaker diarization},
author={Bredin, Herv{\'e} and Yin, Ruiqing and Coria, Juan Manuel and Gelly, Gregory and Korshunov, Pavel and Lavechin, Marvin and Fustes, Diego and Titeux, Hadrien and Bouaziz, Wassim and Gill, Marie-Philippe},
booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={7124--7128},
year={2020},
organization={IEEE}
}
```

View File

@ -41,7 +41,12 @@ def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
speaker_count = pd.Series(speakers).value_counts()
if len(speaker_count) == 0:
seg["speaker"]= "UNKNOWN"
else:
seg["speaker"] = speaker_count.index[0]
# create word level segments for .srt
word_seg = []

View File

@ -107,8 +107,6 @@ def cli():
max_speakers: int = args.pop("max_speakers")
if vad_filter:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
from pyannote.audio import Pipeline
from pyannote.audio import Model, Pipeline
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
@ -158,18 +156,25 @@ def cli():
for audio_path in args.pop("audio"):
input_audio_path = audio_path
tfile = None
# >> VAD & ASR
if vad_model is not None:
if not audio_path.endswith(".wav"):
print(">>VAD requires .wav format, converting to wav as a tempfile...")
tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
input_audio_path = tfile.name
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
if tmp_dir is not None:
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
else:
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
print(">>Performing VAD...")
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
else:
print(">>Performing transcription...")
result = transcribe(model, input_audio_path, temperature=temperature, **args)
# >> Align
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
@ -179,16 +184,18 @@ def cli():
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
# >> Diarize
if diarize_model is not None:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
result = {"segments": results_segments, "word_segments": word_segments}
if tfile is not None:
tfile.close()
writer(result, audio_path)
# cleanup
if input_audio_path != audio_path:
os.remove(input_audio_path)
if __name__ == "__main__":
cli()

View File

@ -236,7 +236,7 @@ class WritePickle(ResultWriter):
pd.DataFrame(result["segments"]).to_pickle(file)
class WriteSRTWord(ResultWriter):
extension: str = ".word.srt"
extension: str = "word.srt"
always_include_hours: bool = True
decimal_marker: str = ","