mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
15 Commits
Author | SHA1 | Date | |
---|---|---|---|
da458863d7 | |||
cf252a8592 | |||
6a72b61564 | |||
48ed89834e | |||
bb15c9428f | |||
9482d324d0 | |||
4146e56d5b | |||
118e7deedb | |||
70a4a0a25c | |||
40948a3d00 | |||
c8be6ac94d | |||
a582a59493 | |||
861379edc3 | |||
4af345434a | |||
634799b3be |
12
README.md
12
README.md
@ -85,8 +85,8 @@ Safest to use install pytorch as follows (for gpu)
|
||||
`
|
||||
|
||||
|
||||
### Voice Activity Detection Filtering & Diarization
|
||||
To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||
### Speaker Diarization
|
||||
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||
|
||||
|
||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||
@ -130,12 +130,13 @@ See more examples in other languages [here](EXAMPLES.md).
|
||||
|
||||
```python
|
||||
import whisperx
|
||||
import whisper
|
||||
|
||||
device = "cuda"
|
||||
audio_file = "audio.mp3"
|
||||
|
||||
# transcribe with original whisper
|
||||
model = whisperx.load_model("large", device)
|
||||
model = whisper.load_model("large", device)
|
||||
result = model.transcribe(audio_file)
|
||||
|
||||
print(result["segments"]) # before alignment
|
||||
@ -157,9 +158,6 @@ In addition to forced alignment, the following two modifications have been made
|
||||
|
||||
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
|
||||
2. Clamping segment `end_time` to be at least 0.02s (one time precision) later than `start_time` (prevents segments with negative duration)
|
||||
|
||||
|
||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||
|
||||
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
||||
@ -255,4 +253,4 @@ as well the following works, used in each stage of the pipeline:
|
||||
year={2020},
|
||||
organization={IEEE}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@ -7,4 +7,4 @@ more-itertools
|
||||
transformers>=4.19.0
|
||||
ffmpeg-python==0.2.0
|
||||
pyannote.audio
|
||||
whisper
|
||||
openai-whisper==20230314
|
||||
|
@ -40,7 +40,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
}
|
||||
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
@ -55,7 +55,7 @@ def load_align_model(language_code, device, model_name=None):
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
pipeline_type = "torchaudio"
|
||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||
align_model = bundle.get_model().to(device)
|
||||
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
|
@ -269,6 +269,10 @@ def transcribe(
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
# clamp end-time to at least be 1 frame after start-time
|
||||
end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
|
||||
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
@ -426,4 +430,4 @@ def transcribe_with_vad(
|
||||
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
return output
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import gc
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
@ -44,7 +45,7 @@ def cli():
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_filter", default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
||||
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
|
||||
@ -113,19 +114,6 @@ def cli():
|
||||
else:
|
||||
vad_model = None
|
||||
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
else:
|
||||
diarize_model = None
|
||||
|
||||
if no_align:
|
||||
align_model, align_metadata = None, None
|
||||
else:
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
|
||||
# if model_flush:
|
||||
# print(">>Model flushing activated... Only loading model after ASR stage")
|
||||
# del align_model
|
||||
@ -150,9 +138,12 @@ def cli():
|
||||
|
||||
from whisper import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
|
||||
# Part 1: VAD & ASR Loop
|
||||
results = []
|
||||
tmp_results = []
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
for audio_path in args.pop("audio"):
|
||||
input_audio_path = audio_path
|
||||
tfile = None
|
||||
@ -161,7 +152,6 @@ def cli():
|
||||
if vad_model is not None:
|
||||
if not audio_path.endswith(".wav"):
|
||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
if tmp_dir is not None:
|
||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||
@ -173,24 +163,53 @@ def cli():
|
||||
else:
|
||||
print(">>Performing transcription...")
|
||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
# Unload Whisper and VAD
|
||||
del model
|
||||
del vad_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# >> Diarize
|
||||
if diarize_model is not None:
|
||||
# Part 2: Align Loop
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
for result, input_audio_path in tmp_results:
|
||||
# >> Align
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
# Unload align model
|
||||
del align_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
writer(result, audio_path)
|
||||
|
||||
# cleanup
|
||||
|
@ -221,13 +221,13 @@ class WriteASS(ResultWriter):
|
||||
extension: str = "ass"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
write_ass(result["segments"], file, resoltuion="word")
|
||||
write_ass(result["segments"], file, resolution="word")
|
||||
|
||||
class WriteASSchar(ResultWriter):
|
||||
extension: str = "ass"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
write_ass(result["segments"], file, resoltuion="char")
|
||||
write_ass(result["segments"], file, resolution="char")
|
||||
|
||||
class WritePickle(ResultWriter):
|
||||
extension: str = "ass"
|
||||
|
@ -16,9 +16,11 @@ from typing import List, Tuple, Optional
|
||||
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
@ -301,4 +303,4 @@ def merge_chunks(segments, chunk_size):
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
return merged_segments
|
||||
|
Reference in New Issue
Block a user