diff --git a/whisperx/asr.py b/whisperx/asr.py index bef3cd8..da30774 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -247,7 +247,7 @@ class FasterWhisperPipeline(Pipeline): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, print_progress = False, combined_progress=False + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -260,7 +260,7 @@ class FasterWhisperPipeline(Pipeline): yield {'inputs': audio[f1:f2]} vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) - vad_segments = merge_chunks(vad_segments, 30) + vad_segments = merge_chunks(vad_segments, chunk_size) if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 49788bd..6a2dcb6 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -41,6 +41,7 @@ def cli(): # vad params parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") + parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.") # diarization params parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") @@ -103,6 +104,8 @@ def cli(): vad_onset: float = args.pop("vad_onset") vad_offset: float = args.pop("vad_offset") + chunk_size: int = args.pop("chunk_size") + diarize: bool = args.pop("diarize") min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") @@ -159,7 +162,7 @@ def cli(): audio = load_audio(audio_path) # >> VAD & ASR print(">>Performing transcription...") - result = model.transcribe(audio, batch_size=batch_size, print_progress=print_progress) + result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress) results.append((result, audio_path)) # Unload Whisper and VAD