diff --git a/README.md b/README.md index 7762e1a..6af17b3 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@
Made by Max Bain • :globe_with_meridians: https://www.maxbain.com
-whisperx-arch +whisperx-arch

Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment. @@ -64,6 +64,7 @@ $ cd whisperX $ pip install -e . ``` + You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.

Usage 💬 (command line)

@@ -101,7 +102,7 @@ Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk}`. If https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov -See more exac +See more examples in other languages [here](EXAMPLES.md). ## Python usage 🐍 @@ -127,6 +128,16 @@ print(result_aligned["segments"]) # after alignment print(result_aligned["word_segments"]) # after alignment ``` + +

Whisper Modifications

+ +In addition to forced alignment, the following two modifications have been made to the whisper transcription method: + +1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination) + +2. Clamping segment `end_time` to be at least 0.02s (one time precision) later than `start_time` (prevents segments with negative duration) + +

Limitations ⚠️

- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index d2303aa..772143d 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -223,6 +223,10 @@ def transcribe( end_timestamp_position = ( sliced_tokens[-1].item() - tokenizer.timestamp_begin ) + + # clamp end-time to at least be 1 frame after start-time + end_timestamp_position = max(end_timestamp_position, start_timestamp_position + time_precision) + add_segment( start=timestamp_offset + start_timestamp_position * time_precision, end=timestamp_offset + end_timestamp_position * time_precision, @@ -291,28 +295,27 @@ def align( prev_t2 = 0 word_segments_list = [] for idx, segment in enumerate(transcript): - if int(segment['start'] * SAMPLE_RATE) >= audio.shape[1]: - print("Failed to align segment: original start time longer than audio duration, skipping...") - continue - - if int(segment['start']) >= int(segment['end']): - print("Failed to align segment: original end time is not after start time, skipping...") - continue - + # first we pad t1 = max(segment['start'] - extend_duration, 0) t2 = min(segment['end'] + extend_duration, MAX_DURATION) + + # use prev_t2 as current t1 if it's later if start_from_previous and t1 < prev_t2: t1 = prev_t2 + # check if timestamp range is still valid + if t1 >= MAX_DURATION: + print("Failed to align segment: original start time longer than audio duration, skipping...") + continue + if t2 - t1 < 0.02: + print("Failed to align segment: duration smaller than 0.02s time precision") + continue + f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) - waveform_segment = audio[:, f1:f2] - if waveform_segment.shape[1] < 10: - print("Failed to align segment: too short in duration, %.3f" % waveform_segment.shape[1]/SAMPLE_RATE) - continue with torch.inference_mode(): if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device)) @@ -321,6 +324,7 @@ def align( else: raise NotImplementedError(f"Align model of type {model_type} not supported.") emissions = torch.log_softmax(emissions, dim=-1) + emission = emissions[0].cpu().detach() transcription = segment['text'].strip() if model_lang not in LANGUAGES_WITHOUT_SPACES: @@ -519,6 +523,7 @@ def cli(): print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) + print("Performing alignment...") result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned) audio_basename = os.path.basename(audio_path)