Merge pull request #40 from MahmoudAshraf97/main

Added arguments and instructions to enable the usage VAD and Diarization
This commit is contained in:
m-bain
2023-01-26 00:34:03 +00:00
committed by GitHub
2 changed files with 11 additions and 5 deletions

View File

@ -54,6 +54,8 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
- Character level timestamps (see `*.char.ass` file output) - Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarization`) - Diarization (still in beta, add `--diarization`)
To enable VAD filtering and Diarization, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Install this package using Install this package using
@ -85,7 +87,7 @@ Run whisper on example segment (using default params)
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g. For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H whisperx examples/sample01.wav --model large-v2 --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large: Result using *WhisperX* with forced alignment to wav2vec2.0 large:

View File

@ -385,7 +385,8 @@ def cli():
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir") model_dir: str = args.pop("model_dir")
@ -397,7 +398,8 @@ def cli():
align_extend: float = args.pop("align_extend") align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev") align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: bool = args.pop("interpolate_method") interpolate_method: bool = args.pop("interpolate_method")
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter") vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input") vad_input: bool = args.pop("vad_input")
@ -410,12 +412,14 @@ def cli():
vad_input = pd.read_csv(vad_input, header=None, sep= " ") vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter: elif vad_filter:
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection",
use_auth_token=hf_token)
diarize_pipeline = None diarize_pipeline = None
if diarize: if diarize:
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1") diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=hf_token)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)