add back word .srt, update readme

This commit is contained in:
Max Bain
2022-12-19 19:12:50 +00:00
parent 6b64cb079a
commit 228b857597
3 changed files with 147 additions and 106 deletions

View File

@ -255,6 +255,7 @@ def align(
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
drop_non_aligned_words: bool = False,
):
print("Performing alignment...")
if not torch.is_tensor(audio):
@ -267,6 +268,7 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
prev_t2 = 0
word_segments_list = []
for idx, segment in enumerate(transcript):
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
@ -313,8 +315,7 @@ def align(
segment['end'] = t2_actual
prev_t2 = segment['end']
# merge missing words to previous, or merge with next word ahead if idx == 0
# for the .ass output
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
@ -323,15 +324,29 @@ def align(
else:
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
# for per-word .srt ouput
# merge missing words to previous, or merge with next word ahead if idx == 0
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
elif not drop_non_aligned_words:
# then we merge
if x == 0:
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
else:
word_segments_list[-1]['text'] += ' ' + curr_word
else:
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
prev_t2 = 0
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
return {"segments": transcript}
return {"segments": transcript, "word_segments": word_segments_list}
def cli():
from . import available_models
@ -342,9 +357,10 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
# alignment params
parser.add_argument("--align_model", default="WAV2VEC2_ASR_LARGE_LV60K_960H", help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="directory to save the outputs")
@ -381,7 +397,7 @@ def cli():
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
# align_interpolate_missing: bool = args.pop("align_interpolate_missing")
drop_non_aligned: bool = args.pop("drop_non_aligned")
os.makedirs(output_dir, exist_ok=True)
@ -409,12 +425,14 @@ def cli():
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n\
{torchaudio.pipelines.__all__}\n\
See details here https://pytorch.org/audio/stable/pipelines.html#id14')
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev)
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
audio_basename = os.path.basename(audio_path)
# save TXT
@ -432,6 +450,10 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as srt:
write_ass(result_aligned["segments"], file=srt)