diff --git a/README.md b/README.md
index 28345f1..b52401b 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ This repository provides fast automatic speech recognition (70x realtime with la
New🚨
+- _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
@@ -74,7 +75,7 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
-`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
+`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
@@ -184,6 +185,11 @@ print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```
+## Demos 🚀
+
+[](https://replicate.com/daanelson/whisperx)
+
+If you don't have access to your own GPUs, use the link above to try out WhisperX.
Technical Details 👷♂️
@@ -276,7 +282,7 @@ If you use this in your research, please cite the paper:
@article{bain2022whisperx,
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
- journal={arXiv preprint, arXiv:2303.00747},
+ journal={INTERSPEECH 2023},
year={2023}
}
```
diff --git a/requirements.txt b/requirements.txt
index ec90a07..ddfa28a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-torch==2.0.0
-torchaudio==2.0.1
+torch>=2
+torchaudio>=2
faster-whisper
transformers
-ffmpeg-python==0.2.0
+ffmpeg-python>=0.2
pandas
-setuptools==65.6.3
-nltk
\ No newline at end of file
+setuptools>=65
+nltk
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index 13dfddc..2717bc4 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -15,6 +15,9 @@ from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
+from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
+
+PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -33,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
+ "cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
@@ -42,6 +46,9 @@ DEFAULT_ALIGN_MODELS_HF = {
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
+ "vi": 'nguyenvulebinh/wav2vec2-base-vi',
+ "ko": "kresnik/wav2vec2-large-xlsr-korean",
+ "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
}
@@ -141,7 +148,11 @@ def align(
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
- sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
+
+ punkt_param = PunktParameters()
+ punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
+ sentence_splitter = PunktSentenceTokenizer(punkt_param)
+ sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
@@ -300,6 +311,8 @@ def align(
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
+ if model_lang in LANGUAGES_WITHOUT_SPACES:
+ agg_dict["text"] = "".join
if return_char_alignments:
agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index 88d5bf6..09454c9 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -13,8 +13,25 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
-def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
- vad_options=None, model=None, task="transcribe"):
+def find_numeral_symbol_tokens(tokenizer):
+ numeral_symbol_tokens = []
+ for i in range(tokenizer.eot):
+ token = tokenizer.decode([i]).removeprefix(" ")
+ has_numeral_symbol = any(c in "0123456789%$£" for c in token)
+ if has_numeral_symbol:
+ numeral_symbol_tokens.append(i)
+ return numeral_symbol_tokens
+
+def load_model(whisper_arch,
+ device,
+ device_index=0,
+ compute_type="float16",
+ asr_options=None,
+ language=None,
+ vad_options=None,
+ model=None,
+ task="transcribe",
+ download_root=None):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
@@ -22,14 +39,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
+ download_root: Optional[str] - The root directory to download the model to.
Returns:
A Whisper pipeline.
- '''
+ '''
if whisper_arch.endswith(".en"):
language = "en"
- model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
+ model = WhisperModel(whisper_arch,
+ device=device,
+ device_index=device_index,
+ compute_type=compute_type,
+ download_root=download_root)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
@@ -54,11 +76,22 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
- "append_punctuations": "\"'.。,,!!??::”)]}、"
+ "append_punctuations": "\"'.。,,!!??::”)]}、",
+ "suppress_numerals": False,
}
if asr_options is not None:
default_asr_options.update(asr_options)
+
+ if default_asr_options["suppress_numerals"]:
+ if tokenizer is None:
+ tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
+ numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
+ print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
+ default_asr_options["suppress_tokens"] += numeral_symbol_tokens
+ default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
+ del default_asr_options["suppress_numerals"]
+
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = {
@@ -106,15 +139,12 @@ class WhisperModel(faster_whisper.WhisperModel):
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
- # length_penalty=options.length_penalty,
- # max_length=self.max_length,
- # return_scores=True,
- # return_no_speech_prob=True,
- # suppress_blank=options.suppress_blank,
- # suppress_tokens=options.suppress_tokens,
- # max_initial_timestamp_index=max_initial_timestamp_index,
+ length_penalty=options.length_penalty,
+ max_length=self.max_length,
+ suppress_blank=options.suppress_blank,
+ suppress_tokens=options.suppress_tokens,
)
-
+
tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str:
@@ -127,7 +157,7 @@ class WhisperModel(faster_whisper.WhisperModel):
text = decode_batch(tokens_batch)
return text
-
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
@@ -136,9 +166,9 @@ class WhisperModel(faster_whisper.WhisperModel):
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
-
+
return self.model.encode(features, to_cpu=to_cpu)
-
+
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
@@ -176,7 +206,7 @@ class FasterWhisperPipeline(Pipeline):
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
-
+
super(Pipeline, self).__init__()
self.vad_model = vad
@@ -194,7 +224,7 @@ class FasterWhisperPipeline(Pipeline):
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
-
+
def postprocess(self, model_outputs):
return model_outputs
@@ -214,11 +244,11 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
- self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
+ self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
-
+
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
@@ -228,14 +258,19 @@ class FasterWhisperPipeline(Pipeline):
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
-
- del_tokenizer = False
if self.tokenizer is None:
- language = self.detect_language(audio)
- self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
- del_tokenizer = True
+ language = language or self.detect_language(audio)
+ task = task or "transcribe"
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
+ self.model.model.is_multilingual, task=task,
+ language=language)
else:
- language = self.tokenizer.language_code
+ language = language or self.tokenizer.language_code
+ task = task or self.tokenizer.task
+ if task != self.tokenizer.task or language != self.tokenizer.language_code:
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
+ self.model.model.is_multilingual, task=task,
+ language=language)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
@@ -250,9 +285,6 @@ class FasterWhisperPipeline(Pipeline):
"end": round(vad_segments[idx]['end'], 3)
}
)
-
- if del_tokenizer:
- self.tokenizer = None
return {"segments": segments, "language": language}
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 3edc746..1cc144e 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -21,11 +21,12 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
- parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
+ parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
+ parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -50,9 +51,11 @@ def cli():
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
- parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
+ parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
+ parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@@ -78,6 +81,7 @@ def cli():
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
+ device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush")
@@ -128,6 +132,8 @@ def cli():
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
+ "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
+ "suppress_numerals": args.pop("suppress_numerals"),
}
writer = get_writer(output_format, output_dir)
@@ -144,7 +150,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
- model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
+ model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
@@ -204,4 +210,4 @@ def cli():
writer(result, audio_path, writer_args)
if __name__ == "__main__":
- cli()
\ No newline at end of file
+ cli()
diff --git a/whisperx/types.py b/whisperx/types.py
index 75d4485..68f2d78 100644
--- a/whisperx/types.py
+++ b/whisperx/types.py
@@ -1,4 +1,4 @@
-from typing import TypedDict, Optional
+from typing import TypedDict, Optional, List
class SingleWordSegment(TypedDict):
@@ -38,15 +38,15 @@ class SingleAlignedSegment(TypedDict):
start: float
end: float
text: str
- words: list[SingleWordSegment]
- chars: Optional[list[SingleCharSegment]]
+ words: List[SingleWordSegment]
+ chars: Optional[List[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
- segments: list[SingleSegment]
+ segments: List[SingleSegment]
language: str
@@ -54,5 +54,5 @@ class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
- segments: list[SingleAlignedSegment]
- word_segments: list[SingleWordSegment]
+ segments: List[SingleAlignedSegment]
+ word_segments: List[SingleWordSegment]
diff --git a/whisperx/utils.py b/whisperx/utils.py
index d042bb7..36c7543 100644
--- a/whisperx/utils.py
+++ b/whisperx/utils.py
@@ -365,6 +365,28 @@ class WriteTSV(ResultWriter):
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
+class WriteAudacity(ResultWriter):
+ """
+ Write a transcript to a text file that audacity can import as labels.
+ The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
+ Yet this is not an audacity project but only a label file!
+
+ Please note : Audacity uses seconds in timestamps not ms!
+ Also there is no header expected.
+
+ If speaker is provided it is prepended to the text between double square brackets [[]].
+ """
+
+ extension: str = "aud"
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ ARROW = " "
+ for segment in result["segments"]:
+ print(segment["start"], file=file, end=ARROW)
+ print(segment["end"], file=file, end=ARROW)
+ print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
+
+
class WriteJSON(ResultWriter):
extension: str = "json"
@@ -383,6 +405,9 @@ def get_writer(
"tsv": WriteTSV,
"json": WriteJSON,
}
+ optional_writers = {
+ "aud": WriteAudacity,
+ }
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
@@ -393,10 +418,12 @@ def get_writer(
return write_all
+ if output_format in optional_writers:
+ return optional_writers[output_format](output_dir)
return writers[output_format](output_dir)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
- return x.ffill().bfill()
\ No newline at end of file
+ return x.ffill().bfill()
diff --git a/whisperx/vad.py b/whisperx/vad.py
index 42b0bfb..15a9e5e 100644
--- a/whisperx/vad.py
+++ b/whisperx/vad.py
@@ -147,8 +147,6 @@ class Binarize:
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
- # if curr_duration > 15:
- # import pdb; pdb.set_trace()
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
@@ -166,14 +164,14 @@ class Binarize:
is_active = False
curr_scores = []
curr_timestamps = []
+ curr_scores.append(y)
+ curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
- curr_scores.append(y)
- curr_timestamps.append(t)
# if active at the end, add final region
if is_active: