diff --git a/whisperx/asr.py b/whisperx/asr.py index 88d5bf6..2fab8bc 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -214,7 +214,7 @@ class FasterWhisperPipeline(Pipeline): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -229,13 +229,12 @@ class FasterWhisperPipeline(Pipeline): vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks(vad_segments, 30) - del_tokenizer = False - if self.tokenizer is None: - language = self.detect_language(audio) - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) - del_tokenizer = True - else: - language = self.tokenizer.language_code + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size @@ -250,9 +249,6 @@ class FasterWhisperPipeline(Pipeline): "end": round(vad_segments[idx]['end'], 3) } ) - - if del_tokenizer: - self.tokenizer = None return {"segments": segments, "language": language}