Merge pull request #290 from m-bain/main

push contributions from main
This commit is contained in:
Max Bain
2023-05-29 12:55:24 +01:00
committed by GitHub
6 changed files with 127 additions and 47 deletions

View File

@ -52,13 +52,6 @@ This repository provides fast automatic speech recognition (70x realtime with la
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*!
- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
<h2 align="left", id="highlights">New🚨</h2>
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
@ -87,11 +80,11 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#v20
### 3. Install this repo
`pip install git+https://github.com/m-bain/whisperx.git@v3`
`pip install git+https://github.com/m-bain/whisperx.git`
If already installed, update package to most recent commit
`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
If wishing to modify this package, clone and install in editable mode:
```
@ -183,10 +176,10 @@ print(result["segments"]) # after alignment
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(input_audio_path)
# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_segments = diarize_model(audio_file)
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
result = assign_word_speakers(diarize_segments, result)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union
from typing import Iterator, Union, List
import numpy as np
import pandas as pd
@ -13,7 +13,11 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -32,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
@ -39,7 +44,10 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
}
@ -80,14 +88,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
transcript: Iterator[dict],
transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
):
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
@ -139,14 +147,18 @@ def align(
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments = []
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
@ -154,7 +166,7 @@ def align(
t2 = segment["end"]
text = segment["text"]
aligned_seg = {
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
@ -307,7 +319,7 @@ def align(
aligned_segments += aligned_subsegments
# create word_segments list
word_segments = []
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]

View File

@ -11,10 +11,18 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None, task="transcribe"):
def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language=None,
vad_options=None,
model=None,
task="transcribe",
download_root=None):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
@ -22,14 +30,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
download_root: Optional[str] - The root directory to download the model to.
Returns:
A Whisper pipeline.
'''
'''
if whisper_arch.endswith(".en"):
language = "en"
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
model = WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
@ -114,7 +127,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
)
tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str:
@ -127,7 +140,7 @@ class WhisperModel(faster_whisper.WhisperModel):
text = decode_batch(tokens_batch)
return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
@ -136,9 +149,9 @@ class WhisperModel(faster_whisper.WhisperModel):
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
@ -176,7 +189,7 @@ class FasterWhisperPipeline(Pipeline):
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
super(Pipeline, self).__init__()
self.vad_model = vad
@ -194,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
def postprocess(self, model_outputs):
return model_outputs
@ -214,11 +227,11 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
@ -228,16 +241,21 @@ class FasterWhisperPipeline(Pipeline):
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
del_tokenizer = False
if self.tokenizer is None:
language = self.detect_language(audio)
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
del_tokenizer = True
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
else:
language = self.tokenizer.language_code
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
segments = []
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
@ -245,14 +263,11 @@ class FasterWhisperPipeline(Pipeline):
text = text[0]
segments.append(
{
"text": out['text'],
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
}
)
if del_tokenizer:
self.tokenizer = None
return {"segments": segments, "language": language}

View File

@ -21,6 +21,7 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
@ -78,6 +79,7 @@ def cli():
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush")
@ -144,7 +146,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)

58
whisperx/types.py Normal file
View File

@ -0,0 +1,58 @@
from typing import TypedDict, Optional
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
words: list[SingleWordSegment]
chars: Optional[list[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleAlignedSegment]
word_segments: list[SingleWordSegment]

View File

@ -157,7 +157,7 @@ class Binarize:
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
# switching from active to inactive
elif y <= self.offset:
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
@ -169,7 +169,7 @@ class Binarize:
# currently inactive
else:
# switching from inactive to active
if y >= self.onset:
if y > self.onset:
start = t
is_active = True