Files
whisperX/whisperx/alignment.py
2023-01-25 18:42:52 +00:00

522 lines
21 KiB
Python

""""
Forced Alignment with Whisper
C. Max Bain
"""
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import AutoProcessor, Wav2Vec2ForCTC
import torchaudio
import torch
from dataclasses import dataclass
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = AutoProcessor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def align(
transcript: Iterator[dict],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"]
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"]
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"]
word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"]
word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1
word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"]
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]]
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"])
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"])
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"])
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge subsegments which are missing times
# group by sub seg and time.
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
sub_start = int(srow["subsegment-idx-start"])
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
cseg['segment-text-start'] = cseg['level_1']
cseg['segment-text-end'] = cseg['level_1'] + 1
del cseg['level_1']
del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
)
return trellis
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token
if changed > stayed:
j -= 1
if j == 0:
break
else:
# failed
return None
return path[::-1]
# Merge the labels
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path, transcript):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
transcript[path[i1].token_index],
path[i1].time_index,
path[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
def merge_words(segments, separator="|"):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words