Custom result types

This commit is contained in:
Simon
2023-05-08 20:45:34 +02:00
parent b50aafb17b
commit eabf35dff0
3 changed files with 68 additions and 9 deletions

View File

@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None):
@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline):
else:
language = self.tokenizer.language_code
segments = []
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']