2 Commits

Author SHA1 Message Date
1843f3553a Merge c72c627d10 into d700b56c9c 2025-06-17 19:02:36 +02:00
c72c627d10 add on_progress callback 2025-01-25 22:29:55 -03:00
2 changed files with 31 additions and 4 deletions

View File

@ -5,7 +5,7 @@ C. Max Bain
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Union, List, Callable, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -120,6 +120,7 @@ def align(
return_char_alignments: bool = False, return_char_alignments: bool = False,
print_progress: bool = False, print_progress: bool = False,
combined_progress: bool = False, combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult: ) -> AlignedTranscriptionResult:
""" """
Align phoneme recognition predictions to known transcription. Align phoneme recognition predictions to known transcription.
@ -148,6 +149,9 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100 base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...") print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(sdx + 1, total_segments)
num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip())

View File

@ -1,6 +1,8 @@
import os import os
from typing import List, Optional, Union
from dataclasses import replace from dataclasses import replace
import warnings
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum
import ctranslate2 import ctranslate2
import faster_whisper import faster_whisper
@ -103,6 +105,12 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode # - add support for timestamp mode
# - add support for custom inference kwargs # - add support for custom inference kwargs
class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"
def __init__( def __init__(
self, self,
model: WhisperModel, model: WhisperModel,
@ -197,8 +205,12 @@ class FasterWhisperPipeline(Pipeline):
print_progress=False, print_progress=False,
combined_progress=False, combined_progress=False,
verbose=False, verbose=False,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)
audio = load_audio(audio) audio = load_audio(audio)
def data(audio, segments): def data(audio, segments):
@ -216,6 +228,8 @@ class FasterWhisperPipeline(Pipeline):
else: else:
waveform = Pyannote.preprocess_audio(audio) waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks merge_chunks = Pyannote.merge_chunks
if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks( vad_segments = merge_chunks(
@ -255,16 +269,22 @@ class FasterWhisperPipeline(Pipeline):
segments: List[SingleSegment] = [] segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size batch_size = batch_size or self._batch_size
total_segments = len(vad_segments) total_segments = len(vad_segments)
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress: if print_progress:
base_progress = ((idx + 1) / total_segments) * 100 base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...") print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)
text = out['text'] text = out['text']
if batch_size in [0, 1, None]: if batch_size in [0, 1, None]:
text = text[0] text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append( segments.append(
{ {
"text": text, "text": text,
@ -273,6 +293,9 @@ class FasterWhisperPipeline(Pipeline):
} }
) )
if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)
# revert the tokenizer if multilingual inference is enabled # revert the tokenizer if multilingual inference is enabled
if self.preset_language is None: if self.preset_language is None:
self.tokenizer = None self.tokenizer = None