diff --git a/README.md b/README.md
index 8043e02..a660d2d 100644
--- a/README.md
+++ b/README.md
@@ -52,13 +52,6 @@ This repository provides fast automatic speech recognition (70x realtime with la
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
-- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*!
-- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper.
-- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
-- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
-- Character level timestamps (see `*.char.ass` file output)
-- Diarization (still in beta, add `--diarize`)
-
New🚨
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
@@ -87,11 +80,11 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#v20
### 3. Install this repo
-`pip install git+https://github.com/m-bain/whisperx.git@v3`
+`pip install git+https://github.com/m-bain/whisperx.git`
If already installed, update package to most recent commit
-`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
+`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
If wishing to modify this package, clone and install in editable mode:
```
@@ -183,10 +176,10 @@ print(result["segments"]) # after alignment
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
-diarize_segments = diarize_model(input_audio_path)
-# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
+diarize_segments = diarize_model(audio_file)
+# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
-result = assign_word_speakers(diarize_segments, result)
+result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index aade4b4..8d088be 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
-from typing import Iterator, Union
+from typing import Iterator, Union, List
import numpy as np
import pandas as pd
@@ -13,7 +13,11 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
+from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
+from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
+
+PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -32,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
+ "cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
@@ -39,7 +44,10 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
+ "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
+ "vi": 'nguyenvulebinh/wav2vec2-base-vi',
+ "ko": "kresnik/wav2vec2-large-xlsr-korean",
}
@@ -80,14 +88,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
- transcript: Iterator[dict],
+ transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
-):
+) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
@@ -139,14 +147,18 @@ def align(
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
- sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
+
+ punkt_param = PunktParameters()
+ punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
+ sentence_splitter = PunktSentenceTokenizer(punkt_param)
+ sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
- aligned_segments = []
+ aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
@@ -154,7 +166,7 @@ def align(
t2 = segment["end"]
text = segment["text"]
- aligned_seg = {
+ aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
@@ -307,7 +319,7 @@ def align(
aligned_segments += aligned_subsegments
# create word_segments list
- word_segments = []
+ word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]
diff --git a/whisperx/asr.py b/whisperx/asr.py
index 66b58ad..d0e6962 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -11,10 +11,18 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
+from .types import TranscriptionResult, SingleSegment
-
-def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
- vad_options=None, model=None, task="transcribe"):
+def load_model(whisper_arch,
+ device,
+ device_index=0,
+ compute_type="float16",
+ asr_options=None,
+ language=None,
+ vad_options=None,
+ model=None,
+ task="transcribe",
+ download_root=None):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
@@ -22,14 +30,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
+ download_root: Optional[str] - The root directory to download the model to.
Returns:
A Whisper pipeline.
- '''
+ '''
if whisper_arch.endswith(".en"):
language = "en"
- model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
+ model = WhisperModel(whisper_arch,
+ device=device,
+ device_index=device_index,
+ compute_type=compute_type,
+ download_root=download_root)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
@@ -114,7 +127,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
)
-
+
tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str:
@@ -127,7 +140,7 @@ class WhisperModel(faster_whisper.WhisperModel):
text = decode_batch(tokens_batch)
return text
-
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
@@ -136,9 +149,9 @@ class WhisperModel(faster_whisper.WhisperModel):
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
-
+
return self.model.encode(features, to_cpu=to_cpu)
-
+
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
@@ -176,7 +189,7 @@ class FasterWhisperPipeline(Pipeline):
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
-
+
super(Pipeline, self).__init__()
self.vad_model = vad
@@ -194,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
-
+
def postprocess(self, model_outputs):
return model_outputs
@@ -214,11 +227,11 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
- self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
- ):
+ self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
+ ) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
-
+
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
@@ -228,16 +241,21 @@ class FasterWhisperPipeline(Pipeline):
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
-
- del_tokenizer = False
if self.tokenizer is None:
- language = self.detect_language(audio)
- self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
- del_tokenizer = True
+ language = language or self.detect_language(audio)
+ task = task or "transcribe"
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
+ self.model.model.is_multilingual, task=task,
+ language=language)
else:
- language = self.tokenizer.language_code
+ language = language or self.tokenizer.language_code
+ task = task or self.tokenizer.task
+ if task != self.tokenizer.task or language != self.tokenizer.language_code:
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
+ self.model.model.is_multilingual, task=task,
+ language=language)
- segments = []
+ segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
@@ -245,14 +263,11 @@ class FasterWhisperPipeline(Pipeline):
text = text[0]
segments.append(
{
- "text": out['text'],
+ "text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
}
)
-
- if del_tokenizer:
- self.tokenizer = None
return {"segments": segments, "language": language}
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 3edc746..691e3f9 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -21,6 +21,7 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
+ parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
@@ -78,6 +79,7 @@ def cli():
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
+ device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush")
@@ -144,7 +146,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
- model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
+ model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
diff --git a/whisperx/types.py b/whisperx/types.py
new file mode 100644
index 0000000..75d4485
--- /dev/null
+++ b/whisperx/types.py
@@ -0,0 +1,58 @@
+from typing import TypedDict, Optional
+
+
+class SingleWordSegment(TypedDict):
+ """
+ A single word of a speech.
+ """
+ word: str
+ start: float
+ end: float
+ score: float
+
+class SingleCharSegment(TypedDict):
+ """
+ A single char of a speech.
+ """
+ char: str
+ start: float
+ end: float
+ score: float
+
+
+class SingleSegment(TypedDict):
+ """
+ A single segment (up to multiple sentences) of a speech.
+ """
+
+ start: float
+ end: float
+ text: str
+
+
+class SingleAlignedSegment(TypedDict):
+ """
+ A single segment (up to multiple sentences) of a speech with word alignment.
+ """
+
+ start: float
+ end: float
+ text: str
+ words: list[SingleWordSegment]
+ chars: Optional[list[SingleCharSegment]]
+
+
+class TranscriptionResult(TypedDict):
+ """
+ A list of segments and word segments of a speech.
+ """
+ segments: list[SingleSegment]
+ language: str
+
+
+class AlignedTranscriptionResult(TypedDict):
+ """
+ A list of segments and word segments of a speech.
+ """
+ segments: list[SingleAlignedSegment]
+ word_segments: list[SingleWordSegment]
diff --git a/whisperx/vad.py b/whisperx/vad.py
index a7a2451..15a9e5e 100644
--- a/whisperx/vad.py
+++ b/whisperx/vad.py
@@ -157,7 +157,7 @@ class Binarize:
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
# switching from active to inactive
- elif y <= self.offset:
+ elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
@@ -169,7 +169,7 @@ class Binarize:
# currently inactive
else:
# switching from inactive to active
- if y >= self.onset:
+ if y > self.onset:
start = t
is_active = True