mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
@ -50,6 +50,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||||
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
||||||
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
||||||
|
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -97,6 +98,9 @@ def align(
|
|||||||
device: str,
|
device: str,
|
||||||
interpolate_method: str = "nearest",
|
interpolate_method: str = "nearest",
|
||||||
return_char_alignments: bool = False,
|
return_char_alignments: bool = False,
|
||||||
|
print_progress: bool = False,
|
||||||
|
combined_progress: bool = False,
|
||||||
|
total_segments: int = 0
|
||||||
) -> AlignedTranscriptionResult:
|
) -> AlignedTranscriptionResult:
|
||||||
"""
|
"""
|
||||||
Align phoneme recognition predictions to known transcription.
|
Align phoneme recognition predictions to known transcription.
|
||||||
@ -118,6 +122,11 @@ def align(
|
|||||||
# 1. Preprocess to keep only characters in dictionary
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
|
if print_progress:
|
||||||
|
base_progress = ((sdx + 1) / total_segments) * 100
|
||||||
|
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||||
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
|
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
@ -161,9 +170,10 @@ def align(
|
|||||||
segment["sentence_spans"] = sentence_spans
|
segment["sentence_spans"] = sentence_spans
|
||||||
|
|
||||||
aligned_segments: List[SingleAlignedSegment] = []
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
# 2. Get prediction matrix from alignment model & align
|
# 2. Get prediction matrix from alignment model & align
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
|
|
||||||
t1 = segment["start"]
|
t1 = segment["start"]
|
||||||
t2 = segment["end"]
|
t2 = segment["end"]
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
@ -247,7 +247,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
|
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, print_progress = False, combined_progress=False
|
||||||
) -> TranscriptionResult:
|
) -> TranscriptionResult:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -285,7 +285,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
segments: List[SingleSegment] = []
|
segments: List[SingleSegment] = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
|
total_segments = len(vad_segments)
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
|
if print_progress:
|
||||||
|
base_progress = ((idx + 1) / total_segments) * 100
|
||||||
|
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||||
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
text = out['text']
|
text = out['text']
|
||||||
if batch_size in [0, 1, None]:
|
if batch_size in [0, 1, None]:
|
||||||
text = text[0]
|
text = text[0]
|
||||||
|
@ -73,6 +73,8 @@ def cli():
|
|||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
|
|
||||||
|
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -104,6 +106,7 @@ def cli():
|
|||||||
diarize: bool = args.pop("diarize")
|
diarize: bool = args.pop("diarize")
|
||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
print_progress: bool = args.pop("print_progress")
|
||||||
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
@ -156,7 +159,7 @@ def cli():
|
|||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
# >> VAD & ASR
|
# >> VAD & ASR
|
||||||
print(">>Performing transcription...")
|
print(">>Performing transcription...")
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
result = model.transcribe(audio, batch_size=batch_size, print_progress=print_progress)
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
# Unload Whisper and VAD
|
# Unload Whisper and VAD
|
||||||
@ -184,7 +187,7 @@ def cli():
|
|||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
|
||||||
|
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
|
@ -225,6 +225,9 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
highlight_words: bool = options["highlight_words"]
|
highlight_words: bool = options["highlight_words"]
|
||||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||||
|
|
||||||
|
if len(result["segments"]) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
if len(result["segments"]) == 0:
|
if len(result["segments"]) == 0:
|
||||||
return
|
return
|
||||||
|
Reference in New Issue
Block a user