Merge pull request #445 from jim60105/add-merge-chunk-size-as-argument

feat: Add merge chunks chunk_size as arguments.
This commit is contained in:
Max Bain
2023-08-29 10:05:14 -06:00
committed by GitHub
2 changed files with 6 additions and 3 deletions

View File

@ -247,7 +247,7 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, print_progress = False, combined_progress=False
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -260,7 +260,7 @@ class FasterWhisperPipeline(Pipeline):
yield {'inputs': audio[f1:f2]}
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
vad_segments = merge_chunks(vad_segments, chunk_size)
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"

View File

@ -41,6 +41,7 @@ def cli():
# vad params
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
@ -103,6 +104,8 @@ def cli():
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
chunk_size: int = args.pop("chunk_size")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
@ -159,7 +162,7 @@ def cli():
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, print_progress=print_progress)
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress)
results.append((result, audio_path))
# Unload Whisper and VAD