skeleton v2

This commit is contained in:
Max Bain
2023-03-30 05:31:57 +01:00
parent 1e7c2c337b
commit 18b63d46e2
53 changed files with 752 additions and 106949 deletions

View File

@ -4,48 +4,12 @@ from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
import numpy as np
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
return x.ffill().bfill()
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
@ -250,8 +214,92 @@ def write_ass(transcript: Iterator[dict],
file.write(ass_str)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()
from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
class WriteASS(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resoltuion="word")
class WriteASSchar(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resoltuion="char")
class WritePickle(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
pd.DataFrame(result["segments"]).to_pickle(file)
class WriteSRTWord(ResultWriter):
extension: str = ".word.srt"
always_include_hours: bool = True
decimal_marker: str = ","
def iterate_result(self, result: dict):
for segment in result["word_segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text
def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
# "json": WriteJSON,
"ass": WriteASS,
# "ass-char": WriteASSchar,
# "pickle": WritePickle,
"srt-word": WriteSRTWord,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)