15 Commits

Author SHA1 Message Date
da458863d7 allow custom model_dir for torchaudio models 2023-04-14 21:40:36 +01:00
cf252a8592 allow custom path for vad model 2023-04-14 15:02:58 +01:00
6a72b61564 clamp end_timestamp to prevent infinite loop 2023-04-11 20:15:37 +01:00
48ed89834e Merge pull request #169 from invisprints/v2-opt-load-model
Optimize the inference process and reduce the memory usage
2023-04-09 13:39:13 +01:00
bb15c9428f opti the inference loop 2023-04-09 15:58:55 +08:00
9482d324d0 Merge pull request #162 from dev-nomi/cli_argument_type
Added vad_filter type
2023-04-05 13:40:04 -07:00
4146e56d5b Added vad_filter type 2023-04-05 17:11:29 +05:00
118e7deedb Merge pull request #161 from diasks2/fix_typo
Fix typo in utils.py
2023-04-04 19:00:18 -07:00
70a4a0a25c Fix typo 2023-04-05 10:50:49 +09:00
40948a3d00 fix whisper version to 20230314 for no breaking 2023-04-04 12:42:34 -07:00
c8be6ac94d update python example 2023-04-03 12:18:31 -07:00
a582a59493 mkdir for torch cache in case it doesnt exist 2023-04-01 13:05:40 -07:00
861379edc3 Merge pull request #157 from Ryan5453/fix/whisper-req
Fix Requirements
2023-03-31 16:40:19 -07:00
4af345434a Update requirements.txt 2023-03-31 19:36:38 -04:00
634799b3be hf token only for diarization 2023-03-31 16:15:40 -07:00
7 changed files with 68 additions and 45 deletions

View File

@ -85,8 +85,8 @@ Safest to use install pytorch as follows (for gpu)
`
### Voice Activity Detection Filtering & Diarization
To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
<h2 align="left" id="example">Usage 💬 (command line)</h2>
@ -130,12 +130,13 @@ See more examples in other languages [here](EXAMPLES.md).
```python
import whisperx
import whisper
device = "cuda"
audio_file = "audio.mp3"
# transcribe with original whisper
model = whisperx.load_model("large", device)
model = whisper.load_model("large", device)
result = model.transcribe(audio_file)
print(result["segments"]) # before alignment
@ -157,9 +158,6 @@ In addition to forced alignment, the following two modifications have been made
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
2. Clamping segment `end_time` to be at least 0.02s (one time precision) later than `start_time` (prevents segments with negative duration)
<h2 align="left" id="limitations">Limitations ⚠️</h2>
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
@ -255,4 +253,4 @@ as well the following works, used in each stage of the pipeline:
year={2020},
organization={IEEE}
}
```
```

View File

@ -7,4 +7,4 @@ more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
pyannote.audio
whisper
openai-whisper==20230314

View File

@ -40,7 +40,7 @@ DEFAULT_ALIGN_MODELS_HF = {
}
def load_align_model(language_code, device, model_name=None):
def load_align_model(language_code, device, model_name=None, model_dir=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
@ -55,7 +55,7 @@ def load_align_model(language_code, device, model_name=None):
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:

View File

@ -269,6 +269,10 @@ def transcribe(
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# clamp end-time to at least be 1 frame after start-time
end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
@ -426,4 +430,4 @@ def transcribe_with_vad(
output["language"] = output["segments"][0]["language"]
return output
return output

View File

@ -1,5 +1,6 @@
import argparse
import os
import gc
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
@ -44,7 +45,7 @@ def cli():
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
# vad params
parser.add_argument("--vad_filter", default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
@ -113,19 +114,6 @@ def cli():
else:
vad_model = None
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
else:
diarize_model = None
if no_align:
align_model, align_metadata = None, None
else:
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
# if model_flush:
# print(">>Model flushing activated... Only loading model after ASR stage")
# del align_model
@ -150,9 +138,12 @@ def cli():
from whisper import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"):
input_audio_path = audio_path
tfile = None
@ -161,7 +152,6 @@ def cli():
if vad_model is not None:
if not audio_path.endswith(".wav"):
print(">>VAD requires .wav format, converting to wav as a tempfile...")
# tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
if tmp_dir is not None:
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
@ -173,24 +163,53 @@ def cli():
else:
print(">>Performing transcription...")
result = transcribe(model, input_audio_path, temperature=temperature, **args)
results.append((result, input_audio_path))
# >> Align
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
# Unload Whisper and VAD
del model
del vad_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize_model is not None:
# Part 2: Align Loop
if not no_align:
tmp_results = results
results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for result, input_audio_path in tmp_results:
# >> Align
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
results.append((result, input_audio_path))
# Unload align model
del align_model
gc.collect()
torch.cuda.empty_cache()
# >> Diarize
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
result = {"segments": results_segments, "word_segments": word_segments}
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
writer(result, audio_path)
# cleanup

View File

@ -221,13 +221,13 @@ class WriteASS(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resoltuion="word")
write_ass(result["segments"], file, resolution="word")
class WriteASSchar(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resoltuion="char")
write_ass(result["segments"], file, resolution="char")
class WritePickle(ResultWriter):
extension: str = "ass"

View File

@ -16,9 +16,11 @@ from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
@ -301,4 +303,4 @@ def merge_chunks(segments, chunk_size):
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
return merged_segments