clean up logic, use pandas where possibl

This commit is contained in:
Max Bain
2023-01-25 18:42:52 +00:00
parent eec6d1f8d8
commit 286a2f2c14
5 changed files with 426 additions and 429 deletions

View File

@ -2,7 +2,7 @@
## Other Languages ## Other Languages
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22). For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl} Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}

View File

@ -11,8 +11,8 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions from .model import Whisper, ModelDimensions
from .transcribe import transcribe, load_align_model, align, transcribe_with_vad from .transcribe import transcribe, transcribe_with_vad
from .alignment import load_align_model, align
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",

View File

@ -1,9 +1,412 @@
""""
Forced Alignment with Whisper
C. Max Bain
"""
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import AutoProcessor, Wav2Vec2ForCTC
import torchaudio
import torch
from dataclasses import dataclass
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = AutoProcessor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def align(
transcript: Iterator[dict],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"]
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"]
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"]
word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"]
word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1
word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"]
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]]
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"])
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"])
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"])
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge subsegments which are missing times
# group by sub seg and time.
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
sub_start = int(srow["subsegment-idx-start"])
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
cseg['segment-text-start'] = cseg['level_1']
cseg['segment-text-end'] = cseg['level_1'] + 1
del cseg['level_1']
del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
""" """
import torch
from dataclasses import dataclass
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)

View File

@ -5,12 +5,11 @@ from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
import torchaudio
from transformers import AutoProcessor, Wav2Vec2ForCTC
import tqdm import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .alignment import get_trellis, backtrack, merge_repeats, merge_words from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .diarize import assign_word_speakers, Segment
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
import pandas as pd import pandas as pd
@ -18,23 +17,6 @@ import pandas as pd
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
}
def transcribe( def transcribe(
@ -273,355 +255,11 @@ def transcribe(
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
def align(
transcript: Iterator[dict],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
for segment in transcript:
aligned_subsegments = []
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
seg_start_actual, seg_end_actual = None, None
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
cdx_prev = 0
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_level["start"].append(start)
char_level["end"].append(end)
char_level["score"].append(score)
char_level["word-index"].append(wdx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
word_level["start"].append(None)
word_level["end"].append(None)
word_level["score"].append(None)
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
cdx_prev = cdx+2
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_level = pd.DataFrame(char_level)
word_level = pd.DataFrame(word_level)
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
# fill missing
if interpolate_method != "ignore":
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
word_level["start"] = word_level["start"].values.tolist()
word_level["end"] = word_level["end"].values.tolist()
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
char_level = char_level.replace({np.nan:None}).to_dict("list")
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else:
word_level = None
aligned_subsegments.append(
{
"text": segment["seg-text"][sub_seg_idx],
"start": seg_start_actual,
"end": seg_end_actual,
"char-segments": char_level,
"word-segments": word_level
}
)
if "language" in segment:
aligned_subsegments[-1]["language"] = segment["language"]
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx += 1
seg_start_actual, seg_end_actual = None, None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None:
seg_start_actual = start
if end is not None:
seg_end_actual = end
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
start = interpolate_nans(pd.DataFrame(aligned_subsegments)["start"], method=interpolate_method)
end = interpolate_nans(pd.DataFrame(aligned_subsegments)["end"], method=interpolate_method)
for idx, seg in enumerate(aligned_subsegments):
seg['start'] = start.iloc[idx]
seg['end'] = end.iloc[idx]
aligned_segments += aligned_subsegments
# create word level segments for .srt
word_seg = []
for seg in aligned_segments:
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character based
seg["word-segments"] = seg["char-segments"]
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = AutoProcessor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH): def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
""" """
Merge VAD segments into larger segments of size ~CHUNK_LENGTH. Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
TODO: Or sliding window alignment model over long segment.
""" """
curr_start = 0 curr_start = 0
curr_end = 0 curr_end = 0
@ -702,58 +340,6 @@ def transcribe_with_vad(
return output return output
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = pd.DataFrame(seg['word-segments'])
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker
def cli(): def cli():
from . import available_models from . import available_models
@ -776,7 +362,7 @@ def cli():
parser.add_argument("--max_speakers", default=None, type=int) parser.add_argument("--max_speakers", default=None, type=int)
# output save params # output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save") parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -868,6 +454,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print("Performing alignment...") print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
@ -915,10 +502,16 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass) write_ass(result_aligned["segments"], file=ass)
# save ASS character-level # # save ASS character-level
if output_type in ["ass-char", "all"]: if output_type in ["ass-char"]:
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass: with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass, resolution="char") write_ass(result_aligned["segments"], file=ass, resolution="char")
# save word tsv
if output_type in ["pickle"]:
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -2,6 +2,7 @@ import os
import zlib import zlib
from typing import Callable, TextIO, Iterator, Tuple from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd import pandas as pd
import numpy as np
def exact_div(x, y): def exact_div(x, y):
assert x % y == 0 assert x % y == 0
@ -214,7 +215,7 @@ def write_ass(transcript: Iterator[dict],
else: else:
speaker_str = "" speaker_str = ""
for cdx, crow in res_segs.iterrows(): for cdx, crow in res_segs.iterrows():
if crow['start'] is not None: if not np.isnan(crow['start']):
if resolution == "char": if resolution == "char":
idx_0 = cdx idx_0 = cdx
idx_1 = cdx + 1 idx_1 = cdx + 1