mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge branch 'main' into add-merge-chunk-size-as-argument
This commit is contained in:
@ -127,6 +127,10 @@ To label the transcript with speaker ID's (set number of speakers if known e.g.
|
||||
|
||||
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
||||
|
||||
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||
|
||||
whisperx examples/sample01.wav --compute_type int8
|
||||
|
||||
### Other languages
|
||||
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||
|
@ -50,6 +50,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
|
||||
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
|
||||
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi"
|
||||
}
|
||||
|
||||
|
||||
@ -97,6 +98,9 @@ def align(
|
||||
device: str,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
print_progress: bool = False,
|
||||
combined_progress: bool = False,
|
||||
total_segments: int = 0
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
@ -118,6 +122,11 @@ def align(
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
for sdx, segment in enumerate(transcript):
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
if print_progress:
|
||||
base_progress = ((sdx + 1) / total_segments) * 100
|
||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
@ -161,9 +170,10 @@ def align(
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
@ -247,7 +247,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
@ -285,7 +285,12 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
total_segments = len(vad_segments)
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
if print_progress:
|
||||
base_progress = ((idx + 1) / total_segments) * 100
|
||||
percent_complete = base_progress / 2 if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
text = out['text']
|
||||
if batch_size in [0, 1, None]:
|
||||
text = text[0]
|
||||
|
@ -74,6 +74,8 @@ def cli():
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -107,6 +109,7 @@ def cli():
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
print_progress: bool = args.pop("print_progress")
|
||||
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
@ -159,7 +162,7 @@ def cli():
|
||||
audio = load_audio(audio_path)
|
||||
# >> VAD & ASR
|
||||
print(">>Performing transcription...")
|
||||
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size)
|
||||
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress)
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload Whisper and VAD
|
||||
@ -187,7 +190,7 @@ def cli():
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
|
||||
|
||||
results.append((result, audio_path))
|
||||
|
||||
|
@ -225,6 +225,9 @@ class SubtitlesWriter(ResultWriter):
|
||||
highlight_words: bool = options["highlight_words"]
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
|
||||
if len(result["segments"]) == 0:
|
||||
return
|
||||
|
||||
if len(result["segments"]) == 0:
|
||||
return
|
||||
@ -296,7 +299,7 @@ class SubtitlesWriter(ResultWriter):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
yield last, start, prefix + subtitle_text
|
||||
|
||||
yield start, end, prefix + " ".join(
|
||||
[
|
||||
|
Reference in New Issue
Block a user