mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based)
This commit is contained in:
@ -5,3 +5,4 @@ transformers
|
|||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
pandas
|
pandas
|
||||||
setuptools==65.6.3
|
setuptools==65.6.3
|
||||||
|
nltk
|
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from .utils import interpolate_nans
|
||||||
|
import nltk
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
@ -84,44 +85,13 @@ def align(
|
|||||||
align_model_metadata: dict,
|
align_model_metadata: dict,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
device: str,
|
device: str,
|
||||||
extend_duration: float = 0.0,
|
|
||||||
start_from_previous: bool = True,
|
|
||||||
interpolate_method: str = "nearest",
|
interpolate_method: str = "nearest",
|
||||||
|
return_char_alignments: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Force align phoneme recognition predictions to known transcription
|
Align phoneme recognition predictions to known transcription.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
transcript: Iterator[dict]
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
model: torch.nn.Module
|
|
||||||
Alignment model (wav2vec2)
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
device: str
|
|
||||||
cuda device
|
|
||||||
|
|
||||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
|
||||||
diarization segments with speaker labels.
|
|
||||||
|
|
||||||
extend_duration: float
|
|
||||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
|
||||||
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
|
||||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
|
||||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -135,42 +105,21 @@ def align(
|
|||||||
model_lang = align_model_metadata["language"]
|
model_lang = align_model_metadata["language"]
|
||||||
model_type = align_model_metadata["type"]
|
model_type = align_model_metadata["type"]
|
||||||
|
|
||||||
aligned_segments = []
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
|
|
||||||
prev_t2 = 0
|
|
||||||
|
|
||||||
char_segments_arr = {
|
|
||||||
"segment-idx": [],
|
|
||||||
"subsegment-idx": [],
|
|
||||||
"word-idx": [],
|
|
||||||
"char": [],
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
while True:
|
|
||||||
segment_align_success = False
|
|
||||||
|
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
transcription = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
|
||||||
# e.g. "$300" -> "three hundred dollars"
|
|
||||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
|
||||||
|
|
||||||
# split into words
|
# split into words
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
per_word = transcription.split(" ")
|
per_word = text.split(" ")
|
||||||
else:
|
else:
|
||||||
per_word = transcription
|
per_word = text
|
||||||
|
|
||||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
|
||||||
clean_char, clean_cdx = [], []
|
clean_char, clean_cdx = [], []
|
||||||
for cdx, char in enumerate(transcription):
|
for cdx, char in enumerate(text):
|
||||||
char_ = char.lower()
|
char_ = char.lower()
|
||||||
# wav2vec2 models use "|" character to represent spaces
|
# wav2vec2 models use "|" character to represent spaces
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
@ -179,7 +128,7 @@ def align(
|
|||||||
# ignore whitespace at beginning and end of transcript
|
# ignore whitespace at beginning and end of transcript
|
||||||
if cdx < num_leading:
|
if cdx < num_leading:
|
||||||
pass
|
pass
|
||||||
elif cdx > len(transcription) - num_trailing - 1:
|
elif cdx > len(text) - num_trailing - 1:
|
||||||
pass
|
pass
|
||||||
elif char_ in model_dictionary.keys():
|
elif char_ in model_dictionary.keys():
|
||||||
clean_char.append(char_)
|
clean_char.append(char_)
|
||||||
@ -190,35 +139,49 @@ def align(
|
|||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
if any([c in model_dictionary.keys() for c in wrd]):
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
# if no characters are in the dictionary, then we skip this segment...
|
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||||
if len(clean_char) == 0:
|
|
||||||
|
segment["clean_char"] = clean_char
|
||||||
|
segment["clean_cdx"] = clean_cdx
|
||||||
|
segment["clean_wdx"] = clean_wdx
|
||||||
|
segment["sentence_spans"] = sentence_spans
|
||||||
|
|
||||||
|
aligned_segments = []
|
||||||
|
|
||||||
|
# 2. Get prediction matrix from alignment model & align
|
||||||
|
for sdx, segment in enumerate(transcript):
|
||||||
|
t1 = segment["start"]
|
||||||
|
t2 = segment["end"]
|
||||||
|
text = segment["text"]
|
||||||
|
|
||||||
|
aligned_seg = {
|
||||||
|
"start": t1,
|
||||||
|
"end": t2,
|
||||||
|
"text": text,
|
||||||
|
"words": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_char_alignments:
|
||||||
|
aligned_seg["chars"] = []
|
||||||
|
|
||||||
|
# check we can align
|
||||||
|
if len(segment["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
|
continue
|
||||||
|
|
||||||
transcription_cleaned = "".join(clean_char)
|
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
|
||||||
|
|
||||||
# we only pad if not using VAD filtering
|
|
||||||
if "seg_text" not in segment:
|
|
||||||
# pad according original timestamps
|
|
||||||
t1 = max(segment["start"] - extend_duration, 0)
|
|
||||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
|
||||||
|
|
||||||
# use prev_t2 as current t1 if it"s later
|
|
||||||
if start_from_previous and t1 < prev_t2:
|
|
||||||
t1 = prev_t2
|
|
||||||
|
|
||||||
# check if timestamp range is still valid
|
|
||||||
if t1 >= MAX_DURATION:
|
|
||||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
if t2 - t1 < 0.02:
|
continue
|
||||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
|
||||||
break
|
text_clean = "".join(segment["clean_char"])
|
||||||
|
tokens = [model_dictionary[c] for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
|
|
||||||
|
# TODO: Probably can get some speedup gain with batched inference here
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
@ -239,231 +202,109 @@ def align(
|
|||||||
|
|
||||||
trellis = get_trellis(emission, tokens, blank_id)
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
path = backtrack(trellis, emission, tokens, blank_id)
|
path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
char_segments = merge_repeats(path, transcription_cleaned)
|
continue
|
||||||
# word_segments = merge_words(char_segments)
|
|
||||||
|
|
||||||
|
char_segments = merge_repeats(path, text_clean)
|
||||||
|
|
||||||
# sub-segments
|
duration = t2 -t1
|
||||||
if "seg-text" not in segment:
|
|
||||||
segment["seg-text"] = [transcription]
|
|
||||||
|
|
||||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
|
||||||
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
|
||||||
sub_seg_idx = 0
|
|
||||||
|
|
||||||
wdx = 0
|
|
||||||
duration = t2 - t1
|
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
for cdx, char in enumerate(transcription + " "):
|
|
||||||
is_last = False
|
|
||||||
if cdx == len(transcription):
|
|
||||||
break
|
|
||||||
elif cdx+1 == len(transcription):
|
|
||||||
is_last = True
|
|
||||||
|
|
||||||
|
|
||||||
|
# assign timestamps to aligned characters
|
||||||
|
char_segments_arr = []
|
||||||
|
word_idx = 0
|
||||||
|
for cdx, char in enumerate(text):
|
||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in clean_cdx:
|
if cdx in segment["clean_cdx"]:
|
||||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||||
start = round(char_seg.start * ratio + t1, 3)
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = round(char_seg.end * ratio + t1, 3)
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = char_seg.score
|
score = round(char_seg.score, 3)
|
||||||
|
|
||||||
char_segments_arr["char"].append(char)
|
char_segments_arr.append(
|
||||||
char_segments_arr["start"].append(start)
|
{
|
||||||
char_segments_arr["end"].append(end)
|
"char": char,
|
||||||
char_segments_arr["score"].append(score)
|
"start": start,
|
||||||
char_segments_arr["word-idx"].append(wdx)
|
"end": end,
|
||||||
char_segments_arr["segment-idx"].append(sdx)
|
"score": score,
|
||||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
"word-idx": word_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# word-level info
|
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
||||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||||
# character == word
|
word_idx += 1
|
||||||
wdx += 1
|
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
word_idx += 1
|
||||||
wdx += 1
|
|
||||||
|
|
||||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
||||||
wdx = 0
|
|
||||||
sub_seg_idx += 1
|
|
||||||
|
|
||||||
prev_t2 = segment["end"]
|
|
||||||
|
|
||||||
segment_align_success = True
|
|
||||||
# end while True loop
|
|
||||||
break
|
|
||||||
|
|
||||||
# reset prev_t2 due to drifting issues
|
|
||||||
if not segment_align_success:
|
|
||||||
prev_t2 = 0
|
|
||||||
|
|
||||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||||
not_space = char_segments_arr["char"] != " "
|
|
||||||
|
|
||||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
aligned_subsegments = []
|
||||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
# assign sentence_idx to each character index
|
||||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
char_segments_arr["sentence-idx"] = None
|
||||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||||
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
|
||||||
|
|
||||||
word_segments_arr = {}
|
sentence_text = text[sstart:send]
|
||||||
|
sentence_start = curr_chars["start"].min()
|
||||||
|
sentence_end = curr_chars["end"].max()
|
||||||
|
sentence_words = []
|
||||||
|
|
||||||
# start of word is first char with a timestamp
|
for word_idx in curr_chars["word-idx"].unique():
|
||||||
word_segments_arr["start"] = per_word_grp["start"].min().values
|
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
||||||
# end of word is last char with a timestamp
|
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||||
word_segments_arr["end"] = per_word_grp["end"].max().values
|
if len(word_text) == 0:
|
||||||
# score of word is mean (excluding nan)
|
continue
|
||||||
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
word_start = word_chars["start"].min()
|
||||||
|
word_end = word_chars["end"].max()
|
||||||
|
word_score = round(word_chars["score"].mean(), 3)
|
||||||
|
|
||||||
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
# -1 indicates unalignable
|
||||||
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
word_segment = {"word": word_text}
|
||||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
|
||||||
|
|
||||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
if not np.isnan(word_start):
|
||||||
segments_arr = {}
|
word_segment["start"] = word_start
|
||||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
if not np.isnan(word_end):
|
||||||
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
word_segment["end"] = word_end
|
||||||
segments_arr = pd.DataFrame(segments_arr)
|
if not np.isnan(word_score):
|
||||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
word_segment["score"] = word_score
|
||||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
|
||||||
|
|
||||||
# interpolate missing words / sub-segments
|
sentence_words.append(word_segment)
|
||||||
if interpolate_method != "ignore":
|
|
||||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
|
||||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
|
||||||
# we still know which word timestamps are interpolated because their score == nan
|
|
||||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
|
||||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
|
||||||
|
|
||||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
aligned_subsegments.append({
|
||||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
"text": sentence_text,
|
||||||
|
"start": sentence_start,
|
||||||
|
"end": sentence_end,
|
||||||
|
"words": sentence_words,
|
||||||
|
})
|
||||||
|
|
||||||
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
if return_char_alignments:
|
||||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
||||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
curr_chars.fillna(-1, inplace=True)
|
||||||
|
curr_chars = curr_chars.to_dict("records")
|
||||||
|
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
||||||
|
|
||||||
# merge words & subsegments which are missing times
|
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
||||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
||||||
|
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||||
|
# concatenate sentences with same timestamps
|
||||||
|
agg_dict = {"text": " ".join, "words": "sum"}
|
||||||
|
if return_char_alignments:
|
||||||
|
agg_dict["chars"] = "sum"
|
||||||
|
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||||
|
aligned_subsegments = aligned_subsegments.to_dict('records')
|
||||||
|
aligned_segments += aligned_subsegments
|
||||||
|
|
||||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
# create word_segments list
|
||||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
word_segments = []
|
||||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
for segment in aligned_segments:
|
||||||
|
word_segments += segment["words"]
|
||||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
|
||||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
|
||||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
|
||||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
|
||||||
else:
|
|
||||||
word_segments_arr.dropna(inplace=True)
|
|
||||||
segments_arr.dropna(inplace=True)
|
|
||||||
|
|
||||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
|
||||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
|
||||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
|
||||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
|
||||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
|
||||||
|
|
||||||
|
|
||||||
aligned_segments = []
|
|
||||||
aligned_segments_word = []
|
|
||||||
|
|
||||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
|
||||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
|
||||||
|
|
||||||
for sdx, srow in segments_arr.iterrows():
|
|
||||||
|
|
||||||
seg_idx = int(srow["segment-idx"])
|
|
||||||
sub_start = int(srow["subsegment-idx-start"])
|
|
||||||
sub_end = int(srow["subsegment-idx-end"])
|
|
||||||
|
|
||||||
seg = transcript[seg_idx]
|
|
||||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
|
||||||
|
|
||||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
|
||||||
wseg["start"].fillna(srow["start"], inplace=True)
|
|
||||||
wseg["end"].fillna(srow["end"], inplace=True)
|
|
||||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
|
||||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
|
||||||
|
|
||||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
|
||||||
# fixes bug for single segment in transcript
|
|
||||||
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
|
||||||
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
|
||||||
if 'level_1' in cseg: del cseg['level_1']
|
|
||||||
if 'level_0' in cseg: del cseg['level_0']
|
|
||||||
cseg.reset_index(inplace=True)
|
|
||||||
|
|
||||||
def get_raw_text(word_row):
|
|
||||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
|
||||||
|
|
||||||
word_list = []
|
|
||||||
wdx = 0
|
|
||||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
|
||||||
if not curr_text.startswith(" "):
|
|
||||||
curr_text = " " + curr_text
|
|
||||||
|
|
||||||
if len(wseg) > 1:
|
|
||||||
for _, wrow in wseg.iloc[1:].iterrows():
|
|
||||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
|
||||||
word_start = wseg.iloc[wdx]['start']
|
|
||||||
word_end = wseg.iloc[wdx]['end']
|
|
||||||
|
|
||||||
aligned_segments_word.append(
|
|
||||||
{
|
|
||||||
"text": curr_text.strip(),
|
|
||||||
"start": word_start,
|
|
||||||
"end": word_end
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
word_list.append(
|
|
||||||
{
|
|
||||||
"word": curr_text.rstrip(),
|
|
||||||
"start": word_start,
|
|
||||||
"end": word_end,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_text = " "
|
|
||||||
curr_text += get_raw_text(wrow) + " "
|
|
||||||
wdx += 1
|
|
||||||
|
|
||||||
aligned_segments_word.append(
|
|
||||||
{
|
|
||||||
"text": curr_text.strip(),
|
|
||||||
"start": wseg.iloc[wdx]["start"],
|
|
||||||
"end": wseg.iloc[wdx]["end"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
word_list.append(
|
|
||||||
{
|
|
||||||
"word": curr_text.rstrip(),
|
|
||||||
"start": wseg.iloc[wdx]['start'],
|
|
||||||
"end": wseg.iloc[wdx]['end'],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
aligned_segments.append(
|
|
||||||
{
|
|
||||||
"start": srow["start"],
|
|
||||||
"end": srow["end"],
|
|
||||||
"text": text,
|
|
||||||
"words": word_list,
|
|
||||||
"word-segments": wseg,
|
|
||||||
"char-segments": cseg
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
|
||||||
|
|
||||||
|
return {"segments": aligned_segments, "word_segments": word_segments}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
|
155
whisperx/asr.py
155
whisperx/asr.py
@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
|||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
'''
|
'''
|
||||||
FasterWhisperModel provides batched inference for faster-whisper.
|
FasterWhisperModel provides batched inference for faster-whisper.
|
||||||
Currently only works in non-timestamp mode.
|
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||||
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
class FasterWhisperPipeline(Pipeline):
|
class FasterWhisperPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||||
|
"""
|
||||||
|
# TODO:
|
||||||
|
# - add support for timestamp mode
|
||||||
|
# - add support for custom inference kwargs
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main_type = "simple"
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jiwer
|
|
||||||
from tqdm import tqdm
|
|
||||||
from whisper.normalizers import EnglishTextNormalizer
|
|
||||||
|
|
||||||
from benchmark.tedlium import parse_tedlium_annos
|
|
||||||
|
|
||||||
if main_type == "complex":
|
|
||||||
from faster_whisper.tokenizer import Tokenizer
|
|
||||||
from faster_whisper.transcribe import TranscriptionOptions
|
|
||||||
from faster_whisper.vad import (SpeechTimestampsMap,
|
|
||||||
get_speech_timestamps)
|
|
||||||
|
|
||||||
from whisperx.vad import load_vad_model, merge_chunks
|
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
|
||||||
faster_t_options = TranscriptionOptions(
|
|
||||||
beam_size=5,
|
|
||||||
best_of=5,
|
|
||||||
patience=1,
|
|
||||||
length_penalty=1,
|
|
||||||
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
|
||||||
compression_ratio_threshold=2.4,
|
|
||||||
log_prob_threshold=-1.0,
|
|
||||||
no_speech_threshold=0.6,
|
|
||||||
condition_on_previous_text=False,
|
|
||||||
initial_prompt=None,
|
|
||||||
prefix=None,
|
|
||||||
suppress_blank=True,
|
|
||||||
suppress_tokens=[-1],
|
|
||||||
without_timestamps=True,
|
|
||||||
max_initial_timestamp=0.0,
|
|
||||||
word_timestamps=False,
|
|
||||||
prepend_punctuations="\"'“¿([{-",
|
|
||||||
append_punctuations="\"'.。,,!!??::”)]}、"
|
|
||||||
)
|
|
||||||
whisper_arch = "large-v2"
|
|
||||||
device = "cuda"
|
|
||||||
batch_size = 16
|
|
||||||
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
|
|
||||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
|
|
||||||
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
|
|
||||||
fn = "DanielKahneman_2010.wav"
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
vad_model = load_vad_model("cuda", 0.6, 0.3)
|
|
||||||
audio = load_audio(os.path.join(wav_dir, fn))
|
|
||||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
|
||||||
vad_segments = merge_chunks(vad_segments, 30)
|
|
||||||
|
|
||||||
def data(audio, segments):
|
|
||||||
for seg in segments:
|
|
||||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
|
||||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
|
||||||
# print(f2-f1)
|
|
||||||
yield {'inputs': audio[f1:f2]}
|
|
||||||
vad_method="pyannote"
|
|
||||||
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
wer_li = []
|
|
||||||
time_li = []
|
|
||||||
for fn in os.listdir(wav_dir):
|
|
||||||
if fn == "RobertGupta_2010U.wav":
|
|
||||||
continue
|
|
||||||
base_fn = fn.split('.')[0]
|
|
||||||
audio_fp = os.path.join(wav_dir, fn)
|
|
||||||
|
|
||||||
audio = load_audio(audio_fp)
|
|
||||||
t1 = time.time()
|
|
||||||
if vad_method == "pyannote":
|
|
||||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
|
||||||
vad_segments = merge_chunks(vad_segments, 30)
|
|
||||||
elif vad_method == "silero":
|
|
||||||
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
|
|
||||||
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
|
|
||||||
new_segs = []
|
|
||||||
curr_start = vad_segments[0]['start']
|
|
||||||
curr_end = vad_segments[0]['end']
|
|
||||||
for seg in vad_segments[1:]:
|
|
||||||
if seg['end'] - curr_start > 30:
|
|
||||||
new_segs.append({"start": curr_start, "end": curr_end})
|
|
||||||
curr_start = seg['start']
|
|
||||||
curr_end = seg['end']
|
|
||||||
else:
|
|
||||||
curr_end = seg['end']
|
|
||||||
new_segs.append({"start": curr_start, "end": curr_end})
|
|
||||||
vad_segments = new_segs
|
|
||||||
text = []
|
|
||||||
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
|
|
||||||
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
|
|
||||||
text.append(out['text'])
|
|
||||||
t2 = time.time()
|
|
||||||
if batch_size == 1:
|
|
||||||
text = [x[0] for x in text]
|
|
||||||
text = " ".join(text)
|
|
||||||
|
|
||||||
normalizer = EnglishTextNormalizer()
|
|
||||||
text = normalizer(text)
|
|
||||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
|
||||||
|
|
||||||
wer_result = jiwer.wer(gt_corpus, text)
|
|
||||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
|
||||||
|
|
||||||
wer_li.append(wer_result)
|
|
||||||
time_li.append(t2-t1)
|
|
||||||
print("# Avg Mean...")
|
|
||||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
|
||||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
|
||||||
elif main_type == "simple":
|
|
||||||
model = load_model(
|
|
||||||
"large-v2",
|
|
||||||
device="cuda",
|
|
||||||
language="en",
|
|
||||||
)
|
|
||||||
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
wer_li = []
|
|
||||||
time_li = []
|
|
||||||
for fn in os.listdir(wav_dir):
|
|
||||||
if fn == "RobertGupta_2010U.wav":
|
|
||||||
continue
|
|
||||||
# fn = "DanielKahneman_2010.wav"
|
|
||||||
base_fn = fn.split('.')[0]
|
|
||||||
audio_fp = os.path.join(wav_dir, fn)
|
|
||||||
|
|
||||||
audio = load_audio(audio_fp)
|
|
||||||
t1 = time.time()
|
|
||||||
out = model.transcribe(audio_fp, batch_size=8)["segments"]
|
|
||||||
t2 = time.time()
|
|
||||||
|
|
||||||
text = " ".join([x['text'] for x in out])
|
|
||||||
normalizer = EnglishTextNormalizer()
|
|
||||||
text = normalizer(text)
|
|
||||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
|
||||||
|
|
||||||
wer_result = jiwer.wer(gt_corpus, text)
|
|
||||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
|
||||||
|
|
||||||
wer_li.append(wer_result)
|
|
||||||
time_li.append(t2-t1)
|
|
||||||
print("# Avg Mean...")
|
|
||||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
|
||||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
|
||||||
|
@ -11,7 +11,6 @@ class DiarizationPipeline:
|
|||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
device: Optional[Union[str, torch.device]] = "cpu",
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
@ -21,59 +20,44 @@ class DiarizationPipeline:
|
|||||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||||
|
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
||||||
return diarize_df
|
return diarize_df
|
||||||
|
|
||||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
|
||||||
for seg in result_segments:
|
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||||
wdf = seg['word-segments']
|
transcript_segments = transcript_result["segments"]
|
||||||
if len(wdf['start'].dropna()) == 0:
|
for seg in transcript_segments:
|
||||||
wdf['start'] = seg['start']
|
# assign speaker to segment (if any)
|
||||||
wdf['end'] = seg['end']
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
|
||||||
speakers = []
|
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
||||||
for wdx, wrow in wdf.iterrows():
|
# remove no hit, otherwise we look for closest (even negative intersection...)
|
||||||
if not np.isnan(wrow['start']):
|
if not fill_nearest:
|
||||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
else:
|
||||||
|
dia_tmp = diarize_df
|
||||||
|
if len(dia_tmp) > 0:
|
||||||
|
# sum over speakers
|
||||||
|
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||||
|
seg["speaker"] = speaker
|
||||||
|
|
||||||
|
# assign speaker to words
|
||||||
|
if 'words' in seg:
|
||||||
|
for word in seg['words']:
|
||||||
|
if 'start' in word:
|
||||||
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
|
||||||
|
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
|
||||||
# remove no hit
|
# remove no hit
|
||||||
if not fill_nearest:
|
if not fill_nearest:
|
||||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||||
else:
|
else:
|
||||||
dia_tmp = diarize_df
|
dia_tmp = diarize_df
|
||||||
if len(dia_tmp) == 0:
|
if len(dia_tmp) > 0:
|
||||||
speaker = None
|
# sum over speakers
|
||||||
else:
|
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
word["speaker"] = speaker
|
||||||
else:
|
|
||||||
speaker = None
|
|
||||||
speakers.append(speaker)
|
|
||||||
seg['word-segments']['speaker'] = speakers
|
|
||||||
|
|
||||||
speaker_count = pd.Series(speakers).value_counts()
|
return transcript_result
|
||||||
if len(speaker_count) == 0:
|
|
||||||
seg["speaker"]= "UNKNOWN"
|
|
||||||
else:
|
|
||||||
seg["speaker"] = speaker_count.index[0]
|
|
||||||
|
|
||||||
# create word level segments for .srt
|
|
||||||
word_seg = []
|
|
||||||
for seg in result_segments:
|
|
||||||
wseg = pd.DataFrame(seg["word-segments"])
|
|
||||||
for wdx, wrow in wseg.iterrows():
|
|
||||||
if wrow["start"] is not None:
|
|
||||||
speaker = wrow['speaker']
|
|
||||||
if speaker is None or speaker == np.nan:
|
|
||||||
speaker = "UNKNOWN"
|
|
||||||
word_seg.append(
|
|
||||||
{
|
|
||||||
"start": wrow["start"],
|
|
||||||
"end": wrow["end"],
|
|
||||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: create segments but split words on new speaker
|
|
||||||
|
|
||||||
return result_segments, word_seg
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start, end, speaker=None):
|
||||||
|
@ -64,14 +64,11 @@ def cli():
|
|||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
|
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
|
|
||||||
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
|
||||||
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
|
||||||
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -97,7 +94,6 @@ def cli():
|
|||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
|
||||||
# TODO: check model loading works.
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
if args["language"] is not None:
|
if args["language"] is not None:
|
||||||
@ -176,6 +172,7 @@ def cli():
|
|||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
|
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
|
||||||
|
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
# Unload align model
|
# Unload align model
|
||||||
@ -193,18 +190,10 @@ def cli():
|
|||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
result = {"segments": results_segments, "word_segments": word_segments}
|
|
||||||
results.append((result, input_audio_path))
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
# Remove pandas dataframes from result so that
|
|
||||||
# we can serialize the result with json
|
|
||||||
for seg in result["segments"]:
|
|
||||||
seg.pop("word-segments", None)
|
|
||||||
seg.pop("char-segments", None)
|
|
||||||
|
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: list[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
times = []
|
||||||
|
last = result["segments"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
for i, original_timing in enumerate(segment["words"]):
|
||||||
timing = original_timing.copy()
|
timing = original_timing.copy()
|
||||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
long_pause = not preserve_segments
|
||||||
|
if "start" in timing:
|
||||||
|
long_pause = long_pause and timing["start"] - last > 3.0
|
||||||
|
else:
|
||||||
|
long_pause = False
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||||
@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
or seg_break
|
or seg_break
|
||||||
):
|
):
|
||||||
# subtitle break
|
# subtitle break
|
||||||
yield subtitle
|
yield subtitle, times
|
||||||
subtitle = []
|
subtitle = []
|
||||||
|
times = []
|
||||||
line_count = 1
|
line_count = 1
|
||||||
elif line_len > 0:
|
elif line_len > 0:
|
||||||
# line break
|
# line break
|
||||||
@ -260,25 +266,36 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
timing["word"] = "\n" + timing["word"]
|
timing["word"] = "\n" + timing["word"]
|
||||||
line_len = len(timing["word"].strip())
|
line_len = len(timing["word"].strip())
|
||||||
subtitle.append(timing)
|
subtitle.append(timing)
|
||||||
|
times.append((segment["start"], segment["end"], segment.get("speaker")))
|
||||||
|
if "start" in timing:
|
||||||
last = timing["start"]
|
last = timing["start"]
|
||||||
if len(subtitle) > 0:
|
if len(subtitle) > 0:
|
||||||
yield subtitle
|
yield subtitle, times
|
||||||
|
|
||||||
if "words" in result["segments"][0]:
|
if "words" in result["segments"][0]:
|
||||||
for subtitle in iterate_subtitles():
|
for subtitle, _ in iterate_subtitles():
|
||||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
sstart, ssend, speaker = _[0]
|
||||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
subtitle_start = self.format_timestamp(sstart)
|
||||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
subtitle_end = self.format_timestamp(ssend)
|
||||||
if highlight_words:
|
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||||
|
has_timing = any(["start" in word for word in subtitle])
|
||||||
|
|
||||||
|
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||||
|
prefix = ""
|
||||||
|
if speaker is not None:
|
||||||
|
prefix = f"[{speaker}]: "
|
||||||
|
|
||||||
|
if highlight_words and has_timing:
|
||||||
last = subtitle_start
|
last = subtitle_start
|
||||||
all_words = [timing["word"] for timing in subtitle]
|
all_words = [timing["word"] for timing in subtitle]
|
||||||
for i, this_word in enumerate(subtitle):
|
for i, this_word in enumerate(subtitle):
|
||||||
|
if "start" in this_word:
|
||||||
start = self.format_timestamp(this_word["start"])
|
start = self.format_timestamp(this_word["start"])
|
||||||
end = self.format_timestamp(this_word["end"])
|
end = self.format_timestamp(this_word["end"])
|
||||||
if last != start:
|
if last != start:
|
||||||
yield last, start, subtitle_text
|
yield last, start, subtitle_text
|
||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, prefix + " ".join(
|
||||||
[
|
[
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
if j == i
|
if j == i
|
||||||
@ -288,12 +305,14 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
)
|
)
|
||||||
last = end
|
last = end
|
||||||
else:
|
else:
|
||||||
yield subtitle_start, subtitle_end, subtitle_text
|
yield subtitle_start, subtitle_end, prefix + subtitle_text
|
||||||
else:
|
else:
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
segment_start = self.format_timestamp(segment["start"])
|
segment_start = self.format_timestamp(segment["start"])
|
||||||
segment_end = self.format_timestamp(segment["end"])
|
segment_end = self.format_timestamp(segment["end"])
|
||||||
segment_text = segment["text"].strip().replace("-->", "->")
|
segment_text = segment["text"].strip().replace("-->", "->")
|
||||||
|
if "speaker" in segment:
|
||||||
|
segment_text = f"[{segment['speaker']}]: {segment_text}"
|
||||||
yield segment_start, segment_end, segment_text
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
def format_timestamp(self, seconds: float):
|
def format_timestamp(self, seconds: float):
|
||||||
|
Reference in New Issue
Block a user