Merge branch 'main' into danish_alignment

This commit is contained in:
Max Bain
2023-05-09 23:10:13 +01:00
committed by GitHub
59 changed files with 1487 additions and 107436 deletions

View File

@ -1,115 +1,4 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel
from .transcribe import load_model
from .alignment import load_align_model, align
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline

View File

@ -2,16 +2,19 @@
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union, List
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch
from dataclasses import dataclass
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -37,11 +40,12 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech"
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
}
def load_align_model(language_code, device, model_name=None):
def load_align_model(language_code, device, model_name=None, model_dir=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
@ -56,7 +60,7 @@ def load_align_model(language_code, device, model_name=None):
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
@ -78,362 +82,232 @@ def load_align_model(language_code, device, model_name=None):
def align(
transcript: Iterator[dict],
transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
return_char_alignments: bool = False,
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
# 1. Preprocess to keep only characters in dictionary
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = text.split(" ")
else:
per_word = text
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
clean_char, clean_cdx = [], []
for cdx, char in enumerate(text):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# we only pad if not using VAD filtering
if "seg_text" not in segment:
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(text) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = list(np.cumsum(seg_lens))
sub_seg_idx = 0
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().values
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
try:
sub_start = int(srow["subsegment-idx-start"])
except:
import pdb; pdb.set_trace()
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
wseg["start"].fillna(srow["start"], inplace=True)
wseg["end"].fillna(srow["end"], inplace=True)
wseg["segment-text-start"].fillna(0, inplace=True)
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
# fixes bug for single segment in transcript
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
"words": [],
}
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
print("Failed to align segment: original start time longer than audio duration, skipping...")
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
aligned_segments.append(aligned_seg)
continue
char_segments = merge_repeats(path, text_clean)
duration = t2 -t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters
char_segments_arr = []
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
char_segments_arr.append(
{
"char": char,
"start": start,
"end": end,
"score": score,
"word-idx": word_idx,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES:
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
sentence_end = curr_chars["end"].max()
sentence_words = []
for word_idx in curr_chars["word-idx"].unique():
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue
word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable
word_segment = {"word": word_text}
if not np.isnan(word_start):
word_segment["start"] = word_start
if not np.isnan(word_end):
word_segment["end"] = word_end
if not np.isnan(word_score):
word_segment["score"] = word_score
sentence_words.append(word_segment)
aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"words": sentence_words,
})
if return_char_alignments:
curr_chars = curr_chars[["char", "start", "end", "score"]]
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
aligned_subsegments[-1]["chars"] = curr_chars
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
if return_char_alignments:
agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
aligned_subsegments = aligned_subsegments.to_dict('records')
aligned_segments += aligned_subsegments
# create word_segments list
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]
return {"segments": aligned_segments, "word_segments": word_segments}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html

270
whisperx/asr.py Normal file
View File

@ -0,0 +1,270 @@
import os
import warnings
from typing import List, Union
import ctranslate2
import faster_whisper
import numpy as np
import torch
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
Returns:
A Whisper pipeline.
'''
if whisper_arch.endswith(".en"):
language = "en"
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、"
}
if asr_options is not None:
default_asr_options.update(asr_options)
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,
"vad_offset": 0.363
}
if vad_options is not None:
default_vad_options.update(vad_options)
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
)
encoder_output = self.encode(features)
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
# length_penalty=options.length_penalty,
# max_length=self.max_length,
# return_scores=True,
# return_no_speech_prob=True,
# suppress_blank=options.suppress_blank,
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
)
tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)
text = decode_batch(tokens_batch)
return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model,
vad,
options,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework = "pt",
**kwargs
):
self.model = model
self.tokenizer = tokenizer
self.options = options
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
super(Pipeline, self).__init__()
self.vad_model = vad
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, audio):
audio = audio['inputs']
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
def postprocess(self, model_outputs):
return model_outputs
def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO hack by collating feature_extractor and image_processor
def stack(items):
return {'inputs': torch.stack([x['inputs'] for x in items])}
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
del_tokenizer = False
if self.tokenizer is None:
language = self.detect_language(audio)
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
del_tokenizer = True
else:
language = self.tokenizer.language_code
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
segments.append(
{
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
}
)
if del_tokenizer:
self.tokenizer = None
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

0
whisperx/assets/mel_filters.npz Executable file → Normal file
View File

View File

@ -1 +0,0 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import Union
from typing import Optional, Union
import ffmpeg
import numpy as np
@ -15,8 +15,12 @@ N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
@ -55,7 +59,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
@ -85,11 +91,18 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
@ -101,6 +114,12 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
@ -111,6 +130,10 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
@ -121,4 +144,4 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
return log_spec

View File

@ -1,710 +0,0 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result

View File

@ -1,53 +1,63 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = seg['word-segments']
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
if not np.isnan(wrow['start']):
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
else:
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
class DiarizationPipeline:
def __init__(
self,
model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
diarize_df.rename(columns={2: "speaker"}, inplace=True)
return diarize_df
# TODO: create segments but split words on new speaker
return result_segments, word_seg
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
# remove no hit, otherwise we look for closest (even negative intersection...)
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
seg["speaker"] = speaker
# assign speaker to words
if 'words' in seg:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
return transcript_result
class Segment:
def __init__(self, start, end, speaker=None):

View File

@ -1,268 +0,0 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

View File

@ -1,2 +0,0 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer

View File

@ -1,71 +0,0 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

File diff suppressed because it is too large Load Diff

View File

@ -1,543 +0,0 @@
import json
import os
import re
from fractions import Fraction
from typing import Iterator, List, Match, Optional, Union
from more_itertools import windowed
from .basic import remove_symbols_and_diacritics
class EnglishNumberNormalizer:
"""
Convert any spelled-out numbers into arabic numbers, while handling:
- remove any commas
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- spell out `one` and `ones`
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
"""
def __init__(self):
super().__init__()
self.zeros = {"o", "oh", "zero"}
self.ones = {
name: i
for i, name in enumerate(
[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
],
start=1,
)
}
self.ones_plural = {
"sixes" if name == "six" else name + "s": (value, "s")
for name, value in self.ones.items()
}
self.ones_ordinal = {
"zeroth": (0, "th"),
"first": (1, "st"),
"second": (2, "nd"),
"third": (3, "rd"),
"fifth": (5, "th"),
"twelfth": (12, "th"),
**{
name + ("h" if name.endswith("t") else "th"): (value, "th")
for name, value in self.ones.items()
if value > 3 and value != 5 and value != 12
},
}
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
self.tens = {
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
}
self.tens_plural = {
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
self.multipliers = {
"hundred": 100,
"thousand": 1_000,
"million": 1_000_000,
"billion": 1_000_000_000,
"trillion": 1_000_000_000_000,
"quadrillion": 1_000_000_000_000_000,
"quintillion": 1_000_000_000_000_000_000,
"sextillion": 1_000_000_000_000_000_000_000,
"septillion": 1_000_000_000_000_000_000_000_000,
"octillion": 1_000_000_000_000_000_000_000_000_000,
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
}
self.multipliers_plural = {
name + "s": (value, "s") for name, value in self.multipliers.items()
}
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
"minus": "-",
"negative": "-",
"plus": "+",
"positive": "+",
}
self.following_prefixers = {
"pound": "£",
"pounds": "£",
"euro": "",
"euros": "",
"dollar": "$",
"dollars": "$",
"cent": "¢",
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
"percent": "%",
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
key
for mapping in [
self.zeros,
self.ones,
self.ones_suffixed,
self.tens,
self.tens_suffixed,
self.multipliers,
self.multipliers_suffixed,
self.preceding_prefixers,
self.following_prefixers,
self.suffixers,
self.specials,
]
for key in mapping
]
)
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
prefix: Optional[str] = None
value: Optional[Union[str, int]] = None
skip = False
def to_fraction(s: str):
try:
return Fraction(s)
except ValueError:
return None
def output(result: Union[str, int]):
nonlocal prefix, value
result = str(result)
if prefix is not None:
result = prefix + result
value = None
prefix = None
return result
if len(words) == 0:
return
for prev, current, next in windowed([None] + words + [None], 3):
if skip:
skip = False
continue
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
has_prefix = current[0] in self.prefixes
current_without_prefix = current[1:] if has_prefix else current
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
# arabic numbers (potentially with signs and fractions)
f = to_fraction(current_without_prefix)
assert f is not None
if value is not None:
if isinstance(value, str) and value.endswith("."):
# concatenate decimals / ip address components
value = str(value) + str(current)
continue
else:
yield output(value)
prefix = current[0] if has_prefix else prefix
if f.denominator == 1:
value = f.numerator # store integers as int
else:
value = current_without_prefix
elif current not in self.words:
# non-numeric words
if value is not None:
yield output(value)
yield output(current)
elif current in self.zeros:
value = str(value or "") + "0"
elif current in self.ones:
ones = self.ones[current]
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
value = str(value) + str(ones)
elif ones < 10:
if value % 10 == 0:
value += ones
else:
value = str(value) + str(ones)
else: # eleven to nineteen
if value % 100 == 0:
value += ones
else:
value = str(value) + str(ones)
elif current in self.ones_suffixed:
# ordinal or cardinal; yield the number right away
ones, suffix = self.ones_suffixed[current]
if value is None:
yield output(str(ones) + suffix)
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10:
assert value[-1] == "0"
yield output(value[:-1] + str(ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
elif ones < 10:
if value % 10 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
else: # eleven to nineteen
if value % 100 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
value = None
elif current in self.tens:
tens = self.tens[current]
if value is None:
value = tens
elif isinstance(value, str):
value = str(value) + str(tens)
else:
if value % 100 == 0:
value += tens
else:
value = str(value) + str(tens)
elif current in self.tens_suffixed:
# ordinal or cardinal; yield the number right away
tens, suffix = self.tens_suffixed[current]
if value is None:
yield output(str(tens) + suffix)
elif isinstance(value, str):
yield output(str(value) + str(tens) + suffix)
else:
if value % 100 == 0:
yield output(str(value + tens) + suffix)
else:
yield output(str(value) + str(tens) + suffix)
elif current in self.multipliers:
multiplier = self.multipliers[current]
if value is None:
value = multiplier
elif isinstance(value, str) or value == 0:
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
value = p.numerator
else:
yield output(value)
value = multiplier
else:
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
elif current in self.multipliers_suffixed:
multiplier, suffix = self.multipliers_suffixed[current]
if value is None:
yield output(str(multiplier) + suffix)
elif isinstance(value, str):
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
yield output(str(p.numerator) + suffix)
else:
yield output(value)
yield output(str(multiplier) + suffix)
else: # int
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
yield output(str(value) + suffix)
value = None
elif current in self.preceding_prefixers:
# apply prefix (positive, minus, etc.) if it precedes a number
if value is not None:
yield output(value)
if next in self.words or next_is_numeric:
prefix = self.preceding_prefixers[current]
else:
yield output(current)
elif current in self.following_prefixers:
# apply prefix (dollars, cents, etc.) only after a number
if value is not None:
prefix = self.following_prefixers[current]
yield output(value)
else:
yield output(current)
elif current in self.suffixers:
# apply suffix symbols (percent -> '%')
if value is not None:
suffix = self.suffixers[current]
if isinstance(suffix, dict):
if next in suffix:
yield output(str(value) + suffix[next])
skip = True
else:
yield output(value)
yield output(current)
else:
yield output(str(value) + suffix)
else:
yield output(current)
elif current in self.specials:
if next not in self.words and not next_is_numeric:
# apply special handling only if the next word can be numeric
if value is not None:
yield output(value)
yield output(current)
elif current == "and":
# ignore "and" after hundreds, thousands, etc.
if prev not in self.multipliers:
if value is not None:
yield output(value)
yield output(current)
elif current == "double" or current == "triple":
if next in self.ones or next in self.zeros:
repeats = 2 if current == "double" else 3
ones = self.ones.get(next, 0)
value = str(value or "") + str(ones) * repeats
skip = True
else:
if value is not None:
yield output(value)
yield output(current)
elif current == "point":
if next in self.decimals or next_is_numeric:
value = str(value or "") + "."
else:
# should all have been covered at this point
raise ValueError(f"Unexpected token: {current}")
else:
# all should have been covered at this point
raise ValueError(f"Unexpected token: {current}")
if value is not None:
yield output(value)
def preprocess(self, s: str):
# replace "<number> and a half" with "<number> point five"
results = []
segments = re.split(r"\band\s+a\s+half\b", s)
for i, segment in enumerate(segments):
if len(segment.strip()) == 0:
continue
if i == len(segments) - 1:
results.append(segment)
else:
results.append(segment)
last_word = segment.rsplit(maxsplit=2)[-1]
if last_word in self.decimals or last_word in self.multipliers:
results.append("point five")
else:
results.append("and a half")
s = " ".join(results)
# put a space at number/letter boundary
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
# but remove spaces which could be a suffix
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
return s
def postprocess(self, s: str):
def combine_cents(m: Match):
try:
currency = m.group(1)
integer = m.group(2)
cents = int(m.group(3))
return f"{currency}{integer}.{cents:02d}"
except ValueError:
return m.string
def extract_cents(m: Match):
try:
return f"¢{int(m.group(1))}"
except ValueError:
return m.string
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
# write "one(s)" instead of "1(s)", just for the readability
s = re.sub(r"\b1(s?)\b", r"one\1", s)
return s
def __call__(self, s: str):
s = self.preprocess(s)
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
s = self.postprocess(s)
return s
class EnglishSpellingNormalizer:
"""
Applies British-American spelling mappings as listed in [1].
[1] https://www.tysto.com/uk-us-spelling-list.html
"""
def __init__(self):
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
self.mapping = json.load(open(mapping_path))
def __call__(self, s: str):
return " ".join(self.mapping.get(word, word) for word in s.split())
class EnglishTextNormalizer:
def __init__(self):
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
self.replacers = {
# common contractions
r"\bwon't\b": "will not",
r"\bcan't\b": "can not",
r"\blet's\b": "let us",
r"\bain't\b": "aint",
r"\by'all\b": "you all",
r"\bwanna\b": "want to",
r"\bgotta\b": "got to",
r"\bgonna\b": "going to",
r"\bi'ma\b": "i am going to",
r"\bimma\b": "i am going to",
r"\bwoulda\b": "would have",
r"\bcoulda\b": "could have",
r"\bshoulda\b": "should have",
r"\bma'am\b": "madam",
# contractions in titles/prefixes
r"\bmr\b": "mister ",
r"\bmrs\b": "missus ",
r"\bst\b": "saint ",
r"\bdr\b": "doctor ",
r"\bprof\b": "professor ",
r"\bcapt\b": "captain ",
r"\bgov\b": "governor ",
r"\bald\b": "alderman ",
r"\bgen\b": "general ",
r"\bsen\b": "senator ",
r"\brep\b": "representative ",
r"\bpres\b": "president ",
r"\brev\b": "reverend ",
r"\bhon\b": "honorable ",
r"\basst\b": "assistant ",
r"\bassoc\b": "associate ",
r"\blt\b": "lieutenant ",
r"\bcol\b": "colonel ",
r"\bjr\b": "junior ",
r"\bsr\b": "senior ",
r"\besq\b": "esquire ",
# prefect tenses, ideally it should be any past participles, but it's harder..
r"'d been\b": " had been",
r"'s been\b": " has been",
r"'d gone\b": " had gone",
r"'s gone\b": " has gone",
r"'d done\b": " had done", # "'s done" is ambiguous
r"'s got\b": " has got",
# general contractions
r"n't\b": " not",
r"'re\b": " are",
r"'s\b": " is",
r"'d\b": " would",
r"'ll\b": " will",
r"'t\b": " not",
r"'ve\b": " have",
r"'m\b": " am",
}
self.standardize_numbers = EnglishNumberNormalizer()
self.standardize_spellings = EnglishSpellingNormalizer()
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
# now remove prefix/suffix symbols that are not preceded/followed by numbers
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

View File

@ -1,331 +0,0 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import GPT2TokenizerFast
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass(frozen=True)
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
@property
@lru_cache()
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- ♪♪♪
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
specials = [
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
language = language or "en"
else:
tokenizer_name = "gpt2"
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))

View File

@ -1,602 +1,51 @@
import argparse
import gc
import os
import warnings
from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
import numpy as np
import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult
from .diarize import assign_word_speakers, Segment
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
from .vad import Binarize
import pandas as pd
if TYPE_CHECKING:
from .model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
mel: np.ndarray = None,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
if mel is None:
mel = log_mel_spectrogram(audio)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
continue
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# clamp end-time to at least be 1 frame after start-time
end_timestamp_position = max(end_timestamp_position, start_timestamp_position + time_precision)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
"""
Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
TODO: Or sliding window alignment model over long segment.
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
def transcribe_with_vad(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
**kwargs
):
"""
Transcribe per VAD segment
"""
if mel is None:
mel = log_mel_spectrogram(audio)
prev = 0
output = {"segments": []}
vad_segments = vad_pipeline(audio)
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments)
for sdx, seg_t in enumerate(vad_segments):
if verbose:
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start]
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t["text"] = result["text"]
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output["language"] = output["segments"][0]["language"]
return output
def transcribe_with_vad_parallel(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
batch_size = -1,
**kwargs
):
"""
Transcribe per VAD segment
"""
if mel is None:
mel = log_mel_spectrogram(audio)
vad_segments = vad_pipeline(audio)
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments)
################################
### START of parallelization ###
################################
# pad mel to a same length
start_seconds = [i['start'] for i in vad_segments]
end_seconds = [i['end'] for i in vad_segments]
duration_list = np.array(end_seconds) - np.array(start_seconds)
max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
offset_list = np.array(start_seconds)
chunks = []
for start_ts, end_ts in zip(start_seconds, end_seconds):
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
chunk = mel[:, start_ts:end_ts]
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
chunks.append(chunk)
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
# using 'decode_options1': only support single temperature decoding (no fallbacks)
# result_list2 = model.decode(mel_chunk, decode_options1)
# prepare DecodingOptions
temperatures = kwargs.pop("temperature", None)
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
logprob_threshold = kwargs.pop("logprob_threshold", None)
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
initial_prompt = kwargs.pop("initial_prompt", None)
t = 0 # TODO: does not upport temperature sweeping
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
decode_result = []
for mel_chunk_batch in mel_chunk_batches:
decode_result.extend(model.decode(mel_chunk_batch, options))
##############################
### END of parallelization ###
##############################
# post processing: get segments rfom batch-decoded results
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
language = kwargs["language"]
task = kwargs["task"]
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
output = post_process_results(
vad_segments,
decode_result,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold=no_speech_threshold,
logprob_threshold=logprob_threshold,
verbose=verbose)
return output
def post_process_results(
vad_segments,
result_list,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold = None,
logprob_threshold = None,
verbose: Optional[bool] = None,
):
seek = 0
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
output = {"segments": []}
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
# process the output
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
all_tokens = []
all_segments = []
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_shape # fast-forward to the next segment boundary
continue
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
seek += segment_shape
all_tokens.extend(tokens.tolist())
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output["language"] = output["segments"][0]["language"]
return output
from .alignment import align, load_align_model
from .asr import load_model
from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
optional_int, str2bool)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
@ -612,152 +61,142 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
batch_size: int = args.pop("batch_size")
output_dir: str = args.pop("output_dir")
output_type: str = args.pop("output_type")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: bool = args.pop("interpolate_method")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
parallel_bs: int = args.pop("parallel_bs")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
vad_pipeline = None
if vad_filter:
if hf_token is None:
print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
from pyannote.audio import Inference
vad_pipeline = Inference(
"pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=hf_token,
device=torch.device(device),
)
diarize_pipeline = None
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=hf_token)
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
asr_options = {
"beam_size": args.pop("beam_size"),
"patience": args.pop("patience"),
"length_penalty": args.pop("length_penalty"),
"temperatures": temperature,
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
"log_prob_threshold": args.pop("logprob_threshold"),
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
}
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if no_align:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
for audio_path in args.pop("audio"):
if vad_filter:
if parallel_bs > 1:
print("Performing VAD and parallel transcribing ...")
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size)
results.append((result, audio_path))
# Unload Whisper and VAD
del model
gc.collect()
torch.cuda.empty_cache()
# Part 2: Align Loop
if not no_align:
tmp_results = results
results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for result, audio_path in tmp_results:
# >> Align
if len(tmp_results) > 1:
input_audio = audio_path
else:
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
# lazily load audio from part 1
input_audio = audio
if result["language"] != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
results.append((result, audio_path))
print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
audio_basename = os.path.basename(audio_path)
# Unload align model
del align_model
gc.collect()
torch.cuda.empty_cache()
if diarize:
print("Performing diarization...")
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
# >> Diarize
if diarize:
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
writer(result, audio_path, writer_args)
# save TXT
if output_type in ["txt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result_aligned["segments"], file=txt)
# save VTT
if output_type in ["vtt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result_aligned["segments"], file=vtt)
# save SRT
if output_type in ["srt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
# save TSV
if output_type in ["tsv", "all"]:
with open(os.path.join(output_dir, audio_basename + ".tsv"), "w", encoding="utf-8") as srt:
write_tsv(result_aligned["segments"], file=srt)
# save SRT word-level
if output_type in ["srt-word", "all"]:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
if output_type in ["ass", "all"]:
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
# # save ASS character-level
if output_type in ["ass-char"]:
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass, resolution="char")
# save word tsv
if output_type in ["pickle"]:
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
# save word tsv
if output_type in ["vad"]:
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
if __name__ == "__main__":
cli()
cli()

58
whisperx/types.py Normal file
View File

@ -0,0 +1,58 @@
from typing import TypedDict, Optional
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
words: list[SingleWordSegment]
chars: Optional[list[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleAlignedSegment]
word_segments: list[SingleWordSegment]

View File

@ -1,8 +1,144 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
import numpy as np
from typing import Callable, Optional, TextIO
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
@ -30,7 +166,9 @@ def compression_ratio(text) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@ -44,211 +182,218 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
print(segment['start'], file=file, end="\t")
print(segment['end'], file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options)
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
class WriteTXT(ResultWriter):
extension: str = "txt"
Example usage:
from pathlib import Path
from whisper.utils import write_srt
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
times = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments
if "start" in timing:
long_pause = long_pause and timing["start"] - last > 3.0
else:
long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle, times
subtitle = []
times = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
times.append((segment["start"], segment["end"], segment.get("speaker")))
if "start" in timing:
last = timing["start"]
if len(subtitle) > 0:
yield subtitle, times
if "words" in result["segments"][0]:
for subtitle, _ in iterate_subtitles():
sstart, ssend, speaker = _[0]
subtitle_start = self.format_timestamp(sstart)
subtitle_end = self.format_timestamp(ssend)
subtitle_text = " ".join([word["word"] for word in subtitle])
has_timing = any(["start" in word for word in subtitle])
# add [$SPEAKER_ID]: to each subtitle if speaker is available
prefix = ""
if speaker is not None:
prefix = f"[{speaker}]: "
if highlight_words and has_timing:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
if "start" in this_word:
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
yield start, end, prefix + " ".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, prefix + subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if "speaker" in segment:
segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
def write_ass(transcript: Iterator[dict],
file: TextIO,
resolution: str = "word",
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
strip=True, **kwargs):
"""
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
Generate Advanced SubStation Alpha (ass) file from results to
display both phrase-level & word-level timestamp simultaneously by:
-using segment-level timestamps display phrases as usual
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
transcript: dict
results from modified model
file: TextIO
file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
underline: bool
whether to underline a word at its corresponding timestamp
prefmt: str
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
suffmt: str
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
font: str
word font (default: Arial)
font_size: int
word font size (default: 48)
kwargs:
used for format styles:
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
extension: str = "tsv"
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
kwargs[k] = f'&H{kwargs[k]}'
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
if font:
fmt_style_dict.update(Fontname=font)
if font_size:
fmt_style_dict.update(Fontsize=font_size)
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
if prefmt or suffmt:
if suffmt:
assert prefmt, 'prefmt must be used along with suffmt'
else:
suffmt = r'\r'
else:
if not color:
color = 'HFF00'
underline_code = r'\u1' if underline else ''
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
suffmt = r'{\r}'
def secs_to_hhmmss(secs: Tuple[float, int]):
mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60)
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
if idx_0 == -1:
text = chars
else:
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
class WriteJSON(ResultWriter):
extension: str = "json"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = []
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)
for segment in transcript:
# if "12" in segment['text']:
# import pdb; pdb.set_trace()
if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
if "speaker" in segment:
speaker_str = f"[{segment['speaker']}]: "
else:
speaker_str = ""
for cdx, crow in res_segs.iterrows():
if not np.isnan(crow['start']):
if resolution == "char":
idx_0 = cdx
idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"chars": speaker_str + segment['text'],
"start": prev,
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
}
ass_arr.append(f_word_ts)
prev = crow['end']
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
file.write(ass_str)
def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
writer(result, file, options)
return write_all
return writers[output_format](output_dir)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:

View File

@ -1,10 +1,67 @@
import pandas as pd
import hashlib
import os
import urllib
from typing import Callable, Optional, Text, Union
import numpy as np
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
from typing import List, Tuple, Optional
import pandas as pd
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from tqdm import tqdm
from .diarize import Segment as SegmentX
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
vad_pipeline.instantiate(hyperparameters)
return vad_pipeline
class Binarize:
"""Binarize detection scores using hysteresis thresholding
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
to ensure not segments are longer than max_duration.
Parameters
----------
onset : float, optional
@ -28,6 +85,9 @@ class Binarize:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
@ -136,6 +196,51 @@ class Binarize:
return active
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
"""Apply voice activity detection
Parameters
----------
file : AudioFile
Processed file.
hook : callable, optional
Hook called after each major step of the pipeline with the following
signature: hook("step_name", step_artefact, file=file)
Returns
-------
speech : Annotation
Speech regions.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
# apply segmentation model (only if needed)
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
@ -157,29 +262,46 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
def merge_chunks(segments, chunk_size):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
if __name__ == "__main__":
# from pyannote.audio import Inference
# hook = lambda segmentation: segmentation
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
# audio = "/tmp/11962.wav"
# scores = inference(audio)
# binarize = Binarize(max_duration=15)
# anno = binarize(scores)
# res = []
# for ann in anno.get_timeline():
# res.append((ann.start, ann.end))
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
# res = pd.DataFrame(res)
# res[2] = res[1] - res[0]
import pandas as pd
input_fp = "tt298650_sync.wav"
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
print(len(df))
N = 0.15
g = df[0].sub(df[1].shift())
input_base = input_fp.split('.')[0]
df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
print(df)
import pdb; pdb.set_trace()
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments