mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
update setup.py to install pyannote.audio==3.1.1, update diarize.py to include num_speakers; to fix Issue #592
This commit is contained in:
227
build/lib/whisperx/SubtitlesProcessor.py
Normal file
227
build/lib/whisperx/SubtitlesProcessor.py
Normal file
@ -0,0 +1,227 @@
|
||||
import math
|
||||
from conjunctions import get_conjunctions, get_comma
|
||||
from typing import TextIO
|
||||
|
||||
def normal_round(n):
|
||||
if n - math.floor(n) < 0.5:
|
||||
return math.floor(n)
|
||||
return math.ceil(n)
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, is_vtt: bool = False):
|
||||
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
separator = '.' if is_vtt else ','
|
||||
|
||||
hours_marker = f"{hours:02d}:"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class SubtitlesProcessor:
|
||||
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
|
||||
self.comma = get_comma(lang)
|
||||
self.conjunctions = set(get_conjunctions(lang))
|
||||
self.segments = segments
|
||||
self.lang = lang
|
||||
self.max_line_length = max_line_length
|
||||
self.min_char_length_splitter = min_char_length_splitter
|
||||
self.is_vtt = is_vtt
|
||||
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
|
||||
if self.lang in complex_script_languages:
|
||||
self.max_line_length = 30
|
||||
self.min_char_length_splitter = 20
|
||||
|
||||
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
|
||||
k = 0.25
|
||||
has_prev_end = i > 0 and 'end' in words[i - 1]
|
||||
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
|
||||
|
||||
if has_prev_end:
|
||||
words[i]['start'] = words[i - 1]['end']
|
||||
if has_next_start:
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
|
||||
|
||||
elif has_next_start:
|
||||
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
|
||||
words[i]['end'] = words[i + 1]['start']
|
||||
|
||||
else:
|
||||
if next_segment_start_time:
|
||||
words[i]['start'] = next_segment_start_time - 1
|
||||
words[i]['end'] = next_segment_start_time - 0.5
|
||||
else:
|
||||
words[i]['start'] = 0
|
||||
words[i]['end'] = 0
|
||||
|
||||
|
||||
|
||||
def process_segments(self, advanced_splitting=True):
|
||||
subtitles = []
|
||||
for i, segment in enumerate(self.segments):
|
||||
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
|
||||
|
||||
if advanced_splitting:
|
||||
|
||||
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
|
||||
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
|
||||
else:
|
||||
words = segment['words']
|
||||
for i, word in enumerate(words):
|
||||
if 'start' not in word or 'end' not in word:
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
subtitles.append({
|
||||
'start': segment['start'],
|
||||
'end': segment['end'],
|
||||
'text': segment['text']
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
|
||||
split_points = []
|
||||
last_split_point = 0
|
||||
char_count = 0
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
add_space = 0 if self.lang in ['zh', 'ja'] else 1
|
||||
|
||||
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
|
||||
char_count_after = total_char_count
|
||||
|
||||
for i, word in enumerate(words):
|
||||
word_text = word['word'] if isinstance(word, dict) else word
|
||||
word_length = len(word_text) + add_space
|
||||
char_count += word_length
|
||||
char_count_after -= word_length
|
||||
|
||||
char_count_before = char_count - word_length
|
||||
|
||||
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
|
||||
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
||||
|
||||
if char_count >= self.max_line_length:
|
||||
midpoint = normal_round((last_split_point + i) / 2)
|
||||
if char_count_before >= self.min_char_length_splitter:
|
||||
split_points.append(midpoint)
|
||||
last_split_point = midpoint + 1
|
||||
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
|
||||
|
||||
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i)
|
||||
last_split_point = i + 1
|
||||
char_count = 0
|
||||
|
||||
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
||||
split_points.append(i - 1)
|
||||
last_split_point = i
|
||||
char_count = word_length
|
||||
|
||||
return split_points
|
||||
|
||||
|
||||
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
|
||||
subtitles = []
|
||||
|
||||
words = segment.get('words', segment['text'].split())
|
||||
total_word_count = len(words)
|
||||
total_time = segment['end'] - segment['start']
|
||||
elapsed_time = segment['start']
|
||||
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
|
||||
start_idx = 0
|
||||
for split_point in split_points:
|
||||
|
||||
fragment_words = words[start_idx:split_point + 1]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
|
||||
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
|
||||
end_time = next_start_time_for_word
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
elapsed_time += current_duration
|
||||
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time,
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
start_idx = split_point + 1
|
||||
|
||||
# Handle the last fragment
|
||||
if start_idx < len(words):
|
||||
fragment_words = words[start_idx:]
|
||||
current_word_count = len(fragment_words)
|
||||
|
||||
if isinstance(fragment_words[0], dict):
|
||||
start_time = fragment_words[0]['start']
|
||||
end_time = fragment_words[-1]['end']
|
||||
else:
|
||||
fragment = prefix.join(fragment_words).strip()
|
||||
current_duration = (current_word_count / total_word_count) * total_time
|
||||
start_time = elapsed_time
|
||||
end_time = elapsed_time + current_duration
|
||||
|
||||
if next_start_time and (next_start_time - end_time) <= 0.8:
|
||||
end_time = next_start_time
|
||||
|
||||
subtitles.append({
|
||||
'start': start_time,
|
||||
'end': end_time if end_time is not None else segment['end'],
|
||||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
||||
})
|
||||
|
||||
return subtitles
|
||||
|
||||
|
||||
|
||||
def save(self, filename="subtitles.srt", advanced_splitting=True):
|
||||
|
||||
subtitles = self.process_segments(advanced_splitting)
|
||||
|
||||
def write_subtitle(file, idx, start_time, end_time, text):
|
||||
|
||||
file.write(f"{idx}\n")
|
||||
file.write(f"{start_time} --> {end_time}\n")
|
||||
file.write(text + "\n\n")
|
||||
|
||||
with open(filename, 'w', encoding='utf-8') as file:
|
||||
if self.is_vtt:
|
||||
file.write("WEBVTT\n\n")
|
||||
|
||||
if advanced_splitting:
|
||||
for idx, subtitle in enumerate(subtitles, 1):
|
||||
start_time = format_timestamp(subtitle['start'], self.is_vtt)
|
||||
end_time = format_timestamp(subtitle['end'], self.is_vtt)
|
||||
text = subtitle['text'].strip()
|
||||
write_subtitle(file, idx, start_time, end_time, text)
|
||||
|
||||
return len(subtitles)
|
Reference in New Issue
Block a user