mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add vad model external dl
This commit is contained in:
@ -1,10 +1,9 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
import tempfile
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import ffmpeg
|
||||||
from whisper.audio import (
|
from whisper.audio import (
|
||||||
FRAMES_PER_SECOND,
|
FRAMES_PER_SECOND,
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
|
@ -2,22 +2,22 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import ffmpeg
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
|
from whisper.audio import SAMPLE_RATE
|
||||||
from whisper.utils import (
|
from whisper.utils import (
|
||||||
optional_float,
|
optional_float,
|
||||||
optional_int,
|
optional_int,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils import get_writer
|
|
||||||
|
|
||||||
from .asr import transcribe, transcribe_with_vad
|
|
||||||
from .alignment import load_align_model, align
|
from .alignment import load_align_model, align
|
||||||
|
from .asr import transcribe, transcribe_with_vad
|
||||||
from .diarize import DiarizationPipeline
|
from .diarize import DiarizationPipeline
|
||||||
|
from .utils import get_writer
|
||||||
from .vad import load_vad_model
|
from .vad import load_vad_model
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@ -74,7 +74,7 @@ def cli():
|
|||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
parser.add_argument("--model_flush", action="store_true", help="Flush memory of each stage after use, more GPU memory efficient, but slower when there are multiple audio files")
|
parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -148,8 +148,18 @@ def cli():
|
|||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
|
|
||||||
if vad_model is not None:
|
if vad_model is not None:
|
||||||
|
if not audio_path.endswith(".wav"):
|
||||||
|
print("VAD requires .wav format, converting to wav as a tempfile...")
|
||||||
|
tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
|
||||||
|
ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
||||||
|
vad_audio_path = tfile.name
|
||||||
|
else:
|
||||||
|
vad_audio_path = audio_path
|
||||||
print("Performing VAD...")
|
print("Performing VAD...")
|
||||||
result = transcribe_with_vad(model, audio_path, vad_model, temperature=temperature, **args)
|
result = transcribe_with_vad(model, vad_audio_path, vad_model, temperature=temperature, **args)
|
||||||
|
|
||||||
|
if tfile is not None:
|
||||||
|
tfile.close()
|
||||||
else:
|
else:
|
||||||
print("Performing transcription...")
|
print("Performing transcription...")
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
@ -1,19 +1,51 @@
|
|||||||
import os
|
import os
|
||||||
|
import urllib
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import hashlib
|
||||||
|
from tqdm import tqdm
|
||||||
from typing import Optional, Callable, Union, Text
|
from typing import Optional, Callable, Union, Text
|
||||||
from pyannote.audio.core.io import AudioFile
|
from pyannote.audio.core.io import AudioFile
|
||||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||||
from pyannote.audio.pipelines.utils import PipelineModel
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
from pyannote.audio import Model, Pipeline
|
from pyannote.audio import Model
|
||||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||||
from .diarize import Segment as SegmentX
|
from .diarize import Segment as SegmentX
|
||||||
|
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
||||||
vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token)
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||||
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
|
if not os.path.isfile(model_fp):
|
||||||
|
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||||
|
with tqdm(
|
||||||
|
total=int(source.info().get("Content-Length")),
|
||||||
|
ncols=80,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
model_bytes = open(model_fp, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||||
hyperparameters = {"onset": vad_onset,
|
hyperparameters = {"onset": vad_onset,
|
||||||
"offset": vad_offset,
|
"offset": vad_offset,
|
||||||
"min_duration_on": 0.1,
|
"min_duration_on": 0.1,
|
||||||
|
Reference in New Issue
Block a user