resolve conflicts

This commit is contained in:
Yasutaka Odo
2022-12-21 01:20:35 +09:00
3 changed files with 168 additions and 122 deletions

View File

@ -28,7 +28,7 @@ def transcribe(
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
**decode_options,
):
"""
@ -258,6 +258,7 @@ def align(
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
drop_non_aligned_words: bool = False,
):
print("Performing alignment...")
if not torch.is_tensor(audio):
@ -270,6 +271,7 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
prev_t2 = 0
word_segments_list = []
for idx, segment in enumerate(transcript):
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
@ -319,8 +321,7 @@ def align(
segment['end'] = t2_actual
prev_t2 = segment['end']
# merge missing words to previous, or merge with next word ahead if idx == 0
# for the .ass output
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
@ -329,15 +330,29 @@ def align(
else:
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
# for per-word .srt ouput
# merge missing words to previous, or merge with next word ahead if idx == 0
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
elif not drop_non_aligned_words:
# then we merge
if x == 0:
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
else:
word_segments_list[-1]['text'] += ' ' + curr_word
else:
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
prev_t2 = 0
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
return {"segments": transcript}
return {"segments": transcript, "word_segments": word_segments_list}
def cli():
from . import available_models
@ -348,9 +363,10 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
# alignment params
parser.add_argument("--align_model", default="WAV2VEC2_ASR_LARGE_LV60K_960H", help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="directory to save the outputs")
@ -387,7 +403,7 @@ def cli():
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
# align_interpolate_missing: bool = args.pop("align_interpolate_missing")
drop_non_aligned: bool = args.pop("drop_non_aligned")
os.makedirs(output_dir, exist_ok=True)
@ -421,12 +437,13 @@ def cli():
labels = processor.tokenizer.get_vocab()
align_dictionary = processor.tokenizer.get_vocab()
else:
print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface}')
print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface} \n\
See details here https://pytorch.org/audio/stable/pipelines.html#id14')
raise ValueError(f'Align model "{align_model}" not supported')
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev)
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
audio_basename = os.path.basename(audio_path)
# save TXT
@ -444,6 +461,10 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as srt:
write_ass(result_aligned["segments"], file=srt)