mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add back word .srt, update readme
This commit is contained in:
@ -255,6 +255,7 @@ def align(
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
drop_non_aligned_words: bool = False,
|
||||
):
|
||||
print("Performing alignment...")
|
||||
if not torch.is_tensor(audio):
|
||||
@ -267,6 +268,7 @@ def align(
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
prev_t2 = 0
|
||||
word_segments_list = []
|
||||
for idx, segment in enumerate(transcript):
|
||||
t1 = max(segment['start'] - extend_duration, 0)
|
||||
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
|
||||
@ -313,8 +315,7 @@ def align(
|
||||
segment['end'] = t2_actual
|
||||
prev_t2 = segment['end']
|
||||
|
||||
|
||||
# merge missing words to previous, or merge with next word ahead if idx == 0
|
||||
# for the .ass output
|
||||
for x in range(len(t_local)):
|
||||
curr_word = t_words[x]
|
||||
curr_timestamp = t_local[x]
|
||||
@ -323,15 +324,29 @@ def align(
|
||||
else:
|
||||
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
|
||||
|
||||
# for per-word .srt ouput
|
||||
# merge missing words to previous, or merge with next word ahead if idx == 0
|
||||
for x in range(len(t_local)):
|
||||
curr_word = t_words[x]
|
||||
curr_timestamp = t_local[x]
|
||||
if curr_timestamp is not None:
|
||||
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
|
||||
elif not drop_non_aligned_words:
|
||||
# then we merge
|
||||
if x == 0:
|
||||
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
|
||||
else:
|
||||
word_segments_list[-1]['text'] += ' ' + curr_word
|
||||
else:
|
||||
# then we resort back to original whisper timestamps
|
||||
# segment['start] and segment['end'] are unchanged
|
||||
prev_t2 = 0
|
||||
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||
|
||||
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
|
||||
|
||||
return {"segments": transcript}
|
||||
return {"segments": transcript, "word_segments": word_segments_list}
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
@ -342,9 +357,10 @@ def cli():
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default="WAV2VEC2_ASR_LARGE_LV60K_960H", help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
||||
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
||||
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="directory to save the outputs")
|
||||
@ -381,7 +397,7 @@ def cli():
|
||||
align_model: str = args.pop("align_model")
|
||||
align_extend: float = args.pop("align_extend")
|
||||
align_from_prev: bool = args.pop("align_from_prev")
|
||||
# align_interpolate_missing: bool = args.pop("align_interpolate_missing")
|
||||
drop_non_aligned: bool = args.pop("drop_non_aligned")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@ -409,12 +425,14 @@ def cli():
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n {torchaudio.pipelines.__all__}')
|
||||
print(f'Align model "{align_model}" not found in torchaudio.pipelines, choose from:\n\
|
||||
{torchaudio.pipelines.__all__}\n\
|
||||
See details here https://pytorch.org/audio/stable/pipelines.html#id14')
|
||||
raise ValueError(f'Align model "{align_model}" not found in torchaudio.pipelines')
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
result_aligned = align(result["segments"], align_model, align_dictionary, audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev)
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
# save TXT
|
||||
@ -432,6 +450,10 @@ def cli():
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result_aligned["segments"], file=srt)
|
||||
|
||||
# save per-word SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result_aligned["word_segments"], file=srt)
|
||||
|
||||
# save ASS
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as srt:
|
||||
write_ass(result_aligned["segments"], file=srt)
|
||||
|
Reference in New Issue
Block a user