mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
15 Commits
Author | SHA1 | Date | |
---|---|---|---|
da458863d7 | |||
cf252a8592 | |||
6a72b61564 | |||
48ed89834e | |||
bb15c9428f | |||
9482d324d0 | |||
4146e56d5b | |||
118e7deedb | |||
70a4a0a25c | |||
40948a3d00 | |||
c8be6ac94d | |||
a582a59493 | |||
861379edc3 | |||
4af345434a | |||
634799b3be |
12
README.md
12
README.md
@ -85,8 +85,8 @@ Safest to use install pytorch as follows (for gpu)
|
|||||||
`
|
`
|
||||||
|
|
||||||
|
|
||||||
### Voice Activity Detection Filtering & Diarization
|
### Speaker Diarization
|
||||||
To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||||
@ -130,12 +130,13 @@ See more examples in other languages [here](EXAMPLES.md).
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import whisperx
|
import whisperx
|
||||||
|
import whisper
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
audio_file = "audio.mp3"
|
audio_file = "audio.mp3"
|
||||||
|
|
||||||
# transcribe with original whisper
|
# transcribe with original whisper
|
||||||
model = whisperx.load_model("large", device)
|
model = whisper.load_model("large", device)
|
||||||
result = model.transcribe(audio_file)
|
result = model.transcribe(audio_file)
|
||||||
|
|
||||||
print(result["segments"]) # before alignment
|
print(result["segments"]) # before alignment
|
||||||
@ -157,9 +158,6 @@ In addition to forced alignment, the following two modifications have been made
|
|||||||
|
|
||||||
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||||
|
|
||||||
2. Clamping segment `end_time` to be at least 0.02s (one time precision) later than `start_time` (prevents segments with negative duration)
|
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||||
|
|
||||||
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
||||||
@ -255,4 +253,4 @@ as well the following works, used in each stage of the pipeline:
|
|||||||
year={2020},
|
year={2020},
|
||||||
organization={IEEE}
|
organization={IEEE}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -7,4 +7,4 @@ more-itertools
|
|||||||
transformers>=4.19.0
|
transformers>=4.19.0
|
||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
pyannote.audio
|
pyannote.audio
|
||||||
whisper
|
openai-whisper==20230314
|
||||||
|
@ -40,7 +40,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_align_model(language_code, device, model_name=None):
|
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
# use default model
|
# use default model
|
||||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||||
@ -55,7 +55,7 @@ def load_align_model(language_code, device, model_name=None):
|
|||||||
if model_name in torchaudio.pipelines.__all__:
|
if model_name in torchaudio.pipelines.__all__:
|
||||||
pipeline_type = "torchaudio"
|
pipeline_type = "torchaudio"
|
||||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||||
align_model = bundle.get_model().to(device)
|
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
|
||||||
labels = bundle.get_labels()
|
labels = bundle.get_labels()
|
||||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||||
else:
|
else:
|
||||||
|
@ -269,6 +269,10 @@ def transcribe(
|
|||||||
end_timestamp_pos = (
|
end_timestamp_pos = (
|
||||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# clamp end-time to at least be 1 frame after start-time
|
||||||
|
end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
|
||||||
|
|
||||||
current_segments.append(
|
current_segments.append(
|
||||||
new_segment(
|
new_segment(
|
||||||
start=time_offset + start_timestamp_pos * time_precision,
|
start=time_offset + start_timestamp_pos * time_precision,
|
||||||
@ -426,4 +430,4 @@ def transcribe_with_vad(
|
|||||||
|
|
||||||
output["language"] = output["segments"][0]["language"]
|
output["language"] = output["segments"][0]["language"]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import gc
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -44,7 +45,7 @@ def cli():
|
|||||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||||
|
|
||||||
# vad params
|
# vad params
|
||||||
parser.add_argument("--vad_filter", default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
||||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||||
|
|
||||||
@ -113,19 +114,6 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
vad_model = None
|
vad_model = None
|
||||||
|
|
||||||
if diarize:
|
|
||||||
if hf_token is None:
|
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
|
||||||
else:
|
|
||||||
diarize_model = None
|
|
||||||
|
|
||||||
if no_align:
|
|
||||||
align_model, align_metadata = None, None
|
|
||||||
else:
|
|
||||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
|
||||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
|
||||||
|
|
||||||
# if model_flush:
|
# if model_flush:
|
||||||
# print(">>Model flushing activated... Only loading model after ASR stage")
|
# print(">>Model flushing activated... Only loading model after ASR stage")
|
||||||
# del align_model
|
# del align_model
|
||||||
@ -150,9 +138,12 @@ def cli():
|
|||||||
|
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
|
|
||||||
|
# Part 1: VAD & ASR Loop
|
||||||
|
results = []
|
||||||
|
tmp_results = []
|
||||||
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
input_audio_path = audio_path
|
input_audio_path = audio_path
|
||||||
tfile = None
|
tfile = None
|
||||||
@ -161,7 +152,6 @@ def cli():
|
|||||||
if vad_model is not None:
|
if vad_model is not None:
|
||||||
if not audio_path.endswith(".wav"):
|
if not audio_path.endswith(".wav"):
|
||||||
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
||||||
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
|
||||||
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||||
if tmp_dir is not None:
|
if tmp_dir is not None:
|
||||||
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
||||||
@ -173,24 +163,53 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
print(">>Performing transcription...")
|
print(">>Performing transcription...")
|
||||||
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
||||||
|
|
||||||
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
# >> Align
|
# Unload Whisper and VAD
|
||||||
if align_model is not None and len(result["segments"]) > 0:
|
del model
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
del vad_model
|
||||||
# load new language
|
gc.collect()
|
||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
torch.cuda.empty_cache()
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
|
||||||
print(">>Performing alignment...")
|
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
|
||||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
|
||||||
|
|
||||||
# >> Diarize
|
# Part 2: Align Loop
|
||||||
if diarize_model is not None:
|
if not no_align:
|
||||||
|
tmp_results = results
|
||||||
|
results = []
|
||||||
|
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||||
|
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||||
|
for result, input_audio_path in tmp_results:
|
||||||
|
# >> Align
|
||||||
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
|
# load new language
|
||||||
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
|
print(">>Performing alignment...")
|
||||||
|
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
||||||
|
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||||
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
|
# Unload align model
|
||||||
|
del align_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# >> Diarize
|
||||||
|
if diarize:
|
||||||
|
if hf_token is None:
|
||||||
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
|
tmp_results = results
|
||||||
|
results = []
|
||||||
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||||
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||||
result = {"segments": results_segments, "word_segments": word_segments}
|
result = {"segments": results_segments, "word_segments": word_segments}
|
||||||
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
|
# >> Write
|
||||||
|
for result, audio_path in results:
|
||||||
writer(result, audio_path)
|
writer(result, audio_path)
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
|
@ -221,13 +221,13 @@ class WriteASS(ResultWriter):
|
|||||||
extension: str = "ass"
|
extension: str = "ass"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO):
|
def write_result(self, result: dict, file: TextIO):
|
||||||
write_ass(result["segments"], file, resoltuion="word")
|
write_ass(result["segments"], file, resolution="word")
|
||||||
|
|
||||||
class WriteASSchar(ResultWriter):
|
class WriteASSchar(ResultWriter):
|
||||||
extension: str = "ass"
|
extension: str = "ass"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO):
|
def write_result(self, result: dict, file: TextIO):
|
||||||
write_ass(result["segments"], file, resoltuion="char")
|
write_ass(result["segments"], file, resolution="char")
|
||||||
|
|
||||||
class WritePickle(ResultWriter):
|
class WritePickle(ResultWriter):
|
||||||
extension: str = "ass"
|
extension: str = "ass"
|
||||||
|
@ -16,9 +16,11 @@ from typing import List, Tuple, Optional
|
|||||||
|
|
||||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
|
||||||
model_dir = torch.hub._get_torch_home()
|
model_dir = torch.hub._get_torch_home()
|
||||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
os.makedirs(model_dir, exist_ok = True)
|
||||||
|
if model_fp is None:
|
||||||
|
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
@ -301,4 +303,4 @@ def merge_chunks(segments, chunk_size):
|
|||||||
"end": curr_end,
|
"end": curr_end,
|
||||||
"segments": seg_idxs,
|
"segments": seg_idxs,
|
||||||
})
|
})
|
||||||
return merged_segments
|
return merged_segments
|
||||||
|
Reference in New Issue
Block a user