mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based)
This commit is contained in:
@ -4,4 +4,5 @@ faster-whisper
|
||||
transformers
|
||||
ffmpeg-python==0.2.0
|
||||
pandas
|
||||
setuptools==65.6.3
|
||||
setuptools==65.6.3
|
||||
nltk
|
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
import nltk
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@ -84,386 +85,226 @@ def align(
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
interpolate_method: str = "nearest",
|
||||
return_char_alignments: bool = False,
|
||||
):
|
||||
"""
|
||||
Force align phoneme recognition predictions to known transcription
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcript: Iterator[dict]
|
||||
The Whisper model instance
|
||||
|
||||
model: torch.nn.Module
|
||||
Alignment model (wav2vec2)
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
device: str
|
||||
cuda device
|
||||
|
||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
||||
diarization segments with speaker labels.
|
||||
|
||||
extend_duration: float
|
||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
model_lang = align_model_metadata["language"]
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
aligned_segments = []
|
||||
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = {
|
||||
"segment-idx": [],
|
||||
"subsegment-idx": [],
|
||||
"word-idx": [],
|
||||
"char": [],
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
}
|
||||
|
||||
# 1. Preprocess to keep only characters in dictionary
|
||||
for sdx, segment in enumerate(transcript):
|
||||
while True:
|
||||
segment_align_success = False
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
transcription = segment["text"]
|
||||
# split into words
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = text.split(" ")
|
||||
else:
|
||||
per_word = text
|
||||
|
||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||
# e.g. "$300" -> "three hundred dollars"
|
||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||
|
||||
# split into words
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(text):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = transcription.split(" ")
|
||||
else:
|
||||
per_word = transcription
|
||||
|
||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(transcription):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(transcription) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# if no characters are in the dictionary, then we skip this segment...
|
||||
if len(clean_char) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
break
|
||||
|
||||
transcription_cleaned = "".join(clean_char)
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
|
||||
# we only pad if not using VAD filtering
|
||||
if "seg_text" not in segment:
|
||||
# pad according original timestamps
|
||||
t1 = max(segment["start"] - extend_duration, 0)
|
||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||
|
||||
# use prev_t2 as current t1 if it"s later
|
||||
if start_from_previous and t1 < prev_t2:
|
||||
t1 = prev_t2
|
||||
|
||||
# check if timestamp range is still valid
|
||||
if t1 >= MAX_DURATION:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
break
|
||||
if t2 - t1 < 0.02:
|
||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||
break
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
break
|
||||
char_segments = merge_repeats(path, transcription_cleaned)
|
||||
# word_segments = merge_words(char_segments)
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(text) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
# sub-segments
|
||||
if "seg-text" not in segment:
|
||||
segment["seg-text"] = [transcription]
|
||||
|
||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
||||
sub_seg_idx = 0
|
||||
|
||||
wdx = 0
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
for cdx, char in enumerate(transcription + " "):
|
||||
is_last = False
|
||||
if cdx == len(transcription):
|
||||
break
|
||||
elif cdx+1 == len(transcription):
|
||||
is_last = True
|
||||
|
||||
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = char_seg.score
|
||||
|
||||
char_segments_arr["char"].append(char)
|
||||
char_segments_arr["start"].append(start)
|
||||
char_segments_arr["end"].append(end)
|
||||
char_segments_arr["score"].append(score)
|
||||
char_segments_arr["word-idx"].append(wdx)
|
||||
char_segments_arr["segment-idx"].append(sdx)
|
||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
||||
|
||||
# word-level info
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character == word
|
||||
wdx += 1
|
||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx += 1
|
||||
|
||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx = 0
|
||||
sub_seg_idx += 1
|
||||
|
||||
prev_t2 = segment["end"]
|
||||
|
||||
segment_align_success = True
|
||||
# end while True loop
|
||||
break
|
||||
|
||||
# reset prev_t2 due to drifting issues
|
||||
if not segment_align_success:
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
not_space = char_segments_arr["char"] != " "
|
||||
|
||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
||||
|
||||
word_segments_arr = {}
|
||||
|
||||
# start of word is first char with a timestamp
|
||||
word_segments_arr["start"] = per_word_grp["start"].min().values
|
||||
# end of word is last char with a timestamp
|
||||
word_segments_arr["end"] = per_word_grp["end"].max().values
|
||||
# score of word is mean (excluding nan)
|
||||
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
||||
|
||||
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
||||
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||
|
||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
||||
segments_arr = {}
|
||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
||||
segments_arr = pd.DataFrame(segments_arr)
|
||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
||||
|
||||
# interpolate missing words / sub-segments
|
||||
if interpolate_method != "ignore":
|
||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
# we still know which word timestamps are interpolated because their score == nan
|
||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
# merge words & subsegments which are missing times
|
||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||
|
||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||
|
||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
||||
else:
|
||||
word_segments_arr.dropna(inplace=True)
|
||||
segments_arr.dropna(inplace=True)
|
||||
|
||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||
|
||||
segment["clean_char"] = clean_char
|
||||
segment["clean_cdx"] = clean_cdx
|
||||
segment["clean_wdx"] = clean_wdx
|
||||
segment["sentence_spans"] = sentence_spans
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments_word = []
|
||||
|
||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
for sdx, srow in segments_arr.iterrows():
|
||||
aligned_seg = {
|
||||
"start": t1,
|
||||
"end": t2,
|
||||
"text": text,
|
||||
"words": [],
|
||||
}
|
||||
|
||||
seg_idx = int(srow["segment-idx"])
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
sub_end = int(srow["subsegment-idx-end"])
|
||||
if return_char_alignments:
|
||||
aligned_seg["chars"] = []
|
||||
|
||||
seg = transcript[seg_idx]
|
||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||
# check we can align
|
||||
if len(segment["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
wseg["start"].fillna(srow["start"], inplace=True)
|
||||
wseg["end"].fillna(srow["end"], inplace=True)
|
||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
||||
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
# fixes bug for single segment in transcript
|
||||
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
||||
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
||||
if 'level_1' in cseg: del cseg['level_1']
|
||||
if 'level_0' in cseg: del cseg['level_0']
|
||||
cseg.reset_index(inplace=True)
|
||||
text_clean = "".join(segment["clean_char"])
|
||||
tokens = [model_dictionary[c] for c in text_clean]
|
||||
|
||||
def get_raw_text(word_row):
|
||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
word_list = []
|
||||
wdx = 0
|
||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
||||
if not curr_text.startswith(" "):
|
||||
curr_text = " " + curr_text
|
||||
# TODO: Probably can get some speedup gain with batched inference here
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
blank_id = 0
|
||||
for char, code in model_dictionary.items():
|
||||
if char == '[pad]' or char == '<pad>':
|
||||
blank_id = code
|
||||
|
||||
trellis = get_trellis(emission, tokens, blank_id)
|
||||
path = backtrack(trellis, emission, tokens, blank_id)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
char_segments = merge_repeats(path, text_clean)
|
||||
|
||||
duration = t2 -t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
|
||||
# assign timestamps to aligned characters
|
||||
char_segments_arr = []
|
||||
word_idx = 0
|
||||
for cdx, char in enumerate(text):
|
||||
start, end, score = None, None, None
|
||||
if cdx in segment["clean_cdx"]:
|
||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||
start = round(char_seg.start * ratio + t1, 3)
|
||||
end = round(char_seg.end * ratio + t1, 3)
|
||||
score = round(char_seg.score, 3)
|
||||
|
||||
char_segments_arr.append(
|
||||
{
|
||||
"char": char,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"score": score,
|
||||
"word-idx": word_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
word_idx += 1
|
||||
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||
word_idx += 1
|
||||
|
||||
if len(wseg) > 1:
|
||||
for _, wrow in wseg.iloc[1:].iterrows():
|
||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
||||
word_start = wseg.iloc[wdx]['start']
|
||||
word_end = wseg.iloc[wdx]['end']
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": word_start,
|
||||
"end": word_end
|
||||
}
|
||||
)
|
||||
aligned_subsegments = []
|
||||
# assign sentence_idx to each character index
|
||||
char_segments_arr["sentence-idx"] = None
|
||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||
|
||||
sentence_text = text[sstart:send]
|
||||
sentence_start = curr_chars["start"].min()
|
||||
sentence_end = curr_chars["end"].max()
|
||||
sentence_words = []
|
||||
|
||||
word_list.append(
|
||||
{
|
||||
"word": curr_text.rstrip(),
|
||||
"start": word_start,
|
||||
"end": word_end,
|
||||
}
|
||||
)
|
||||
for word_idx in curr_chars["word-idx"].unique():
|
||||
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
||||
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||
if len(word_text) == 0:
|
||||
continue
|
||||
word_start = word_chars["start"].min()
|
||||
word_end = word_chars["end"].max()
|
||||
word_score = round(word_chars["score"].mean(), 3)
|
||||
|
||||
curr_text = " "
|
||||
curr_text += get_raw_text(wrow) + " "
|
||||
wdx += 1
|
||||
# -1 indicates unalignable
|
||||
word_segment = {"word": word_text}
|
||||
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"]
|
||||
}
|
||||
)
|
||||
if not np.isnan(word_start):
|
||||
word_segment["start"] = word_start
|
||||
if not np.isnan(word_end):
|
||||
word_segment["end"] = word_end
|
||||
if not np.isnan(word_score):
|
||||
word_segment["score"] = word_score
|
||||
|
||||
word_list.append(
|
||||
{
|
||||
"word": curr_text.rstrip(),
|
||||
"start": wseg.iloc[wdx]['start'],
|
||||
"end": wseg.iloc[wdx]['end'],
|
||||
}
|
||||
)
|
||||
sentence_words.append(word_segment)
|
||||
|
||||
aligned_subsegments.append({
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"words": sentence_words,
|
||||
})
|
||||
|
||||
aligned_segments.append(
|
||||
{
|
||||
"start": srow["start"],
|
||||
"end": srow["end"],
|
||||
"text": text,
|
||||
"words": word_list,
|
||||
"word-segments": wseg,
|
||||
"char-segments": cseg
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
||||
if return_char_alignments:
|
||||
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
||||
curr_chars.fillna(-1, inplace=True)
|
||||
curr_chars = curr_chars.to_dict("records")
|
||||
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
||||
|
||||
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
||||
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
||||
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||
# concatenate sentences with same timestamps
|
||||
agg_dict = {"text": " ".join, "words": "sum"}
|
||||
if return_char_alignments:
|
||||
agg_dict["chars"] = "sum"
|
||||
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||
aligned_subsegments = aligned_subsegments.to_dict('records')
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word_segments list
|
||||
word_segments = []
|
||||
for segment in aligned_segments:
|
||||
word_segments += segment["words"]
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": word_segments}
|
||||
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
|
155
whisperx/asr.py
155
whisperx/asr.py
@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_type = "simple"
|
||||
import time
|
||||
|
||||
import jiwer
|
||||
from tqdm import tqdm
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
|
||||
from benchmark.tedlium import parse_tedlium_annos
|
||||
|
||||
if main_type == "complex":
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import TranscriptionOptions
|
||||
from faster_whisper.vad import (SpeechTimestampsMap,
|
||||
get_speech_timestamps)
|
||||
|
||||
from whisperx.vad import load_vad_model, merge_chunks
|
||||
|
||||
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
faster_t_options = TranscriptionOptions(
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
length_penalty=1,
|
||||
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
no_speech_threshold=0.6,
|
||||
condition_on_previous_text=False,
|
||||
initial_prompt=None,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens=[-1],
|
||||
without_timestamps=True,
|
||||
max_initial_timestamp=0.0,
|
||||
word_timestamps=False,
|
||||
prepend_punctuations="\"'“¿([{-",
|
||||
append_punctuations="\"'.。,,!!??::”)]}、"
|
||||
)
|
||||
whisper_arch = "large-v2"
|
||||
device = "cuda"
|
||||
batch_size = 16
|
||||
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
|
||||
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
|
||||
fn = "DanielKahneman_2010.wav"
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
vad_model = load_vad_model("cuda", 0.6, 0.3)
|
||||
audio = load_audio(os.path.join(wav_dir, fn))
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
|
||||
def data(audio, segments):
|
||||
for seg in segments:
|
||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
vad_method="pyannote"
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
if vad_method == "pyannote":
|
||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(vad_segments, 30)
|
||||
elif vad_method == "silero":
|
||||
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
|
||||
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
|
||||
new_segs = []
|
||||
curr_start = vad_segments[0]['start']
|
||||
curr_end = vad_segments[0]['end']
|
||||
for seg in vad_segments[1:]:
|
||||
if seg['end'] - curr_start > 30:
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
curr_start = seg['start']
|
||||
curr_end = seg['end']
|
||||
else:
|
||||
curr_end = seg['end']
|
||||
new_segs.append({"start": curr_start, "end": curr_end})
|
||||
vad_segments = new_segs
|
||||
text = []
|
||||
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
|
||||
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
|
||||
text.append(out['text'])
|
||||
t2 = time.time()
|
||||
if batch_size == 1:
|
||||
text = [x[0] for x in text]
|
||||
text = " ".join(text)
|
||||
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
elif main_type == "simple":
|
||||
model = load_model(
|
||||
"large-v2",
|
||||
device="cuda",
|
||||
language="en",
|
||||
)
|
||||
|
||||
wav_dir = f"/tmp/test/wav/"
|
||||
wer_li = []
|
||||
time_li = []
|
||||
for fn in os.listdir(wav_dir):
|
||||
if fn == "RobertGupta_2010U.wav":
|
||||
continue
|
||||
# fn = "DanielKahneman_2010.wav"
|
||||
base_fn = fn.split('.')[0]
|
||||
audio_fp = os.path.join(wav_dir, fn)
|
||||
|
||||
audio = load_audio(audio_fp)
|
||||
t1 = time.time()
|
||||
out = model.transcribe(audio_fp, batch_size=8)["segments"]
|
||||
t2 = time.time()
|
||||
|
||||
text = " ".join([x['text'] for x in out])
|
||||
normalizer = EnglishTextNormalizer()
|
||||
text = normalizer(text)
|
||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
||||
|
||||
wer_result = jiwer.wer(gt_corpus, text)
|
||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
||||
|
||||
wer_li.append(wer_result)
|
||||
time_li.append(t2-t1)
|
||||
print("# Avg Mean...")
|
||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
||||
|
@ -11,7 +11,6 @@ class DiarizationPipeline:
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
@ -21,59 +20,44 @@ class DiarizationPipeline:
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
||||
return diarize_df
|
||||
|
||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||
for seg in result_segments:
|
||||
wdf = seg['word-segments']
|
||||
if len(wdf['start'].dropna()) == 0:
|
||||
wdf['start'] = seg['start']
|
||||
wdf['end'] = seg['end']
|
||||
speakers = []
|
||||
for wdx, wrow in wdf.iterrows():
|
||||
if not np.isnan(wrow['start']):
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) == 0:
|
||||
speaker = None
|
||||
else:
|
||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
||||
else:
|
||||
speaker = None
|
||||
speakers.append(speaker)
|
||||
seg['word-segments']['speaker'] = speakers
|
||||
|
||||
speaker_count = pd.Series(speakers).value_counts()
|
||||
if len(speaker_count) == 0:
|
||||
seg["speaker"]= "UNKNOWN"
|
||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
||||
# remove no hit, otherwise we look for closest (even negative intersection...)
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
seg["speaker"] = speaker_count.index[0]
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
seg["speaker"] = speaker
|
||||
|
||||
# assign speaker to words
|
||||
if 'words' in seg:
|
||||
for word in seg['words']:
|
||||
if 'start' in word:
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) > 0:
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
word["speaker"] = speaker
|
||||
|
||||
return transcript_result
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
for seg in result_segments:
|
||||
wseg = pd.DataFrame(seg["word-segments"])
|
||||
for wdx, wrow in wseg.iterrows():
|
||||
if wrow["start"] is not None:
|
||||
speaker = wrow['speaker']
|
||||
if speaker is None or speaker == np.nan:
|
||||
speaker = "UNKNOWN"
|
||||
word_seg.append(
|
||||
{
|
||||
"start": wrow["start"],
|
||||
"end": wrow["end"],
|
||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: create segments but split words on new speaker
|
||||
|
||||
return result_segments, word_seg
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
|
@ -64,14 +64,11 @@ def cli():
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -97,7 +94,6 @@ def cli():
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
|
||||
# TODO: check model loading works.
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
@ -176,6 +172,7 @@ def cli():
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
|
||||
|
||||
results.append((result, audio_path))
|
||||
|
||||
# Unload align model
|
||||
@ -193,18 +190,10 @@ def cli():
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
result = {"segments": results_segments, "word_segments": word_segments}
|
||||
result = assign_word_speakers(diarize_segments, result)
|
||||
results.append((result, input_audio_path))
|
||||
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
# Remove pandas dataframes from result so that
|
||||
# we can serialize the result with json
|
||||
for seg in result["segments"]:
|
||||
seg.pop("word-segments", None)
|
||||
seg.pop("char-segments", None)
|
||||
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
last = result["segments"][0]["words"][0]["start"]
|
||||
times = []
|
||||
last = result["segments"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
||||
long_pause = not preserve_segments
|
||||
if "start" in timing:
|
||||
long_pause = long_pause and timing["start"] - last > 3.0
|
||||
else:
|
||||
long_pause = False
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
yield subtitle, times
|
||||
subtitle = []
|
||||
times = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
@ -260,40 +266,53 @@ class SubtitlesWriter(ResultWriter):
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
times.append((segment["start"], segment["end"], segment.get("speaker")))
|
||||
if "start" in timing:
|
||||
last = timing["start"]
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
yield subtitle, times
|
||||
|
||||
if "words" in result["segments"][0]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
if highlight_words:
|
||||
for subtitle, _ in iterate_subtitles():
|
||||
sstart, ssend, speaker = _[0]
|
||||
subtitle_start = self.format_timestamp(sstart)
|
||||
subtitle_end = self.format_timestamp(ssend)
|
||||
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||
has_timing = any(["start" in word for word in subtitle])
|
||||
|
||||
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||
prefix = ""
|
||||
if speaker is not None:
|
||||
prefix = f"[{speaker}]: "
|
||||
|
||||
if highlight_words and has_timing:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
if "start" in this_word:
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
yield start, end, prefix + " ".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, subtitle_text
|
||||
yield subtitle_start, subtitle_end, prefix + subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
if "speaker" in segment:
|
||||
segment_text = f"[{segment['speaker']}]: {segment_text}"
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
|
Reference in New Issue
Block a user