mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
new logic, diarization, vad filtering
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import zlib
|
||||
from typing import Iterator, TextIO, Tuple, List
|
||||
|
||||
from typing import Callable, TextIO, Iterator, Tuple
|
||||
import pandas as pd
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in transcript:
|
||||
print(round(1000 * segment['start']), file=file, end="\t")
|
||||
print(round(1000 * segment['end']), file=file, end="\t")
|
||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
"""
|
||||
@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
)
|
||||
|
||||
|
||||
def write_ass(transcript: Iterator[dict], file: TextIO,
|
||||
def write_ass(transcript: Iterator[dict],
|
||||
file: TextIO,
|
||||
resolution: str = "word",
|
||||
color: str = None, underline=True,
|
||||
prefmt: str = None, suffmt: str = None,
|
||||
font: str = None, font_size: int = 24,
|
||||
@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
|
||||
Note: ass file is used in the same way as srt, vtt, etc.
|
||||
Parameters
|
||||
----------
|
||||
res: dict
|
||||
transcript: dict
|
||||
results from modified model
|
||||
ass_path: str
|
||||
output path (e.g. caption.ass)
|
||||
file: TextIO
|
||||
file object to write to
|
||||
resolution: str
|
||||
"word" or "char", timestamp resolution to highlight.
|
||||
color: str
|
||||
color code for a word at its corresponding timestamp
|
||||
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
|
||||
@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
|
||||
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
|
||||
|
||||
|
||||
def dialogue(words: List[str], idx, start, end) -> str:
|
||||
text = ''.join(f' {prefmt}{word}{suffmt}'
|
||||
# if not word.startswith(' ') or word == ' ' else
|
||||
# f' {prefmt}{word.strip()}{suffmt}')
|
||||
if curr_idx == idx else
|
||||
f' {word}'
|
||||
for curr_idx, word in enumerate(words))
|
||||
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
|
||||
if idx_0 == -1:
|
||||
text = chars
|
||||
else:
|
||||
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
|
||||
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
|
||||
f"Default,,0,0,0,,{text.strip() if strip else text}"
|
||||
|
||||
|
||||
if resolution == "word":
|
||||
resolution_key = "word-segments"
|
||||
elif resolution == "char":
|
||||
resolution_key = "char-segments"
|
||||
else:
|
||||
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
|
||||
|
||||
ass_arr = []
|
||||
|
||||
for segment in transcript:
|
||||
curr_words = [wrd['text'] for wrd in segment['word-level']]
|
||||
prev = segment['word-level'][0]['start']
|
||||
if prev is None:
|
||||
if resolution_key in segment:
|
||||
res_segs = pd.DataFrame(segment[resolution_key])
|
||||
prev = segment['start']
|
||||
for wdx, word in enumerate(segment['word-level']):
|
||||
if word['start'] is not None:
|
||||
# fill gap between previous word
|
||||
if word['start'] > prev:
|
||||
filler_ts = {
|
||||
"words": curr_words,
|
||||
"start": prev,
|
||||
"end": word['start'],
|
||||
"idx": -1
|
||||
if "speaker" in segment:
|
||||
speaker_str = f"[{segment['speaker']}]: "
|
||||
else:
|
||||
speaker_str = ""
|
||||
for cdx, crow in res_segs.iterrows():
|
||||
if crow['start'] is not None:
|
||||
if resolution == "char":
|
||||
idx_0 = cdx
|
||||
idx_1 = cdx + 1
|
||||
elif resolution == "word":
|
||||
idx_0 = int(crow["segment-text-start"])
|
||||
idx_1 = int(crow["segment-text-end"])
|
||||
# fill gap
|
||||
if crow['start'] > prev:
|
||||
filler_ts = {
|
||||
"chars": speaker_str + segment['text'],
|
||||
"start": prev,
|
||||
"end": crow['start'],
|
||||
"idx_0": -1,
|
||||
"idx_1": -1
|
||||
}
|
||||
|
||||
ass_arr.append(filler_ts)
|
||||
# highlight current word
|
||||
f_word_ts = {
|
||||
"chars": speaker_str + segment['text'],
|
||||
"start": crow['start'],
|
||||
"end": crow['end'],
|
||||
"idx_0": idx_0 + len(speaker_str),
|
||||
"idx_1": idx_1 + len(speaker_str)
|
||||
}
|
||||
ass_arr.append(filler_ts)
|
||||
|
||||
# highlight current word
|
||||
f_word_ts = {
|
||||
"words": curr_words,
|
||||
"start": word['start'],
|
||||
"end": word['end'],
|
||||
"idx": wdx
|
||||
}
|
||||
ass_arr.append(f_word_ts)
|
||||
|
||||
prev = word['end']
|
||||
|
||||
|
||||
ass_arr.append(f_word_ts)
|
||||
prev = crow['end']
|
||||
|
||||
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
|
||||
|
||||
file.write(ass_str)
|
||||
|
||||
def interpolate_nans(x, method='nearest'):
|
||||
if x.notnull().sum() > 1:
|
||||
return x.interpolate(method=method).ffill().bfill()
|
||||
else:
|
||||
return x.ffill().bfill()
|
Reference in New Issue
Block a user