new logic, diarization, vad filtering

This commit is contained in:
Max Bain
2023-01-24 15:02:08 +00:00
parent ba102feb7f
commit d395c21b83
8 changed files with 498 additions and 260 deletions

View File

@ -1,6 +1,7 @@
import os
import zlib
from typing import Iterator, TextIO, Tuple, List
from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
def exact_div(x, y):
assert x % y == 0
@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
flush=True,
)
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
)
def write_ass(transcript: Iterator[dict], file: TextIO,
def write_ass(transcript: Iterator[dict],
file: TextIO,
resolution: str = "word",
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
res: dict
transcript: dict
results from modified model
ass_path: str
output path (e.g. caption.ass)
file: TextIO
file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def dialogue(words: List[str], idx, start, end) -> str:
text = ''.join(f' {prefmt}{word}{suffmt}'
# if not word.startswith(' ') or word == ' ' else
# f' {prefmt}{word.strip()}{suffmt}')
if curr_idx == idx else
f' {word}'
for curr_idx, word in enumerate(words))
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
if idx_0 == -1:
text = chars
else:
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = []
for segment in transcript:
curr_words = [wrd['text'] for wrd in segment['word-level']]
prev = segment['word-level'][0]['start']
if prev is None:
if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
for wdx, word in enumerate(segment['word-level']):
if word['start'] is not None:
# fill gap between previous word
if word['start'] > prev:
filler_ts = {
"words": curr_words,
"start": prev,
"end": word['start'],
"idx": -1
if "speaker" in segment:
speaker_str = f"[{segment['speaker']}]: "
else:
speaker_str = ""
for cdx, crow in res_segs.iterrows():
if crow['start'] is not None:
if resolution == "char":
idx_0 = cdx
idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"chars": speaker_str + segment['text'],
"start": prev,
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"words": curr_words,
"start": word['start'],
"end": word['end'],
"idx": wdx
}
ass_arr.append(f_word_ts)
prev = word['end']
ass_arr.append(f_word_ts)
prev = crow['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()