mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge branch 'v3' of https://github.com/m-bain/whisperX into v3
Conflicts: whisperx/asr.py
This commit is contained in:
17
README.md
17
README.md
@ -52,13 +52,6 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
||||
|
||||
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
|
||||
|
||||
- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*!
|
||||
- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper.
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
|
||||
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
||||
- Character level timestamps (see `*.char.ass` file output)
|
||||
- Diarization (still in beta, add `--diarize`)
|
||||
|
||||
<h2 align="left", id="highlights">New🚨</h2>
|
||||
|
||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||
@ -87,11 +80,11 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#v20
|
||||
|
||||
### 3. Install this repo
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git@v3`
|
||||
`pip install git+https://github.com/m-bain/whisperx.git`
|
||||
|
||||
If already installed, update package to most recent commit
|
||||
|
||||
`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
|
||||
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
|
||||
|
||||
If wishing to modify this package, clone and install in editable mode:
|
||||
```
|
||||
@ -183,10 +176,10 @@ print(result["segments"]) # after alignment
|
||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||
|
||||
# add min/max number of speakers if known
|
||||
diarize_segments = diarize_model(input_audio_path)
|
||||
# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_segments = diarize_model(audio_file)
|
||||
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
||||
result = assign_word_speakers(diarize_segments, result)
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
print(diarize_segments)
|
||||
print(result["segments"]) # segments are now assigned speaker IDs
|
||||
```
|
||||
|
@ -3,7 +3,7 @@ Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -13,7 +13,11 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@ -32,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
||||
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
||||
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
||||
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
||||
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
||||
@ -39,7 +44,10 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
|
||||
"ko": "kresnik/wav2vec2-large-xlsr-korean",
|
||||
}
|
||||
|
||||
|
||||
@ -80,14 +88,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
transcript: Iterator[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
):
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
@ -139,14 +147,18 @@ def align(
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||
|
||||
punkt_param = PunktParameters()
|
||||
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
@ -154,7 +166,7 @@ def align(
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
aligned_seg = {
|
||||
aligned_seg: SingleAlignedSegment = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
@ -307,7 +319,7 @@ def align(
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments = []
|
||||
word_segments: List[SingleWordSegment] = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
|
@ -11,6 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
<<<<<<< HEAD
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
@ -23,6 +24,20 @@ def find_numeral_symbol_tokens(tokenizer):
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None, task="transcribe"):
|
||||
=======
|
||||
from .types import TranscriptionResult, SingleSegment
|
||||
|
||||
def load_model(whisper_arch,
|
||||
device,
|
||||
device_index=0,
|
||||
compute_type="float16",
|
||||
asr_options=None,
|
||||
language=None,
|
||||
vad_options=None,
|
||||
model=None,
|
||||
task="transcribe",
|
||||
download_root=None):
|
||||
>>>>>>> ec6a110cdf2616919cfd0a616f9ae2fbdd44903f
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
@ -30,14 +45,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
download_root: Optional[str] - The root directory to download the model to.
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
|
||||
model = WhisperModel(whisper_arch,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
download_root=download_root)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
@ -133,7 +153,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
@ -146,7 +166,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
@ -155,9 +175,9 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
||||
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
@ -195,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
|
||||
@ -213,7 +233,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
@ -233,11 +253,11 @@ class FasterWhisperPipeline(Pipeline):
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
):
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
@ -247,16 +267,21 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
del_tokenizer = False
|
||||
if self.tokenizer is None:
|
||||
language = self.detect_language(audio)
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
|
||||
del_tokenizer = True
|
||||
language = language or self.detect_language(audio)
|
||||
task = task or "transcribe"
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual, task=task,
|
||||
language=language)
|
||||
else:
|
||||
language = self.tokenizer.language_code
|
||||
language = language or self.tokenizer.language_code
|
||||
task = task or self.tokenizer.task
|
||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual, task=task,
|
||||
language=language)
|
||||
|
||||
segments = []
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
text = out['text']
|
||||
@ -264,14 +289,11 @@ class FasterWhisperPipeline(Pipeline):
|
||||
text = text[0]
|
||||
segments.append(
|
||||
{
|
||||
"text": out['text'],
|
||||
"text": text,
|
||||
"start": round(vad_segments[idx]['start'], 3),
|
||||
"end": round(vad_segments[idx]['end'], 3)
|
||||
}
|
||||
)
|
||||
|
||||
if del_tokenizer:
|
||||
self.tokenizer = None
|
||||
|
||||
return {"segments": segments, "language": language}
|
||||
|
||||
|
@ -21,6 +21,7 @@ def cli():
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
@ -80,6 +81,7 @@ def cli():
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
device_index: int = args.pop("device_index")
|
||||
compute_type: str = args.pop("compute_type")
|
||||
|
||||
# model_flush: bool = args.pop("model_flush")
|
||||
@ -148,7 +150,7 @@ def cli():
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
|
||||
model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
|
58
whisperx/types.py
Normal file
58
whisperx/types.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
"""
|
||||
A single word of a speech.
|
||||
"""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
class SingleCharSegment(TypedDict):
|
||||
"""
|
||||
A single char of a speech.
|
||||
"""
|
||||
char: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
|
||||
class SingleSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: list[SingleWordSegment]
|
||||
chars: Optional[list[SingleCharSegment]]
|
||||
|
||||
|
||||
class TranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleSegment]
|
||||
language: str
|
||||
|
||||
|
||||
class AlignedTranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleAlignedSegment]
|
||||
word_segments: list[SingleWordSegment]
|
@ -157,7 +157,7 @@ class Binarize:
|
||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
||||
# switching from active to inactive
|
||||
elif y <= self.offset:
|
||||
elif y < self.offset:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = t
|
||||
@ -169,7 +169,7 @@ class Binarize:
|
||||
# currently inactive
|
||||
else:
|
||||
# switching from inactive to active
|
||||
if y >= self.onset:
|
||||
if y > self.onset:
|
||||
start = t
|
||||
is_active = True
|
||||
|
||||
|
Reference in New Issue
Block a user