mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge branch 'main' into danish_alignment
This commit is contained in:
@ -1,115 +1,4 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel
|
||||
from .transcribe import load_model
|
||||
from .alignment import load_align_model, align
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
return model.to(device)
|
||||
from .audio import load_audio
|
||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
@ -2,16 +2,19 @@
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Union, Iterator, TYPE_CHECKING
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
import torchaudio
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
import torchaudio
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
|
||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
||||
import nltk
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@ -37,11 +40,12 @@ DEFAULT_ALIGN_MODELS_HF = {
|
||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech"
|
||||
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
|
||||
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||
}
|
||||
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
@ -56,7 +60,7 @@ def load_align_model(language_code, device, model_name=None):
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
pipeline_type = "torchaudio"
|
||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||
align_model = bundle.get_model().to(device)
|
||||
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
@ -78,362 +82,232 @@ def load_align_model(language_code, device, model_name=None):
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
transcript: Iterator[SingleSegment],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
interpolate_method: str = "nearest",
|
||||
):
|
||||
return_char_alignments: bool = False,
|
||||
) -> AlignedTranscriptionResult:
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
Force align phoneme recognition predictions to known transcription
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcript: Iterator[dict]
|
||||
The Whisper model instance
|
||||
|
||||
model: torch.nn.Module
|
||||
Alignment model (wav2vec2)
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
device: str
|
||||
cuda device
|
||||
|
||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
||||
diarization segments with speaker labels.
|
||||
|
||||
extend_duration: float
|
||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
model_lang = align_model_metadata["language"]
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
aligned_segments = []
|
||||
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = {
|
||||
"segment-idx": [],
|
||||
"subsegment-idx": [],
|
||||
"word-idx": [],
|
||||
"char": [],
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
}
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
for sdx, segment in enumerate(transcript):
|
||||
while True:
|
||||
segment_align_success = False
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
transcription = segment["text"]
|
||||
# split into words
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = text.split(" ")
|
||||
else:
|
||||
per_word = text
|
||||
|
||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||
# e.g. "$300" -> "three hundred dollars"
|
||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||
|
||||
# split into words
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(text):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = transcription.split(" ")
|
||||
else:
|
||||
per_word = transcription
|
||||
|
||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(transcription):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(transcription) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# if no characters are in the dictionary, then we skip this segment...
|
||||
if len(clean_char) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
break
|
||||
|
||||
transcription_cleaned = "".join(clean_char)
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
|
||||
# we only pad if not using VAD filtering
|
||||
if "seg_text" not in segment:
|
||||
# pad according original timestamps
|
||||
t1 = max(segment["start"] - extend_duration, 0)
|
||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||
|
||||
# use prev_t2 as current t1 if it"s later
|
||||
if start_from_previous and t1 < prev_t2:
|
||||
t1 = prev_t2
|
||||
|
||||
# check if timestamp range is still valid
|
||||
if t1 >= MAX_DURATION:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
break
|
||||
if t2 - t1 < 0.02:
|
||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||
break
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
trellis = get_trellis(emission, tokens)
|
||||
path = backtrack(trellis, emission, tokens)
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
break
|
||||
char_segments = merge_repeats(path, transcription_cleaned)
|
||||
# word_segments = merge_words(char_segments)
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(text) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
# sub-segments
|
||||
if "seg-text" not in segment:
|
||||
segment["seg-text"] = [transcription]
|
||||
|
||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
||||
sub_seg_idx = 0
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
wdx = 0
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
for cdx, char in enumerate(transcription + " "):
|
||||
is_last = False
|
||||
if cdx == len(transcription):
|
||||
break
|
||||
elif cdx+1 == len(transcription):
|
||||
is_last = True
|
||||
|
||||
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = char_seg.start * ratio + t1
|
||||
end = char_seg.end * ratio + t1
|
||||
score = char_seg.score
|
||||
|
||||
char_segments_arr["char"].append(char)
|
||||
char_segments_arr["start"].append(start)
|
||||
char_segments_arr["end"].append(end)
|
||||
char_segments_arr["score"].append(score)
|
||||
char_segments_arr["word-idx"].append(wdx)
|
||||
char_segments_arr["segment-idx"].append(sdx)
|
||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
||||
|
||||
# word-level info
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character == word
|
||||
wdx += 1
|
||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx += 1
|
||||
|
||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx = 0
|
||||
sub_seg_idx += 1
|
||||
|
||||
prev_t2 = segment["end"]
|
||||
|
||||
segment_align_success = True
|
||||
# end while True loop
|
||||
break
|
||||
|
||||
# reset prev_t2 due to drifting issues
|
||||
if not segment_align_success:
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
not_space = char_segments_arr["char"] != " "
|
||||
|
||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
||||
|
||||
word_segments_arr = {}
|
||||
|
||||
# start of word is first char with a timestamp
|
||||
word_segments_arr["start"] = per_word_grp["start"].min().values
|
||||
# end of word is last char with a timestamp
|
||||
word_segments_arr["end"] = per_word_grp["end"].max().values
|
||||
# score of word is mean (excluding nan)
|
||||
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
||||
|
||||
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
||||
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||
|
||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
||||
segments_arr = {}
|
||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
||||
segments_arr = pd.DataFrame(segments_arr)
|
||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
||||
|
||||
# interpolate missing words / sub-segments
|
||||
if interpolate_method != "ignore":
|
||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
# we still know which word timestamps are interpolated because their score == nan
|
||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
# merge words & subsegments which are missing times
|
||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||
|
||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||
|
||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
||||
else:
|
||||
word_segments_arr.dropna(inplace=True)
|
||||
segments_arr.dropna(inplace=True)
|
||||
|
||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
||||
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments_word = []
|
||||
|
||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
||||
|
||||
for sdx, srow in segments_arr.iterrows():
|
||||
|
||||
seg_idx = int(srow["segment-idx"])
|
||||
try:
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
except:
|
||||
import pdb; pdb.set_trace()
|
||||
sub_end = int(srow["subsegment-idx-end"])
|
||||
|
||||
seg = transcript[seg_idx]
|
||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
wseg["start"].fillna(srow["start"], inplace=True)
|
||||
wseg["end"].fillna(srow["end"], inplace=True)
|
||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
||||
|
||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
# fixes bug for single segment in transcript
|
||||
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
||||
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
||||
if 'level_1' in cseg: del cseg['level_1']
|
||||
if 'level_0' in cseg: del cseg['level_0']
|
||||
cseg.reset_index(inplace=True)
|
||||
aligned_segments.append(
|
||||
{
|
||||
"start": srow["start"],
|
||||
"end": srow["end"],
|
||||
"text": text,
|
||||
"word-segments": wseg,
|
||||
"char-segments": cseg
|
||||
}
|
||||
)
|
||||
|
||||
def get_raw_text(word_row):
|
||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
||||
|
||||
wdx = 0
|
||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
||||
if len(wseg) > 1:
|
||||
for _, wrow in wseg.iloc[1:].iterrows():
|
||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"],
|
||||
}
|
||||
)
|
||||
curr_text = ""
|
||||
curr_text += " " + get_raw_text(wrow)
|
||||
wdx += 1
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"]
|
||||
}
|
||||
)
|
||||
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
aligned_seg: SingleAlignedSegment = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
"words": [],
|
||||
}
|
||||
|
||||
if return_char_alignments:
|
||||
aligned_seg["chars"] = []
|
||||
|
||||
# check we can align
|
||||
if len(segment["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
text_clean = "".join(segment["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
# TODO: Probably can get some speedup gain with batched inference here
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
duration = t2 -t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
char_segments_arr = []
|
||||
word_idx = 0
|
||||
for cdx, char in enumerate(text):
|
||||
start, end, score = None, None, None
|
||||
if cdx in segment["clean_cdx"]:
|
||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = round(char_seg.score, 3)
|
||||
|
||||
char_segments_arr.append(
|
||||
{
|
||||
"char": char,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"score": score,
|
||||
"word-idx": word_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
word_idx += 1
|
||||
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||
word_idx += 1
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
|
||||
aligned_subsegments = []
|
||||
# assign sentence_idx to each character index
|
||||
char_segments_arr["sentence-idx"] = None
|
||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
sentence_end = curr_chars["end"].max()
|
||||
sentence_words = []
|
||||
|
||||
for word_idx in curr_chars["word-idx"].unique():
|
||||
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
||||
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||
if len(word_text) == 0:
|
||||
continue
|
||||
word_start = word_chars["start"].min()
|
||||
word_end = word_chars["end"].max()
|
||||
word_score = round(word_chars["score"].mean(), 3)
|
||||
|
||||
# -1 indicates unalignable
|
||||
word_segment = {"word": word_text}
|
||||
|
||||
if not np.isnan(word_start):
|
||||
word_segment["start"] = word_start
|
||||
if not np.isnan(word_end):
|
||||
word_segment["end"] = word_end
|
||||
if not np.isnan(word_score):
|
||||
word_segment["score"] = word_score
|
||||
|
||||
sentence_words.append(word_segment)
|
||||
|
||||
aligned_subsegments.append({
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"words": sentence_words,
|
||||
})
|
||||
|
||||
if return_char_alignments:
|
||||
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
||||
curr_chars.fillna(-1, inplace=True)
|
||||
curr_chars = curr_chars.to_dict("records")
|
||||
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
||||
aligned_subsegments[-1]["chars"] = curr_chars
|
||||
|
||||
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
||||
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
||||
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||
# concatenate sentences with same timestamps
|
||||
agg_dict = {"text": " ".join, "words": "sum"}
|
||||
if return_char_alignments:
|
||||
agg_dict["chars"] = "sum"
|
||||
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||
aligned_subsegments = aligned_subsegments.to_dict('records')
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments: List[SingleWordSegment] = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": word_segments}
|
||||
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
|
270
whisperx/asr.py
Normal file
270
whisperx/asr.py
Normal file
@ -0,0 +1,270 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import ctranslate2
|
||||
import faster_whisper
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
from .types import TranscriptionResult, SingleSegment
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None):
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
device: str - The device to load the model on.
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
|
||||
|
||||
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
batch_size = features.shape[0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
encoder_output = self.encode(features)
|
||||
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
# length_penalty=options.length_penalty,
|
||||
# max_length=self.max_length,
|
||||
# return_scores=True,
|
||||
# return_no_speech_prob=True,
|
||||
# suppress_blank=options.suppress_blank,
|
||||
# suppress_tokens=options.suppress_tokens,
|
||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
||||
)
|
||||
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
res = []
|
||||
for tk in tokens:
|
||||
res.append([token for token in tk if token < tokenizer.eot])
|
||||
# text_tokens = [token for token in tokens if token < self.eot]
|
||||
return tokenizer.tokenizer.decode_batch(res)
|
||||
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
# unsqueeze if batch size = 1
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
vad,
|
||||
options,
|
||||
tokenizer=None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework = "pt",
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
self.call_count = 0
|
||||
self.framework = framework
|
||||
if self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
elif isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
elif device < 0:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "tokenizer" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
||||
):
|
||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
# TODO hack by collating feature_extractor and image_processor
|
||||
|
||||
def stack(items):
|
||||
return {'inputs': torch.stack([x['inputs'] for x in items])}
|
||||
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
|
||||
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
||||
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
||||
return final_iterator
|
||||
|
||||
def transcribe(
|
||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||
) -> TranscriptionResult:
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
del_tokenizer = False
|
||||
if self.tokenizer is None:
|
||||
language = self.detect_language(audio)
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
|
||||
del_tokenizer = True
|
||||
else:
|
||||
language = self.tokenizer.language_code
|
||||
|
||||
segments: List[SingleSegment] = []
|
||||
batch_size = batch_size or self._batch_size
|
||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||
text = out['text']
|
||||
if batch_size in [0, 1, None]:
|
||||
text = text[0]
|
||||
segments.append(
|
||||
{
|
||||
"text": text,
|
||||
"start": round(vad_segments[idx]['start'], 3),
|
||||
"end": round(vad_segments[idx]['end'], 3)
|
||||
}
|
||||
)
|
||||
|
||||
if del_tokenizer:
|
||||
self.tokenizer = None
|
||||
|
||||
return {"segments": segments, "language": language}
|
||||
|
||||
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||
encoder_output = self.model.encode(segment)
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
@ -1 +0,0 @@
|
||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
0
whisperx/assets/mel_filters.npz
Executable file → Normal file
0
whisperx/assets/mel_filters.npz
Executable file → Normal file
@ -1 +0,0 @@
|
||||
{"<|endoftext|>": 50257}
|
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
@ -1 +0,0 @@
|
||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
@ -15,8 +15,12 @@ N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
@ -55,7 +59,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
||||
array = array.index_select(
|
||||
dim=axis, index=torch.arange(length, device=array.device)
|
||||
)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
@ -85,11 +91,18 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
@ -101,6 +114,12 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
@ -111,6 +130,10 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
@ -121,4 +144,4 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
return log_spec
|
||||
|
@ -1,710 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : Tensor, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||
|
||||
# options for ranking generations (either beams or best-of-N samples)
|
||||
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||
|
||||
# prompt, prefix, and token suppression
|
||||
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: Tensor
|
||||
language: str
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rearrange_kv_cache(self, source_indices) -> None:
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_caching(self) -> None:
|
||||
"""Clean up any resources or hooks after decoding is finished"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchInference(Inference):
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
if not self.kv_cache:
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
for module, tensor in self.kv_cache.items():
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[module] = tensor[source_indices].detach()
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
temperature = self.temperature
|
||||
if temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None: # for the first update
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
|
||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens[idx].tolist()
|
||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||
sequence = tuple(prefix + [token.item()])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
|
||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
|
||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
if len(sequences) >= self.beam_size:
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
]
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||
"""Apply any filtering or masking to logits in-place
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int]):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
logits[:, self.suppress_tokens] = -np.inf
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
self.decoder = BeamSearchDecoder(
|
||||
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
)
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
prefix = self.options.prefix
|
||||
prompt = self.options.prompt
|
||||
|
||||
if prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||
)
|
||||
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: Tensor):
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||
assert audio_features.shape[0] == tokens.shape[0]
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
|
||||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
self.inference.cleanup_caching()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||
]
|
||||
|
||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||
|
||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||
A tensor containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
if single:
|
||||
result = result[0]
|
||||
|
||||
return result
|
@ -1,53 +1,63 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||
|
||||
for seg in result_segments:
|
||||
wdf = seg['word-segments']
|
||||
if len(wdf['start'].dropna()) == 0:
|
||||
wdf['start'] = seg['start']
|
||||
wdf['end'] = seg['end']
|
||||
speakers = []
|
||||
for wdx, wrow in wdf.iterrows():
|
||||
if not np.isnan(wrow['start']):
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) == 0:
|
||||
speaker = None
|
||||
else:
|
||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
||||
else:
|
||||
speaker = None
|
||||
speakers.append(speaker)
|
||||
seg['word-segments']['speaker'] = speakers
|
||||
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization@2.1",
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
for seg in result_segments:
|
||||
wseg = pd.DataFrame(seg["word-segments"])
|
||||
for wdx, wrow in wseg.iterrows():
|
||||
if wrow["start"] is not None:
|
||||
speaker = wrow['speaker']
|
||||
if speaker is None or speaker == np.nan:
|
||||
speaker = "UNKNOWN"
|
||||
word_seg.append(
|
||||
{
|
||||
"start": wrow["start"],
|
||||
"end": wrow["end"],
|
||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||
}
|
||||
)
|
||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
||||
return diarize_df
|
||||
|
||||
# TODO: create segments but split words on new speaker
|
||||
|
||||
return result_segments, word_seg
|
||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
||||
# remove no hit, otherwise we look for closest (even negative intersection...)
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
seg["speaker"] = speaker
|
||||
|
||||
# assign speaker to words
|
||||
if 'words' in seg:
|
||||
for word in seg['words']:
|
||||
if 'start' in word:
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
word["speaker"] = speaker
|
||||
|
||||
return transcript_result
|
||||
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
|
@ -1,268 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||
cache[module] = output # save as-is, for the first token or cross attention
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
@ -1,2 +0,0 @@
|
||||
from .basic import BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer
|
@ -1,71 +0,0 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
File diff suppressed because it is too large
Load Diff
@ -1,543 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
@ -1,331 +0,0 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||
|
||||
tokenizer: "GPT2TokenizerFast"
|
||||
language: Optional[str]
|
||||
sot_sequence: Tuple[int]
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
outputs = [[]]
|
||||
for token in tokens:
|
||||
if token >= self.timestamp_begin:
|
||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
)
|
||||
)
|
||||
candidate = f"<|{self.language}|>"
|
||||
if candidate in additional_tokens:
|
||||
return additional_tokens[candidate]
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
):
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def _get_single_token_id(self, text) -> int:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||
|
||||
specials = [
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
]
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
language: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
tokenizer_name = "multilingual"
|
||||
task = task or "transcribe"
|
||||
language = language or "en"
|
||||
else:
|
||||
tokenizer_name = "gpt2"
|
||||
task = None
|
||||
language = None
|
||||
|
||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||
sot: int = all_special_ids[1]
|
||||
translate: int = all_special_ids[-6]
|
||||
transcribe: int = all_special_ids[-5]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(language))
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
@ -1,602 +1,51 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||
from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .diarize import assign_word_speakers, Segment
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
||||
from .vad import Binarize
|
||||
import pandas as pd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
|
||||
mel: np.ndarray = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||
needs_fallback = True # too repetitive
|
||||
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
||||
if initial_prompt:
|
||||
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt)
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
previous_seek_value = seek
|
||||
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
# clamp end-time to at least be 1 frame after start-time
|
||||
end_timestamp_position = max(end_timestamp_position, start_timestamp_position + time_precision)
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
)
|
||||
last_slice = current_slice
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
|
||||
seek += segment.shape[-1]
|
||||
all_tokens.extend(tokens.tolist())
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||
previous_seek_value = seek
|
||||
|
||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
||||
|
||||
|
||||
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
||||
"""
|
||||
Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
|
||||
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
|
||||
TODO: Or sliding window alignment model over long segment.
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
assert segments_list, "segments_list is empty."
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
|
||||
for seg in segments_list:
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
|
||||
|
||||
def transcribe_with_vad(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
vad_pipeline,
|
||||
mel = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Transcribe per VAD segment
|
||||
"""
|
||||
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
prev = 0
|
||||
output = {"segments": []}
|
||||
|
||||
vad_segments = vad_pipeline(audio)
|
||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||
vad_segments = merge_chunks(vad_segments)
|
||||
|
||||
for sdx, seg_t in enumerate(vad_segments):
|
||||
if verbose:
|
||||
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
|
||||
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
|
||||
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
||||
mel = mel[:, local_f_start:] # seek forward
|
||||
prev = seg_f_start
|
||||
local_mel = mel[:, :local_f_end-local_f_start]
|
||||
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
|
||||
seg_t["text"] = result["text"]
|
||||
output["segments"].append(
|
||||
{
|
||||
"start": seg_t["start"],
|
||||
"end": seg_t["end"],
|
||||
"language": result["language"],
|
||||
"text": result["text"],
|
||||
"seg-text": [x["text"] for x in result["segments"]],
|
||||
"seg-start": [x["start"] for x in result["segments"]],
|
||||
"seg-end": [x["end"] for x in result["segments"]],
|
||||
}
|
||||
)
|
||||
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def transcribe_with_vad_parallel(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
vad_pipeline,
|
||||
mel = None,
|
||||
verbose: Optional[bool] = None,
|
||||
batch_size = -1,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Transcribe per VAD segment
|
||||
"""
|
||||
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
vad_segments = vad_pipeline(audio)
|
||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||
vad_segments = merge_chunks(vad_segments)
|
||||
|
||||
################################
|
||||
### START of parallelization ###
|
||||
################################
|
||||
|
||||
# pad mel to a same length
|
||||
start_seconds = [i['start'] for i in vad_segments]
|
||||
end_seconds = [i['end'] for i in vad_segments]
|
||||
duration_list = np.array(end_seconds) - np.array(start_seconds)
|
||||
max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
|
||||
offset_list = np.array(start_seconds)
|
||||
chunks = []
|
||||
|
||||
for start_ts, end_ts in zip(start_seconds, end_seconds):
|
||||
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||
chunk = mel[:, start_ts:end_ts]
|
||||
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
|
||||
chunks.append(chunk)
|
||||
|
||||
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
|
||||
# using 'decode_options1': only support single temperature decoding (no fallbacks)
|
||||
# result_list2 = model.decode(mel_chunk, decode_options1)
|
||||
|
||||
# prepare DecodingOptions
|
||||
temperatures = kwargs.pop("temperature", None)
|
||||
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
|
||||
logprob_threshold = kwargs.pop("logprob_threshold", None)
|
||||
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
|
||||
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
|
||||
initial_prompt = kwargs.pop("initial_prompt", None)
|
||||
|
||||
t = 0 # TODO: does not upport temperature sweeping
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
|
||||
decode_result = []
|
||||
for mel_chunk_batch in mel_chunk_batches:
|
||||
decode_result.extend(model.decode(mel_chunk_batch, options))
|
||||
|
||||
##############################
|
||||
### END of parallelization ###
|
||||
##############################
|
||||
|
||||
# post processing: get segments rfom batch-decoded results
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
language = kwargs["language"]
|
||||
task = kwargs["task"]
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
output = post_process_results(
|
||||
vad_segments,
|
||||
decode_result,
|
||||
duration_list,
|
||||
offset_list,
|
||||
input_stride,
|
||||
language,
|
||||
tokenizer,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
logprob_threshold=logprob_threshold,
|
||||
verbose=verbose)
|
||||
return output
|
||||
|
||||
|
||||
def post_process_results(
|
||||
vad_segments,
|
||||
result_list,
|
||||
duration_list,
|
||||
offset_list,
|
||||
input_stride,
|
||||
language,
|
||||
tokenizer,
|
||||
no_speech_threshold = None,
|
||||
logprob_threshold = None,
|
||||
verbose: Optional[bool] = None,
|
||||
):
|
||||
|
||||
seek = 0
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
output = {"segments": []}
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
|
||||
# process the output
|
||||
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
|
||||
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_shape # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
)
|
||||
last_slice = current_slice
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
|
||||
seek += segment_shape
|
||||
all_tokens.extend(tokens.tolist())
|
||||
|
||||
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
||||
output["segments"].append(
|
||||
{
|
||||
"start": seg_t["start"],
|
||||
"end": seg_t["end"],
|
||||
"language": result["language"],
|
||||
"text": result["text"],
|
||||
"seg-text": [x["text"] for x in result["segments"]],
|
||||
"seg-start": [x["start"] for x in result["segments"]],
|
||||
"seg-end": [x["end"] for x in result["segments"]],
|
||||
}
|
||||
)
|
||||
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
from .alignment import align, load_align_model
|
||||
from .asr import load_model
|
||||
from .audio import load_audio
|
||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
||||
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
|
||||
optional_int, str2bool)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
||||
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
# vad params
|
||||
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
||||
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int)
|
||||
parser.add_argument("--max_speakers", default=None, type=int)
|
||||
# output save params
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
@ -612,152 +61,142 @@ def cli():
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
batch_size: int = args.pop("batch_size")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_type: str = args.pop("output_type")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
compute_type: str = args.pop("compute_type")
|
||||
|
||||
# model_flush: bool = args.pop("model_flush")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
align_model: str = args.pop("align_model")
|
||||
align_extend: float = args.pop("align_extend")
|
||||
align_from_prev: bool = args.pop("align_from_prev")
|
||||
interpolate_method: bool = args.pop("interpolate_method")
|
||||
|
||||
interpolate_method: str = args.pop("interpolate_method")
|
||||
no_align: bool = args.pop("no_align")
|
||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||
|
||||
hf_token: str = args.pop("hf_token")
|
||||
vad_filter: bool = args.pop("vad_filter")
|
||||
parallel_bs: int = args.pop("parallel_bs")
|
||||
vad_onset: float = args.pop("vad_onset")
|
||||
vad_offset: float = args.pop("vad_offset")
|
||||
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
|
||||
vad_pipeline = None
|
||||
if vad_filter:
|
||||
if hf_token is None:
|
||||
print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
|
||||
from pyannote.audio import Inference
|
||||
vad_pipeline = Inference(
|
||||
"pyannote/segmentation",
|
||||
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||
use_auth_token=hf_token,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
diarize_pipeline = None
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
from pyannote.audio import Pipeline
|
||||
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
|
||||
use_auth_token=hf_token)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||
if temperature_increment_on_fallback is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
threads = args.pop("threads")
|
||||
if threads > 0:
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
asr_options = {
|
||||
"beam_size": args.pop("beam_size"),
|
||||
"patience": args.pop("patience"),
|
||||
"length_penalty": args.pop("length_penalty"),
|
||||
"temperatures": temperature,
|
||||
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
|
||||
"log_prob_threshold": args.pop("logprob_threshold"),
|
||||
"no_speech_threshold": args.pop("no_speech_threshold"),
|
||||
"condition_on_previous_text": False,
|
||||
"initial_prompt": args.pop("initial_prompt"),
|
||||
}
|
||||
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
if no_align:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
|
||||
# Part 1: VAD & ASR Loop
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
if vad_filter:
|
||||
if parallel_bs > 1:
|
||||
print("Performing VAD and parallel transcribing ...")
|
||||
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
|
||||
audio = load_audio(audio_path)
|
||||
# >> VAD & ASR
|
||||
print(">>Performing transcription...")
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload Whisper and VAD
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Part 2: Align Loop
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
for result, audio_path in tmp_results:
|
||||
# >> Align
|
||||
if len(tmp_results) > 1:
|
||||
input_audio = audio_path
|
||||
else:
|
||||
print("Performing VAD...")
|
||||
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
||||
else:
|
||||
print("Performing transcription...")
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
# lazily load audio from part 1
|
||||
input_audio = audio
|
||||
|
||||
if result["language"] != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
||||
|
||||
results.append((result, audio_path))
|
||||
|
||||
print("Performing alignment...")
|
||||
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
# Unload align model
|
||||
del align_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if diarize:
|
||||
print("Performing diarization...")
|
||||
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
# assumes each utterance is single speaker (needs fix)
|
||||
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
|
||||
result_aligned["segments"] = result_segments
|
||||
result_aligned["word_segments"] = word_segments
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
result = assign_word_speakers(diarize_segments, result)
|
||||
results.append((result, input_audio_path))
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
# save TXT
|
||||
if output_type in ["txt", "all"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||
write_txt(result_aligned["segments"], file=txt)
|
||||
|
||||
# save VTT
|
||||
if output_type in ["vtt", "all"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||
write_vtt(result_aligned["segments"], file=vtt)
|
||||
|
||||
# save SRT
|
||||
if output_type in ["srt", "all"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result_aligned["segments"], file=srt)
|
||||
|
||||
# save TSV
|
||||
if output_type in ["tsv", "all"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".tsv"), "w", encoding="utf-8") as srt:
|
||||
write_tsv(result_aligned["segments"], file=srt)
|
||||
|
||||
# save SRT word-level
|
||||
if output_type in ["srt-word", "all"]:
|
||||
# save per-word SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result_aligned["word_segments"], file=srt)
|
||||
|
||||
# save ASS
|
||||
if output_type in ["ass", "all"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass)
|
||||
|
||||
# # save ASS character-level
|
||||
if output_type in ["ass-char"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass, resolution="char")
|
||||
|
||||
# save word tsv
|
||||
if output_type in ["pickle"]:
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
|
||||
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
|
||||
|
||||
# save word tsv
|
||||
if output_type in ["vad"]:
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
|
||||
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
|
||||
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
cli()
|
58
whisperx/types.py
Normal file
58
whisperx/types.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
|
||||
class SingleWordSegment(TypedDict):
|
||||
"""
|
||||
A single word of a speech.
|
||||
"""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
class SingleCharSegment(TypedDict):
|
||||
"""
|
||||
A single char of a speech.
|
||||
"""
|
||||
char: str
|
||||
start: float
|
||||
end: float
|
||||
score: float
|
||||
|
||||
|
||||
class SingleSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
class SingleAlignedSegment(TypedDict):
|
||||
"""
|
||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||
"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: list[SingleWordSegment]
|
||||
chars: Optional[list[SingleCharSegment]]
|
||||
|
||||
|
||||
class TranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleSegment]
|
||||
language: str
|
||||
|
||||
|
||||
class AlignedTranscriptionResult(TypedDict):
|
||||
"""
|
||||
A list of segments and word segments of a speech.
|
||||
"""
|
||||
segments: list[SingleAlignedSegment]
|
||||
word_segments: list[SingleWordSegment]
|
@ -1,8 +1,144 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, TextIO, Iterator, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, TextIO
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
if system_encoding != "utf-8":
|
||||
|
||||
def make_safe(string):
|
||||
# replaces any character not representable using the system default encoding with an '?',
|
||||
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||
|
||||
else:
|
||||
|
||||
def make_safe(string):
|
||||
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||
return string
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
@ -30,7 +166,9 @@ def compression_ratio(text) -> float:
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||
):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
@ -44,211 +182,218 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||
for segment in transcript:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in transcript:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
)
|
||||
|
||||
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in transcript:
|
||||
print(segment['start'], file=file, end="\t")
|
||||
print(segment['end'], file=file, end="\t")
|
||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
"""
|
||||
Write a transcript to a file in SRT format.
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
Example usage:
|
||||
from pathlib import Path
|
||||
from whisper.utils import write_srt
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
# save SRT
|
||||
audio_basename = Path(audio_path).stem
|
||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
"""
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict, options: dict):
|
||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
||||
max_line_count: Optional[int] = options["max_line_count"]
|
||||
highlight_words: bool = options["highlight_words"]
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
times = []
|
||||
last = result["segments"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments
|
||||
if "start" in timing:
|
||||
long_pause = long_pause and timing["start"] - last > 3.0
|
||||
else:
|
||||
long_pause = False
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle, times
|
||||
subtitle = []
|
||||
times = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
times.append((segment["start"], segment["end"], segment.get("speaker")))
|
||||
if "start" in timing:
|
||||
last = timing["start"]
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle, times
|
||||
|
||||
if "words" in result["segments"][0]:
|
||||
for subtitle, _ in iterate_subtitles():
|
||||
sstart, ssend, speaker = _[0]
|
||||
subtitle_start = self.format_timestamp(sstart)
|
||||
subtitle_end = self.format_timestamp(ssend)
|
||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||
has_timing = any(["start" in word for word in subtitle])
|
||||
|
||||
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||
prefix = ""
|
||||
if speaker is not None:
|
||||
prefix = f"[{speaker}]: "
|
||||
|
||||
if highlight_words and has_timing:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
if "start" in this_word:
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
|
||||
yield start, end, prefix + " ".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, prefix + subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
if "speaker" in segment:
|
||||
segment_text = f"[{segment['speaker']}]: {segment_text}"
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
def write_ass(transcript: Iterator[dict],
|
||||
file: TextIO,
|
||||
resolution: str = "word",
|
||||
color: str = None, underline=True,
|
||||
prefmt: str = None, suffmt: str = None,
|
||||
font: str = None, font_size: int = 24,
|
||||
strip=True, **kwargs):
|
||||
"""
|
||||
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
|
||||
Generate Advanced SubStation Alpha (ass) file from results to
|
||||
display both phrase-level & word-level timestamp simultaneously by:
|
||||
-using segment-level timestamps display phrases as usual
|
||||
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
|
||||
Note: ass file is used in the same way as srt, vtt, etc.
|
||||
Parameters
|
||||
----------
|
||||
transcript: dict
|
||||
results from modified model
|
||||
file: TextIO
|
||||
file object to write to
|
||||
resolution: str
|
||||
"word" or "char", timestamp resolution to highlight.
|
||||
color: str
|
||||
color code for a word at its corresponding timestamp
|
||||
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
|
||||
underline: bool
|
||||
whether to underline a word at its corresponding timestamp
|
||||
prefmt: str
|
||||
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
|
||||
appears as such in the .ass file:
|
||||
Hi, {<prefmt>}how{<suffmt>} are you?
|
||||
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
||||
suffmt: str
|
||||
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
|
||||
appears as such in the .ass file:
|
||||
Hi, {<prefmt>}how{<suffmt>} are you?
|
||||
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
||||
font: str
|
||||
word font (default: Arial)
|
||||
font_size: int
|
||||
word font size (default: 48)
|
||||
kwargs:
|
||||
used for format styles:
|
||||
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
|
||||
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
|
||||
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||
|
||||
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
|
||||
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
|
||||
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
|
||||
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
|
||||
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
|
||||
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
|
||||
extension: str = "tsv"
|
||||
|
||||
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
|
||||
kwargs[k] = f'&H{kwargs[k]}'
|
||||
|
||||
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
|
||||
|
||||
if font:
|
||||
fmt_style_dict.update(Fontname=font)
|
||||
if font_size:
|
||||
fmt_style_dict.update(Fontsize=font_size)
|
||||
|
||||
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
|
||||
|
||||
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
|
||||
|
||||
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
|
||||
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
|
||||
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
|
||||
|
||||
if prefmt or suffmt:
|
||||
if suffmt:
|
||||
assert prefmt, 'prefmt must be used along with suffmt'
|
||||
else:
|
||||
suffmt = r'\r'
|
||||
else:
|
||||
if not color:
|
||||
color = 'HFF00'
|
||||
underline_code = r'\u1' if underline else ''
|
||||
|
||||
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
|
||||
suffmt = r'{\r}'
|
||||
|
||||
def secs_to_hhmmss(secs: Tuple[float, int]):
|
||||
mm, ss = divmod(secs, 60)
|
||||
hh, mm = divmod(mm, 60)
|
||||
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
|
||||
if idx_0 == -1:
|
||||
text = chars
|
||||
else:
|
||||
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
|
||||
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
|
||||
f"Default,,0,0,0,,{text.strip() if strip else text}"
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
if resolution == "word":
|
||||
resolution_key = "word-segments"
|
||||
elif resolution == "char":
|
||||
resolution_key = "char-segments"
|
||||
else:
|
||||
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
|
||||
|
||||
ass_arr = []
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
json.dump(result, file)
|
||||
|
||||
for segment in transcript:
|
||||
# if "12" in segment['text']:
|
||||
# import pdb; pdb.set_trace()
|
||||
if resolution_key in segment:
|
||||
res_segs = pd.DataFrame(segment[resolution_key])
|
||||
prev = segment['start']
|
||||
if "speaker" in segment:
|
||||
speaker_str = f"[{segment['speaker']}]: "
|
||||
else:
|
||||
speaker_str = ""
|
||||
for cdx, crow in res_segs.iterrows():
|
||||
if not np.isnan(crow['start']):
|
||||
if resolution == "char":
|
||||
idx_0 = cdx
|
||||
idx_1 = cdx + 1
|
||||
elif resolution == "word":
|
||||
idx_0 = int(crow["segment-text-start"])
|
||||
idx_1 = int(crow["segment-text-end"])
|
||||
# fill gap
|
||||
if crow['start'] > prev:
|
||||
filler_ts = {
|
||||
"chars": speaker_str + segment['text'],
|
||||
"start": prev,
|
||||
"end": crow['start'],
|
||||
"idx_0": -1,
|
||||
"idx_1": -1
|
||||
}
|
||||
|
||||
ass_arr.append(filler_ts)
|
||||
# highlight current word
|
||||
f_word_ts = {
|
||||
"chars": speaker_str + segment['text'],
|
||||
"start": crow['start'],
|
||||
"end": crow['end'],
|
||||
"idx_0": idx_0 + len(speaker_str),
|
||||
"idx_1": idx_1 + len(speaker_str)
|
||||
}
|
||||
ass_arr.append(f_word_ts)
|
||||
prev = crow['end']
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
"srt": WriteSRT,
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
|
||||
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
file.write(ass_str)
|
||||
def write_all(result: dict, file: TextIO, options: dict):
|
||||
for writer in all_writers:
|
||||
writer(result, file, options)
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
def interpolate_nans(x, method='nearest'):
|
||||
if x.notnull().sum() > 1:
|
||||
|
178
whisperx/vad.py
178
whisperx/vad.py
@ -1,10 +1,67 @@
|
||||
import pandas as pd
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
from typing import Callable, Optional, Text, Union
|
||||
|
||||
import numpy as np
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
|
||||
from typing import List, Tuple, Optional
|
||||
import pandas as pd
|
||||
import torch
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
from tqdm import tqdm
|
||||
|
||||
from .diarize import Segment as SegmentX
|
||||
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
if not os.path.isfile(model_fp):
|
||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
||||
vad_pipeline.instantiate(hyperparameters)
|
||||
|
||||
return vad_pipeline
|
||||
|
||||
class Binarize:
|
||||
"""Binarize detection scores using hysteresis thresholding
|
||||
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
||||
to ensure not segments are longer than max_duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
onset : float, optional
|
||||
@ -28,6 +85,9 @@ class Binarize:
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
Pyannote-audio
|
||||
"""
|
||||
|
||||
@ -136,6 +196,51 @@ class Binarize:
|
||||
return active
|
||||
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
):
|
||||
|
||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||
|
||||
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
|
||||
"""Apply voice activity detection
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : AudioFile
|
||||
Processed file.
|
||||
hook : callable, optional
|
||||
Hook called after each major step of the pipeline with the following
|
||||
signature: hook("step_name", step_artefact, file=file)
|
||||
|
||||
Returns
|
||||
-------
|
||||
speech : Annotation
|
||||
Speech regions.
|
||||
"""
|
||||
|
||||
# setup hook (e.g. for debugging purposes)
|
||||
hook = self.setup_hook(file, hook=hook)
|
||||
|
||||
# apply segmentation model (only if needed)
|
||||
# output shape is (num_chunks, num_frames, 1)
|
||||
if self.training:
|
||||
if self.CACHED_SEGMENTATION in file:
|
||||
segmentations = file[self.CACHED_SEGMENTATION]
|
||||
else:
|
||||
segmentations = self._segmentation(file)
|
||||
file[self.CACHED_SEGMENTATION] = segmentations
|
||||
else:
|
||||
segmentations: SlidingWindowFeature = self._segmentation(file)
|
||||
|
||||
return segmentations
|
||||
|
||||
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
|
||||
active = Annotation()
|
||||
@ -157,29 +262,46 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
||||
|
||||
def merge_chunks(segments, chunk_size):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from pyannote.audio import Inference
|
||||
# hook = lambda segmentation: segmentation
|
||||
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||||
# audio = "/tmp/11962.wav"
|
||||
# scores = inference(audio)
|
||||
# binarize = Binarize(max_duration=15)
|
||||
# anno = binarize(scores)
|
||||
# res = []
|
||||
# for ann in anno.get_timeline():
|
||||
# res.append((ann.start, ann.end))
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
# res = pd.DataFrame(res)
|
||||
# res[2] = res[1] - res[0]
|
||||
import pandas as pd
|
||||
input_fp = "tt298650_sync.wav"
|
||||
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
|
||||
print(len(df))
|
||||
N = 0.15
|
||||
g = df[0].sub(df[1].shift())
|
||||
input_base = input_fp.split('.')[0]
|
||||
df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
|
||||
df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
|
||||
print(df)
|
||||
import pdb; pdb.set_trace()
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
# assert segments_list, "segments_list is empty."
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
|
||||
for seg in segments_list:
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
|
Reference in New Issue
Block a user