Update alignment.py

This commit is contained in:
awerks
2023-08-16 16:18:00 +02:00
committed by GitHub
parent 72685d0398
commit 65688208c9

View File

@ -98,6 +98,7 @@ def align(
device: str, device: str,
interpolate_method: str = "nearest", interpolate_method: str = "nearest",
return_char_alignments: bool = False, return_char_alignments: bool = False,
print_progress = False
) -> AlignedTranscriptionResult: ) -> AlignedTranscriptionResult:
""" """
Align phoneme recognition predictions to known transcription. Align phoneme recognition predictions to known transcription.
@ -116,9 +117,16 @@ def align(
model_lang = align_model_metadata["language"] model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"] model_type = align_model_metadata["type"]
total_segments = len(list(transcript))
transcript = iter(transcript)
# 1. Preprocess to keep only characters in dictionary # 1. Preprocess to keep only characters in dictionary
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount. # strip spaces at beginning / end, but keep track of the amount.
if print_progress:
percent_complete = ((sdx + 1) / total_segments) * 100
print(f"Progress: {percent_complete:.2f}%...")
num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"] text = segment["text"]
@ -162,15 +170,10 @@ def align(
segment["sentence_spans"] = sentence_spans segment["sentence_spans"] = sentence_spans
aligned_segments: List[SingleAlignedSegment] = [] aligned_segments: List[SingleAlignedSegment] = []
total_segments = len(list(transcript))
transcript = iter(transcript)
# 2. Get prediction matrix from alignment model & align # 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
percent_complete = ((sdx + 1) / total_segments) * 100
print(f"Progress: {percent_complete:.2f}%...")
t1 = segment["start"] t1 = segment["start"]
t2 = segment["end"] t2 = segment["end"]
text = segment["text"] text = segment["text"]