|
|
|
@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim,
|
|
|
|
|
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
|
|
|
|
from .decoding import DecodingOptions, DecodingResult
|
|
|
|
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|
|
|
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
|
|
|
|
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@ -280,8 +280,39 @@ def align(
|
|
|
|
|
device: str,
|
|
|
|
|
extend_duration: float = 0.0,
|
|
|
|
|
start_from_previous: bool = True,
|
|
|
|
|
drop_non_aligned_words: bool = False,
|
|
|
|
|
interpolate_method: str = "nearest",
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Force align phoneme recognition predictions to known transcription
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
transcript: Iterator[dict]
|
|
|
|
|
The Whisper model instance
|
|
|
|
|
|
|
|
|
|
model: torch.nn.Module
|
|
|
|
|
Alignment model (wav2vec2)
|
|
|
|
|
|
|
|
|
|
audio: Union[str, np.ndarray, torch.Tensor]
|
|
|
|
|
The path to the audio file to open, or the audio waveform
|
|
|
|
|
|
|
|
|
|
device: str
|
|
|
|
|
cuda device
|
|
|
|
|
|
|
|
|
|
extend_duration: float
|
|
|
|
|
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
|
|
|
|
|
|
|
|
|
If the gzip compression ratio is above this value, treat as failed
|
|
|
|
|
|
|
|
|
|
interpolate_method: str ["nearest", "linear", "ignore"]
|
|
|
|
|
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
|
|
|
|
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
|
|
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
|
|
|
"""
|
|
|
|
|
if not torch.is_tensor(audio):
|
|
|
|
|
if isinstance(audio, str):
|
|
|
|
|
audio = load_audio(audio)
|
|
|
|
@ -291,171 +322,266 @@ def align(
|
|
|
|
|
|
|
|
|
|
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
|
|
|
|
|
|
|
|
|
model_dictionary = align_model_metadata['dictionary']
|
|
|
|
|
model_lang = align_model_metadata['language']
|
|
|
|
|
model_type = align_model_metadata['type']
|
|
|
|
|
model_dictionary = align_model_metadata["dictionary"]
|
|
|
|
|
model_lang = align_model_metadata["language"]
|
|
|
|
|
model_type = align_model_metadata["type"]
|
|
|
|
|
|
|
|
|
|
aligned_segments = []
|
|
|
|
|
|
|
|
|
|
prev_t2 = 0
|
|
|
|
|
total_word_segments_list = []
|
|
|
|
|
vad_segments_list = []
|
|
|
|
|
for idx, segment in enumerate(transcript):
|
|
|
|
|
word_segments_list = []
|
|
|
|
|
# first we pad
|
|
|
|
|
t1 = max(segment['start'] - extend_duration, 0)
|
|
|
|
|
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
|
|
|
|
|
sdx = 0
|
|
|
|
|
for segment in transcript:
|
|
|
|
|
while True:
|
|
|
|
|
segment_align_success = False
|
|
|
|
|
|
|
|
|
|
# use prev_t2 as current t1 if it's later
|
|
|
|
|
if start_from_previous and t1 < prev_t2:
|
|
|
|
|
t1 = prev_t2
|
|
|
|
|
# strip spaces at beginning / end, but keep track of the amount.
|
|
|
|
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
|
|
|
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
|
|
|
|
transcription = segment["text"]
|
|
|
|
|
|
|
|
|
|
# check if timestamp range is still valid
|
|
|
|
|
if t1 >= MAX_DURATION:
|
|
|
|
|
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
|
|
|
|
continue
|
|
|
|
|
if t2 - t1 < 0.02:
|
|
|
|
|
print("Failed to align segment: duration smaller than 0.02s time precision")
|
|
|
|
|
continue
|
|
|
|
|
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
|
|
|
|
# e.g. "$300" -> "three hundred dollars"
|
|
|
|
|
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
|
|
|
|
|
|
|
|
|
f1 = int(t1 * SAMPLE_RATE)
|
|
|
|
|
f2 = int(t2 * SAMPLE_RATE)
|
|
|
|
|
|
|
|
|
|
waveform_segment = audio[:, f1:f2]
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
if model_type == "torchaudio":
|
|
|
|
|
emissions, _ = model(waveform_segment.to(device))
|
|
|
|
|
elif model_type == "huggingface":
|
|
|
|
|
emissions = model(waveform_segment.to(device)).logits
|
|
|
|
|
# split into words
|
|
|
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
per_word = transcription.split(" ")
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
|
|
|
|
emissions = torch.log_softmax(emissions, dim=-1)
|
|
|
|
|
per_word = transcription
|
|
|
|
|
|
|
|
|
|
emission = emissions[0].cpu().detach()
|
|
|
|
|
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
|
|
|
|
clean_char, clean_cdx = [], []
|
|
|
|
|
for cdx, char in enumerate(transcription):
|
|
|
|
|
char_ = char.lower()
|
|
|
|
|
# wav2vec2 models use "|" character to represent spaces
|
|
|
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
char_ = char_.replace(" ", "|")
|
|
|
|
|
|
|
|
|
|
# ignore whitespace at beginning and end of transcript
|
|
|
|
|
if cdx < num_leading:
|
|
|
|
|
pass
|
|
|
|
|
elif cdx > len(transcription) - num_trailing - 1:
|
|
|
|
|
pass
|
|
|
|
|
elif char_ in model_dictionary.keys():
|
|
|
|
|
clean_char.append(char_)
|
|
|
|
|
clean_cdx.append(cdx)
|
|
|
|
|
|
|
|
|
|
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
|
|
|
|
|
ratio = waveform_segment.size(0) / emission.size(0)
|
|
|
|
|
space_idx = model_dictionary['|']
|
|
|
|
|
# find non-vad segments
|
|
|
|
|
for i in range(1, len(segment['vad'])):
|
|
|
|
|
start = segment['vad'][i-1][1]
|
|
|
|
|
end = segment['vad'][i][0]
|
|
|
|
|
if start < end: # check if there is a gap between intervals
|
|
|
|
|
non_vad_f1 = int(start / ratio)
|
|
|
|
|
non_vad_f2 = int(end / ratio)
|
|
|
|
|
# non-vad should be masked, use space to do so
|
|
|
|
|
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
|
|
|
|
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
|
|
|
|
clean_wdx = []
|
|
|
|
|
for wdx, wrd in enumerate(per_word):
|
|
|
|
|
if any([c in model_dictionary.keys() for c in wrd]):
|
|
|
|
|
clean_wdx.append(wdx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = segment['vad'][i][1]
|
|
|
|
|
end = segment['end']
|
|
|
|
|
non_vad_f1 = int(start / ratio)
|
|
|
|
|
non_vad_f2 = int(end / ratio)
|
|
|
|
|
# non-vad should be masked, use space to do so
|
|
|
|
|
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
|
|
|
|
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
|
|
|
|
|
|
|
|
|
transcription = segment['text'].strip()
|
|
|
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
t_words = transcription.split(' ')
|
|
|
|
|
else:
|
|
|
|
|
t_words = [c for c in transcription]
|
|
|
|
|
|
|
|
|
|
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
|
|
|
|
|
t_words_nonempty = [x for x in t_words_clean if x != ""]
|
|
|
|
|
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
|
|
|
|
|
segment['word-level'] = []
|
|
|
|
|
|
|
|
|
|
fail_fallback = False
|
|
|
|
|
if len(t_words_nonempty) > 0:
|
|
|
|
|
transcription_cleaned = "|".join(t_words_nonempty).lower()
|
|
|
|
|
# if no characters are in the dictionary, then we skip this segment...
|
|
|
|
|
if len(clean_char) == 0:
|
|
|
|
|
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
transcription_cleaned = "".join(clean_char)
|
|
|
|
|
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
|
|
|
|
|
|
|
|
|
# pad according original timestamps
|
|
|
|
|
t1 = max(segment["start"] - extend_duration, 0)
|
|
|
|
|
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
|
|
|
|
|
|
|
|
|
# use prev_t2 as current t1 if it"s later
|
|
|
|
|
if start_from_previous and t1 < prev_t2:
|
|
|
|
|
t1 = prev_t2
|
|
|
|
|
|
|
|
|
|
# check if timestamp range is still valid
|
|
|
|
|
if t1 >= MAX_DURATION:
|
|
|
|
|
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
|
|
|
|
break
|
|
|
|
|
if t2 - t1 < 0.02:
|
|
|
|
|
print("Failed to align segment: duration smaller than 0.02s time precision")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
f1 = int(t1 * SAMPLE_RATE)
|
|
|
|
|
f2 = int(t2 * SAMPLE_RATE)
|
|
|
|
|
|
|
|
|
|
waveform_segment = audio[:, f1:f2]
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
if model_type == "torchaudio":
|
|
|
|
|
emissions, _ = model(waveform_segment.to(device))
|
|
|
|
|
elif model_type == "huggingface":
|
|
|
|
|
emissions = model(waveform_segment.to(device)).logits
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
|
|
|
|
emissions = torch.log_softmax(emissions, dim=-1)
|
|
|
|
|
|
|
|
|
|
emission = emissions[0].cpu().detach()
|
|
|
|
|
|
|
|
|
|
trellis = get_trellis(emission, tokens)
|
|
|
|
|
path = backtrack(trellis, emission, tokens)
|
|
|
|
|
if path is None:
|
|
|
|
|
print("Failed to align segment: backtrack failed, resorting to original...")
|
|
|
|
|
fail_fallback = True
|
|
|
|
|
else:
|
|
|
|
|
segments = merge_repeats(path, transcription_cleaned)
|
|
|
|
|
word_segments = merge_words(segments)
|
|
|
|
|
ratio = waveform_segment.size(0) / (trellis.size(0) - 1)
|
|
|
|
|
break
|
|
|
|
|
char_segments = merge_repeats(path, transcription_cleaned)
|
|
|
|
|
# word_segments = merge_words(char_segments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duration = t2 - t1
|
|
|
|
|
local = []
|
|
|
|
|
t_local = [None] * len(t_words)
|
|
|
|
|
for wdx, word in enumerate(word_segments):
|
|
|
|
|
t1_ = ratio * word.start
|
|
|
|
|
t2_ = ratio * word.end
|
|
|
|
|
local.append((t1_, t2_))
|
|
|
|
|
t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1)
|
|
|
|
|
t1_actual = t1 + local[0][0] * duration
|
|
|
|
|
t2_actual = t1 + local[-1][1] * duration
|
|
|
|
|
# sub-segments
|
|
|
|
|
if "seg-text" not in segment:
|
|
|
|
|
segment["seg-text"] = [transcription]
|
|
|
|
|
|
|
|
|
|
v = 0
|
|
|
|
|
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
|
|
|
|
seg_lens_cumsum = [v := v + n for n in seg_lens]
|
|
|
|
|
sub_seg_idx = 0
|
|
|
|
|
|
|
|
|
|
segment['start'] = t1_actual
|
|
|
|
|
segment['end'] = t2_actual
|
|
|
|
|
prev_t2 = segment['end']
|
|
|
|
|
char_level = {
|
|
|
|
|
"start": [],
|
|
|
|
|
"end": [],
|
|
|
|
|
"score": [],
|
|
|
|
|
"word-index": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# for the .ass output
|
|
|
|
|
for x in range(len(t_local)):
|
|
|
|
|
curr_word = t_words[x]
|
|
|
|
|
curr_timestamp = t_local[x]
|
|
|
|
|
if curr_timestamp is not None:
|
|
|
|
|
segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
|
|
|
|
|
word_level = {
|
|
|
|
|
"start": [],
|
|
|
|
|
"end": [],
|
|
|
|
|
"score": [],
|
|
|
|
|
"segment-text-start": [],
|
|
|
|
|
"segment-text-end": []
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wdx = 0
|
|
|
|
|
seg_start_actual, seg_end_actual = None, None
|
|
|
|
|
duration = t2 - t1
|
|
|
|
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
|
|
|
|
cdx_prev = 0
|
|
|
|
|
for cdx, char in enumerate(transcription + " "):
|
|
|
|
|
is_last = False
|
|
|
|
|
if cdx == len(transcription):
|
|
|
|
|
break
|
|
|
|
|
elif cdx+1 == len(transcription):
|
|
|
|
|
is_last = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start, end, score = None, None, None
|
|
|
|
|
if cdx in clean_cdx:
|
|
|
|
|
char_seg = char_segments[clean_cdx.index(cdx)]
|
|
|
|
|
start = char_seg.start * ratio + t1
|
|
|
|
|
end = char_seg.end * ratio + t1
|
|
|
|
|
score = char_seg.score
|
|
|
|
|
|
|
|
|
|
char_level["start"].append(start)
|
|
|
|
|
char_level["end"].append(end)
|
|
|
|
|
char_level["score"].append(score)
|
|
|
|
|
char_level["word-index"].append(wdx)
|
|
|
|
|
|
|
|
|
|
# word-level info
|
|
|
|
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
# character == word
|
|
|
|
|
wdx += 1
|
|
|
|
|
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
|
|
|
wdx += 1
|
|
|
|
|
word_level["start"].append(None)
|
|
|
|
|
word_level["end"].append(None)
|
|
|
|
|
word_level["score"].append(None)
|
|
|
|
|
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
|
|
|
|
|
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
|
|
|
|
|
cdx_prev = cdx+2
|
|
|
|
|
|
|
|
|
|
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
|
|
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
char_level = pd.DataFrame(char_level)
|
|
|
|
|
word_level = pd.DataFrame(word_level)
|
|
|
|
|
|
|
|
|
|
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
|
|
|
|
|
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
|
|
|
|
|
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
|
|
|
|
|
|
|
|
|
|
# fill missing
|
|
|
|
|
if interpolate_method != "ignore":
|
|
|
|
|
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
|
|
|
|
|
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
|
|
|
|
|
word_level["start"] = word_level["start"].values.tolist()
|
|
|
|
|
word_level["end"] = word_level["end"].values.tolist()
|
|
|
|
|
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
|
|
|
|
|
|
|
|
|
|
char_level = char_level.replace({np.nan:None}).to_dict("list")
|
|
|
|
|
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
|
|
|
|
|
else:
|
|
|
|
|
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
|
|
|
|
|
word_level = None
|
|
|
|
|
|
|
|
|
|
# for per-word .srt ouput
|
|
|
|
|
# merge missing words to previous, or merge with next word ahead if idx == 0
|
|
|
|
|
found_first_ts = False
|
|
|
|
|
for x in range(len(t_local)):
|
|
|
|
|
curr_word = t_words[x]
|
|
|
|
|
curr_timestamp = t_local[x]
|
|
|
|
|
if curr_timestamp is not None:
|
|
|
|
|
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
|
|
|
|
|
found_first_ts = True
|
|
|
|
|
elif not drop_non_aligned_words:
|
|
|
|
|
# then we merge
|
|
|
|
|
if not found_first_ts:
|
|
|
|
|
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
|
|
|
|
|
else:
|
|
|
|
|
word_segments_list[-1]['text'] += ' ' + curr_word
|
|
|
|
|
else:
|
|
|
|
|
fail_fallback = True
|
|
|
|
|
|
|
|
|
|
if fail_fallback:
|
|
|
|
|
# then we resort back to original whisper timestamps
|
|
|
|
|
# segment['start] and segment['end'] are unchanged
|
|
|
|
|
prev_t2 = 0
|
|
|
|
|
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
|
|
|
|
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
|
|
|
|
|
|
|
|
|
if 'vad' in segment:
|
|
|
|
|
curr_vdx = 0
|
|
|
|
|
curr_text = ''
|
|
|
|
|
for wrd_seg in word_segments_list:
|
|
|
|
|
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
|
|
|
|
|
curr_speaker = segment['speakers'][curr_vdx]
|
|
|
|
|
vad_segments_list.append(
|
|
|
|
|
{'start': segment['vad'][curr_vdx][0],
|
|
|
|
|
'end': segment['vad'][curr_vdx][1],
|
|
|
|
|
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
|
|
|
|
aligned_segments.append(
|
|
|
|
|
{
|
|
|
|
|
"text": segment["seg-text"][sub_seg_idx],
|
|
|
|
|
"start": seg_start_actual,
|
|
|
|
|
"end": seg_end_actual,
|
|
|
|
|
"char-segments": char_level,
|
|
|
|
|
"word-segments": word_level
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
curr_vdx += 1
|
|
|
|
|
curr_text = ''
|
|
|
|
|
curr_text += ' ' + wrd_seg['text']
|
|
|
|
|
if len(curr_text) > 0:
|
|
|
|
|
curr_speaker = segment['speakers'][curr_vdx]
|
|
|
|
|
vad_segments_list.append(
|
|
|
|
|
{'start': segment['vad'][curr_vdx][0],
|
|
|
|
|
'end': segment['vad'][curr_vdx][1],
|
|
|
|
|
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
|
|
|
|
)
|
|
|
|
|
curr_text = ''
|
|
|
|
|
total_word_segments_list += word_segments_list
|
|
|
|
|
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
|
|
|
|
|
if "language" in segment:
|
|
|
|
|
aligned_segments[-1]["language"] = segment["language"]
|
|
|
|
|
|
|
|
|
|
print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
|
|
|
|
|
char_level = {
|
|
|
|
|
"start": [],
|
|
|
|
|
"end": [],
|
|
|
|
|
"score": [],
|
|
|
|
|
"word-index": [],
|
|
|
|
|
}
|
|
|
|
|
word_level = {
|
|
|
|
|
"start": [],
|
|
|
|
|
"end": [],
|
|
|
|
|
"score": [],
|
|
|
|
|
"segment-text-start": [],
|
|
|
|
|
"segment-text-end": []
|
|
|
|
|
}
|
|
|
|
|
wdx = 0
|
|
|
|
|
cdx_prev = cdx + 2
|
|
|
|
|
sub_seg_idx += 1
|
|
|
|
|
seg_start_actual, seg_end_actual = None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# take min-max for actual segment-level timestamp
|
|
|
|
|
if seg_start_actual is None and start is not None:
|
|
|
|
|
seg_start_actual = start
|
|
|
|
|
if end is not None:
|
|
|
|
|
seg_end_actual = end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_t2 = segment["end"]
|
|
|
|
|
|
|
|
|
|
segment_align_success = True
|
|
|
|
|
# end while True loop
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# reset prev_t2 due to drifting issues
|
|
|
|
|
if not segment_align_success:
|
|
|
|
|
prev_t2 = 0
|
|
|
|
|
|
|
|
|
|
# shift segment index by amount of sub-segments
|
|
|
|
|
if "seg-text" in segment:
|
|
|
|
|
sdx += len(segment["seg-text"])
|
|
|
|
|
else:
|
|
|
|
|
sdx += 1
|
|
|
|
|
|
|
|
|
|
# create word level segments for .srt
|
|
|
|
|
word_seg = []
|
|
|
|
|
for seg in aligned_segments:
|
|
|
|
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
|
|
|
|
# character based
|
|
|
|
|
seg["word-segments"] = seg["char-segments"]
|
|
|
|
|
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
|
|
|
|
|
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
|
|
|
|
|
|
|
|
|
|
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
|
|
|
|
|
for wdx, wrow in wseg.iterrows():
|
|
|
|
|
if wrow["start"] is not None:
|
|
|
|
|
word_seg.append(
|
|
|
|
|
{
|
|
|
|
|
"start": wrow["start"],
|
|
|
|
|
"end": wrow["end"],
|
|
|
|
|
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"segments": aligned_segments, "word_segments": word_seg}
|
|
|
|
|
|
|
|
|
|
def load_align_model(language_code, device, model_name=None):
|
|
|
|
|
if model_name is None:
|
|
|
|
@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None):
|
|
|
|
|
|
|
|
|
|
return align_model, align_metadata
|
|
|
|
|
|
|
|
|
|
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
|
|
|
|
|
'''
|
|
|
|
|
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
|
|
|
|
"""
|
|
|
|
|
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
|
|
|
|
"""
|
|
|
|
|
curr_start = 0
|
|
|
|
|
curr_end = 0
|
|
|
|
|
merged_segments = []
|
|
|
|
@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
|
|
|
|
|
"start": curr_start,
|
|
|
|
|
"end": curr_end,
|
|
|
|
|
"segments": seg_idxs,
|
|
|
|
|
"speakers": speaker_idxs,
|
|
|
|
|
})
|
|
|
|
|
curr_start = seg.start
|
|
|
|
|
seg_idxs = []
|
|
|
|
@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
|
|
|
|
|
"start": curr_start,
|
|
|
|
|
"end": curr_end,
|
|
|
|
|
"segments": seg_idxs,
|
|
|
|
|
"speakers": speaker_idxs
|
|
|
|
|
})
|
|
|
|
|
return merged_segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_segments(
|
|
|
|
|
def transcribe_with_vad(
|
|
|
|
|
model: "Whisper",
|
|
|
|
|
audio: Union[str, np.ndarray, torch.Tensor],
|
|
|
|
|
merged_segments,
|
|
|
|
|
vad_pipeline,
|
|
|
|
|
mel = None,
|
|
|
|
|
verbose: Optional[bool] = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
'''
|
|
|
|
|
Transcribe according to predefined VAD segments.
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
Transcribe per VAD segment
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if mel is None:
|
|
|
|
|
mel = log_mel_spectrogram(audio)
|
|
|
|
|
|
|
|
|
|
prev = 0
|
|
|
|
|
output = {"segments": []}
|
|
|
|
|
|
|
|
|
|
output = {'segments': []}
|
|
|
|
|
vad_segments_list = []
|
|
|
|
|
vad_segments = vad_pipeline(audio)
|
|
|
|
|
for speech_turn in vad_segments.get_timeline().support():
|
|
|
|
|
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
|
|
|
# merge segments to approx 30s inputs to make whisper most appropraite
|
|
|
|
|
vad_segments = merge_chunks(vad_segments_list)
|
|
|
|
|
|
|
|
|
|
for sdx, seg_t in enumerate(merged_segments):
|
|
|
|
|
print(sdx, seg_t['start'], seg_t['end'], '...')
|
|
|
|
|
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
|
|
|
|
|
for sdx, seg_t in enumerate(vad_segments):
|
|
|
|
|
if verbose:
|
|
|
|
|
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
|
|
|
|
|
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
|
|
|
|
|
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
|
|
|
|
mel = mel[:, local_f_start:] # seek forward
|
|
|
|
|
prev = seg_f_start
|
|
|
|
|
local_mel = mel[:, :local_f_end-local_f_start]
|
|
|
|
|
result = transcribe(model, audio, mel=local_mel, **kwargs)
|
|
|
|
|
seg_t['text'] = result['text']
|
|
|
|
|
output['segments'].append(
|
|
|
|
|
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
|
|
|
|
|
seg_t["text"] = result["text"]
|
|
|
|
|
output["segments"].append(
|
|
|
|
|
{
|
|
|
|
|
'start': seg_t['start'],
|
|
|
|
|
'end': seg_t['end'],
|
|
|
|
|
'language': result['language'],
|
|
|
|
|
'text': result['text'],
|
|
|
|
|
'seg-text': [x['text'] for x in result['segments']],
|
|
|
|
|
'seg-start': [x['start'] for x in result['segments']],
|
|
|
|
|
'seg-end': [x['end'] for x in result['segments']],
|
|
|
|
|
"start": seg_t["start"],
|
|
|
|
|
"end": seg_t["end"],
|
|
|
|
|
"language": result["language"],
|
|
|
|
|
"text": result["text"],
|
|
|
|
|
"seg-text": [x["text"] for x in result["segments"]],
|
|
|
|
|
"seg-start": [x["start"] for x in result["segments"]],
|
|
|
|
|
"seg-end": [x["end"] for x in result["segments"]],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output['language'] = output['segments'][0]['language']
|
|
|
|
|
output["language"] = output["segments"][0]["language"]
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
|
|
|
|
|
|
|
|
|
for seg in result_segments:
|
|
|
|
|
wdf = pd.DataFrame(seg['word-segments'])
|
|
|
|
|
if len(wdf['start'].dropna()) == 0:
|
|
|
|
|
wdf['start'] = seg['start']
|
|
|
|
|
wdf['end'] = seg['end']
|
|
|
|
|
speakers = []
|
|
|
|
|
for wdx, wrow in wdf.iterrows():
|
|
|
|
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
|
|
|
|
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
|
|
|
|
# remove no hit
|
|
|
|
|
if not fill_nearest:
|
|
|
|
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
|
|
|
|
else:
|
|
|
|
|
dia_tmp = diarize_df
|
|
|
|
|
if len(dia_tmp) == 0:
|
|
|
|
|
speaker = None
|
|
|
|
|
else:
|
|
|
|
|
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
|
|
|
|
speakers.append(speaker)
|
|
|
|
|
seg['word-segments']['speaker'] = speakers
|
|
|
|
|
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
|
|
|
|
|
|
|
|
|
# create word level segments for .srt
|
|
|
|
|
word_seg = []
|
|
|
|
|
for seg in result_segments:
|
|
|
|
|
wseg = pd.DataFrame(seg["word-segments"])
|
|
|
|
|
for wdx, wrow in wseg.iterrows():
|
|
|
|
|
if wrow["start"] is not None:
|
|
|
|
|
speaker = wrow['speaker']
|
|
|
|
|
if speaker is None or speaker == np.nan:
|
|
|
|
|
speaker = "UNKNOWN"
|
|
|
|
|
word_seg.append(
|
|
|
|
|
{
|
|
|
|
|
"start": wrow["start"],
|
|
|
|
|
"end": wrow["end"],
|
|
|
|
|
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# TODO: create segments but split words on new speaker
|
|
|
|
|
|
|
|
|
|
return result_segments, word_seg
|
|
|
|
|
|
|
|
|
|
class Segment:
|
|
|
|
|
def __init__(self, start, end, speaker=None):
|
|
|
|
|
self.start = start
|
|
|
|
@ -589,11 +766,17 @@ def cli():
|
|
|
|
|
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
|
|
|
|
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
|
|
|
|
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
|
|
|
|
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
|
|
|
|
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
|
|
|
|
|
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
|
|
|
|
# vad params
|
|
|
|
|
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
|
|
|
|
parser.add_argument("--vad_input", default=None, type=str)
|
|
|
|
|
# diarization params
|
|
|
|
|
parser.add_argument("--diarize", action='store_true')
|
|
|
|
|
parser.add_argument("--min_speakers", default=None, type=int)
|
|
|
|
|
parser.add_argument("--max_speakers", default=None, type=int)
|
|
|
|
|
# output save params
|
|
|
|
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
|
|
|
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
|
|
|
|
|
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
|
|
|
|
|
|
|
@ -627,24 +810,32 @@ def cli():
|
|
|
|
|
align_model: str = args.pop("align_model")
|
|
|
|
|
align_extend: float = args.pop("align_extend")
|
|
|
|
|
align_from_prev: bool = args.pop("align_from_prev")
|
|
|
|
|
drop_non_aligned: bool = args.pop("drop_non_aligned")
|
|
|
|
|
interpolate_method: bool = args.pop("interpolate_method")
|
|
|
|
|
|
|
|
|
|
vad_filter: bool = args.pop("vad_filter")
|
|
|
|
|
vad_input: bool = args.pop("vad_input")
|
|
|
|
|
|
|
|
|
|
diarize: bool = args.pop("diarize")
|
|
|
|
|
min_speakers: int = args.pop("min_speakers")
|
|
|
|
|
max_speakers: int = args.pop("max_speakers")
|
|
|
|
|
|
|
|
|
|
vad_pipeline = None
|
|
|
|
|
if vad_input is not None:
|
|
|
|
|
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
|
|
|
|
elif vad_filter:
|
|
|
|
|
from pyannote.audio import Pipeline
|
|
|
|
|
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
|
|
|
|
|
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
|
|
|
|
|
|
|
|
|
|
diarize_pipeline = None
|
|
|
|
|
if diarize:
|
|
|
|
|
from pyannote.audio import Pipeline
|
|
|
|
|
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
|
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
|
|
|
|
if args["language"] is not None:
|
|
|
|
|
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
|
|
|
|
warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
|
|
|
|
|
args["language"] = "en"
|
|
|
|
|
|
|
|
|
|
temperature = args.pop("temperature")
|
|
|
|
@ -665,24 +856,10 @@ def cli():
|
|
|
|
|
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
|
|
|
|
|
|
|
|
|
for audio_path in args.pop("audio"):
|
|
|
|
|
if vad_filter or vad_input is not None:
|
|
|
|
|
output_segments = []
|
|
|
|
|
if vad_filter:
|
|
|
|
|
print("Performing VAD...")
|
|
|
|
|
# vad_segments = vad_pipeline(audio_path)
|
|
|
|
|
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
|
|
|
|
|
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
|
|
|
|
|
vad_segments = vad_pipeline(audio_path)
|
|
|
|
|
for speech_turn in vad_segments.get_timeline().support():
|
|
|
|
|
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
|
|
|
elif vad_input is not None:
|
|
|
|
|
# rttm format
|
|
|
|
|
for idx, row in vad_input.iterrows():
|
|
|
|
|
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
|
|
|
|
|
vad_segments = merge_chunks(output_segments)
|
|
|
|
|
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
|
|
|
|
|
if vad_filter:
|
|
|
|
|
print("Performing VAD...")
|
|
|
|
|
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
|
|
|
|
else:
|
|
|
|
|
vad_segments = None
|
|
|
|
|
print("Performing transcription...")
|
|
|
|
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
|
|
|
|
|
|
|
@ -693,9 +870,20 @@ def cli():
|
|
|
|
|
|
|
|
|
|
print("Performing alignment...")
|
|
|
|
|
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
|
|
|
|
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
|
|
|
|
|
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
|
|
|
|
audio_basename = os.path.basename(audio_path)
|
|
|
|
|
|
|
|
|
|
if diarize:
|
|
|
|
|
print("Performing diarization...")
|
|
|
|
|
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
|
|
|
|
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
|
|
|
|
|
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
|
|
|
|
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
|
|
|
|
# assumes each utterance is single speaker (needs fix)
|
|
|
|
|
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
|
|
|
|
|
result_aligned["segments"] = result_segments
|
|
|
|
|
result_aligned["word_segments"] = word_segments
|
|
|
|
|
|
|
|
|
|
# save TXT
|
|
|
|
|
if output_type in ["txt", "all"]:
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
|
|
|
@ -711,19 +899,27 @@ def cli():
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
|
|
|
write_srt(result_aligned["segments"], file=srt)
|
|
|
|
|
|
|
|
|
|
# save per-word SRT
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
|
|
|
|
|
write_srt(result_aligned["word_segments"], file=srt)
|
|
|
|
|
# save TSV
|
|
|
|
|
if output_type in ["tsv", "all"]:
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
|
|
|
write_tsv(result_aligned["segments"], file=srt)
|
|
|
|
|
|
|
|
|
|
# save SRT word-level
|
|
|
|
|
if output_type in ["srt-word", "all"]:
|
|
|
|
|
# save per-word SRT
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
|
|
|
|
|
write_srt(result_aligned["word_segments"], file=srt)
|
|
|
|
|
|
|
|
|
|
# save ASS
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
|
|
|
|
write_ass(result_aligned["segments"], file=ass)
|
|
|
|
|
if output_type in ["ass", "all"]:
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
|
|
|
|
write_ass(result_aligned["segments"], file=ass)
|
|
|
|
|
|
|
|
|
|
if vad_filter is not None:
|
|
|
|
|
# save per-word SRT
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
|
|
|
|
|
write_srt(result_aligned["vad_segments"], file=srt)
|
|
|
|
|
# save ASS character-level
|
|
|
|
|
if output_type in ["ass-char", "all"]:
|
|
|
|
|
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
|
|
|
|
|
write_ass(result_aligned["segments"], file=ass, resolution="char")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
cli()
|
|
|
|
|