add on_progress callback

This commit is contained in:
Matheus Bach
2025-01-25 22:29:55 -03:00
parent 36d2622e27
commit c72c627d10
2 changed files with 31 additions and 4 deletions

View File

@ -5,7 +5,7 @@ C. Max Bain
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Union, List, Callable, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -119,6 +119,7 @@ def align(
return_char_alignments: bool = False, return_char_alignments: bool = False,
print_progress: bool = False, print_progress: bool = False,
combined_progress: bool = False, combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult: ) -> AlignedTranscriptionResult:
""" """
Align phoneme recognition predictions to known transcription. Align phoneme recognition predictions to known transcription.
@ -147,6 +148,9 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100 base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...") print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(sdx + 1, total_segments)
num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip())

View File

@ -1,6 +1,8 @@
import os import os
from typing import List, Optional, Union
from dataclasses import replace from dataclasses import replace
import warnings
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum
import ctranslate2 import ctranslate2
import faster_whisper import faster_whisper
@ -101,6 +103,12 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode # - add support for timestamp mode
# - add support for custom inference kwargs # - add support for custom inference kwargs
class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"
def __init__( def __init__(
self, self,
model: WhisperModel, model: WhisperModel,
@ -195,8 +203,12 @@ class FasterWhisperPipeline(Pipeline):
print_progress=False, print_progress=False,
combined_progress=False, combined_progress=False,
verbose=False, verbose=False,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)
audio = load_audio(audio) audio = load_audio(audio)
def data(audio, segments): def data(audio, segments):
@ -214,6 +226,8 @@ class FasterWhisperPipeline(Pipeline):
else: else:
waveform = Pyannote.preprocess_audio(audio) waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks merge_chunks = Pyannote.merge_chunks
if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks( vad_segments = merge_chunks(
@ -253,16 +267,22 @@ class FasterWhisperPipeline(Pipeline):
segments: List[SingleSegment] = [] segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size batch_size = batch_size or self._batch_size
total_segments = len(vad_segments) total_segments = len(vad_segments)
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress: if print_progress:
base_progress = ((idx + 1) / total_segments) * 100 base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...") print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)
text = out['text'] text = out['text']
if batch_size in [0, 1, None]: if batch_size in [0, 1, None]:
text = text[0] text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append( segments.append(
{ {
"text": text, "text": text,
@ -271,6 +291,9 @@ class FasterWhisperPipeline(Pipeline):
} }
) )
if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)
# revert the tokenizer if multilingual inference is enabled # revert the tokenizer if multilingual inference is enabled
if self.preset_language is None: if self.preset_language is None:
self.tokenizer = None self.tokenizer = None