mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 10:07:28 -04:00
clean up logic, use pandas where possibl
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
## Other Languages
|
||||
|
||||
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
|
||||
|
||||
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
|
||||
|
||||
|
@ -11,8 +11,8 @@ from tqdm import tqdm
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .transcribe import transcribe, load_align_model, align, transcribe_with_vad
|
||||
|
||||
from .transcribe import transcribe, transcribe_with_vad
|
||||
from .alignment import load_align_model, align
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
|
@ -1,9 +1,412 @@
|
||||
""""
|
||||
Forced Alignment with Whisper
|
||||
C. Max Bain
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Union, Iterator, TYPE_CHECKING
|
||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||
import torchaudio
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
DEFAULT_ALIGN_MODELS_TORCH = {
|
||||
"en": "WAV2VEC2_ASR_BASE_960H",
|
||||
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
||||
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
||||
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
||||
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
||||
}
|
||||
|
||||
DEFAULT_ALIGN_MODELS_HF = {
|
||||
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
||||
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
||||
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||
}
|
||||
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||
else:
|
||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
||||
raise ValueError(f"No default align-model for language: {language_code}")
|
||||
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
pipeline_type = "torchaudio"
|
||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
||||
pipeline_type = "huggingface"
|
||||
align_model = align_model.to(device)
|
||||
labels = processor.tokenizer.get_vocab()
|
||||
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
||||
|
||||
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
||||
|
||||
return align_model, align_metadata
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
interpolate_method: str = "nearest",
|
||||
):
|
||||
"""
|
||||
Force align phoneme recognition predictions to known transcription
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcript: Iterator[dict]
|
||||
The Whisper model instance
|
||||
|
||||
model: torch.nn.Module
|
||||
Alignment model (wav2vec2)
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
device: str
|
||||
cuda device
|
||||
|
||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
||||
diarization segments with speaker labels.
|
||||
|
||||
extend_duration: float
|
||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
model_lang = align_model_metadata["language"]
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
aligned_segments = []
|
||||
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = {
|
||||
"segment-idx": [],
|
||||
"subsegment-idx": [],
|
||||
"word-idx": [],
|
||||
"char": [],
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
}
|
||||
|
||||
for sdx, segment in enumerate(transcript):
|
||||
while True:
|
||||
segment_align_success = False
|
||||
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
transcription = segment["text"]
|
||||
|
||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||
# e.g. "$300" -> "three hundred dollars"
|
||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||
|
||||
# split into words
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = transcription.split(" ")
|
||||
else:
|
||||
per_word = transcription
|
||||
|
||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(transcription):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(transcription) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# if no characters are in the dictionary, then we skip this segment...
|
||||
if len(clean_char) == 0:
|
||||
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
|
||||
break
|
||||
|
||||
transcription_cleaned = "".join(clean_char)
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
|
||||
# pad according original timestamps
|
||||
t1 = max(segment["start"] - extend_duration, 0)
|
||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||
|
||||
# use prev_t2 as current t1 if it"s later
|
||||
if start_from_previous and t1 < prev_t2:
|
||||
t1 = prev_t2
|
||||
|
||||
# check if timestamp range is still valid
|
||||
if t1 >= MAX_DURATION:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
break
|
||||
if t2 - t1 < 0.02:
|
||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||
break
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
trellis = get_trellis(emission, tokens)
|
||||
path = backtrack(trellis, emission, tokens)
|
||||
if path is None:
|
||||
print("Failed to align segment: backtrack failed, resorting to original...")
|
||||
break
|
||||
char_segments = merge_repeats(path, transcription_cleaned)
|
||||
# word_segments = merge_words(char_segments)
|
||||
|
||||
|
||||
# sub-segments
|
||||
if "seg-text" not in segment:
|
||||
segment["seg-text"] = [transcription]
|
||||
|
||||
v = 0
|
||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||
seg_lens_cumsum = [v := v + n for n in seg_lens]
|
||||
sub_seg_idx = 0
|
||||
|
||||
wdx = 0
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
for cdx, char in enumerate(transcription + " "):
|
||||
is_last = False
|
||||
if cdx == len(transcription):
|
||||
break
|
||||
elif cdx+1 == len(transcription):
|
||||
is_last = True
|
||||
|
||||
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = char_seg.start * ratio + t1
|
||||
end = char_seg.end * ratio + t1
|
||||
score = char_seg.score
|
||||
|
||||
char_segments_arr["char"].append(char)
|
||||
char_segments_arr["start"].append(start)
|
||||
char_segments_arr["end"].append(end)
|
||||
char_segments_arr["score"].append(score)
|
||||
char_segments_arr["word-idx"].append(wdx)
|
||||
char_segments_arr["segment-idx"].append(sdx)
|
||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
||||
|
||||
# word-level info
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character == word
|
||||
wdx += 1
|
||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx += 1
|
||||
|
||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx = 0
|
||||
sub_seg_idx += 1
|
||||
|
||||
prev_t2 = segment["end"]
|
||||
|
||||
segment_align_success = True
|
||||
# end while True loop
|
||||
break
|
||||
|
||||
# reset prev_t2 due to drifting issues
|
||||
if not segment_align_success:
|
||||
prev_t2 = 0
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
not_space = char_segments_arr["char"] != " "
|
||||
|
||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||
|
||||
word_segments_arr = {}
|
||||
|
||||
# start of word is first char with a timestamp
|
||||
word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"]
|
||||
# end of word is last char with a timestamp
|
||||
word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"]
|
||||
# score of word is mean (excluding nan)
|
||||
word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"]
|
||||
|
||||
|
||||
word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"]
|
||||
word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1
|
||||
word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"]
|
||||
|
||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]]
|
||||
|
||||
segments_arr = {}
|
||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
|
||||
segments_arr = pd.DataFrame(segments_arr)
|
||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
||||
|
||||
# interpolate missing words / sub-segments
|
||||
if interpolate_method != "ignore":
|
||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"])
|
||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"])
|
||||
# we still know which word timestamps are interpolated because their score == nan
|
||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
|
||||
sub_seg_grp = segments_arr.groupby(["segment-idx"])
|
||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||
# merge subsegments which are missing times
|
||||
# group by sub seg and time.
|
||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
||||
else:
|
||||
word_segments_arr.dropna(inplace=True)
|
||||
segments_arr.dropna(inplace=True)
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments_word = []
|
||||
|
||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
||||
|
||||
for sdx, srow in segments_arr.iterrows():
|
||||
|
||||
seg_idx = int(srow["segment-idx"])
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
sub_end = int(srow["subsegment-idx-end"])
|
||||
|
||||
seg = transcript[seg_idx]
|
||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
cseg['segment-text-start'] = cseg['level_1']
|
||||
cseg['segment-text-end'] = cseg['level_1'] + 1
|
||||
del cseg['level_1']
|
||||
del cseg['level_0']
|
||||
cseg.reset_index(inplace=True)
|
||||
aligned_segments.append(
|
||||
{
|
||||
"start": srow["start"],
|
||||
"end": srow["end"],
|
||||
"text": text,
|
||||
"word-segments": wseg,
|
||||
"char-segments": cseg
|
||||
}
|
||||
)
|
||||
|
||||
def get_raw_text(word_row):
|
||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
||||
|
||||
wdx = 0
|
||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
||||
if len(wseg) > 1:
|
||||
for _, wrow in wseg.iloc[1:].iterrows():
|
||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"],
|
||||
}
|
||||
)
|
||||
curr_text = ""
|
||||
curr_text += " " + get_raw_text(wrow)
|
||||
wdx += 1
|
||||
aligned_segments_word.append(
|
||||
{
|
||||
"text": curr_text.strip(),
|
||||
"start": wseg.iloc[wdx]["start"],
|
||||
"end": wseg.iloc[wdx]["end"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
||||
|
||||
|
||||
"""
|
||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||
"""
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
def get_trellis(emission, tokens, blank_id=0):
|
||||
num_frame = emission.size(0)
|
||||
num_tokens = len(tokens)
|
||||
|
@ -5,12 +5,11 @@ from typing import List, Optional, Tuple, Union, Iterator, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||
import tqdm
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
||||
from .alignment import load_align_model, align, get_trellis, backtrack, merge_repeats, merge_words
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .diarize import assign_word_speakers, Segment
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
||||
import pandas as pd
|
||||
@ -18,23 +17,6 @@ import pandas as pd
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
DEFAULT_ALIGN_MODELS_TORCH = {
|
||||
"en": "WAV2VEC2_ASR_BASE_960H",
|
||||
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
||||
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
||||
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
||||
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
||||
}
|
||||
|
||||
DEFAULT_ALIGN_MODELS_HF = {
|
||||
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
||||
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
||||
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
||||
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
||||
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
||||
}
|
||||
|
||||
|
||||
def transcribe(
|
||||
@ -273,355 +255,11 @@ def transcribe(
|
||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
||||
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
start_from_previous: bool = True,
|
||||
interpolate_method: str = "nearest",
|
||||
):
|
||||
"""
|
||||
Force align phoneme recognition predictions to known transcription
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcript: Iterator[dict]
|
||||
The Whisper model instance
|
||||
|
||||
model: torch.nn.Module
|
||||
Alignment model (wav2vec2)
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
device: str
|
||||
cuda device
|
||||
|
||||
extend_duration: float
|
||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
||||
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
model_lang = align_model_metadata["language"]
|
||||
model_type = align_model_metadata["type"]
|
||||
|
||||
aligned_segments = []
|
||||
|
||||
prev_t2 = 0
|
||||
for segment in transcript:
|
||||
aligned_subsegments = []
|
||||
while True:
|
||||
segment_align_success = False
|
||||
|
||||
# strip spaces at beginning / end, but keep track of the amount.
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
transcription = segment["text"]
|
||||
|
||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
||||
# e.g. "$300" -> "three hundred dollars"
|
||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
||||
|
||||
# split into words
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
per_word = transcription.split(" ")
|
||||
else:
|
||||
per_word = transcription
|
||||
|
||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
||||
clean_char, clean_cdx = [], []
|
||||
for cdx, char in enumerate(transcription):
|
||||
char_ = char.lower()
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
elif cdx > len(transcription) - num_trailing - 1:
|
||||
pass
|
||||
elif char_ in model_dictionary.keys():
|
||||
clean_char.append(char_)
|
||||
clean_cdx.append(cdx)
|
||||
|
||||
clean_wdx = []
|
||||
for wdx, wrd in enumerate(per_word):
|
||||
if any([c in model_dictionary.keys() for c in wrd]):
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
# if no characters are in the dictionary, then we skip this segment...
|
||||
if len(clean_char) == 0:
|
||||
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
|
||||
break
|
||||
|
||||
transcription_cleaned = "".join(clean_char)
|
||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
||||
|
||||
# pad according original timestamps
|
||||
t1 = max(segment["start"] - extend_duration, 0)
|
||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
||||
|
||||
# use prev_t2 as current t1 if it"s later
|
||||
if start_from_previous and t1 < prev_t2:
|
||||
t1 = prev_t2
|
||||
|
||||
# check if timestamp range is still valid
|
||||
if t1 >= MAX_DURATION:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
break
|
||||
if t2 - t1 < 0.02:
|
||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
||||
break
|
||||
|
||||
f1 = int(t1 * SAMPLE_RATE)
|
||||
f2 = int(t2 * SAMPLE_RATE)
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
|
||||
emission = emissions[0].cpu().detach()
|
||||
|
||||
trellis = get_trellis(emission, tokens)
|
||||
path = backtrack(trellis, emission, tokens)
|
||||
if path is None:
|
||||
print("Failed to align segment: backtrack failed, resorting to original...")
|
||||
break
|
||||
char_segments = merge_repeats(path, transcription_cleaned)
|
||||
# word_segments = merge_words(char_segments)
|
||||
|
||||
|
||||
# sub-segments
|
||||
if "seg-text" not in segment:
|
||||
segment["seg-text"] = [transcription]
|
||||
|
||||
v = 0
|
||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
||||
seg_lens_cumsum = [v := v + n for n in seg_lens]
|
||||
sub_seg_idx = 0
|
||||
|
||||
char_level = {
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
"word-index": [],
|
||||
}
|
||||
|
||||
word_level = {
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
"segment-text-start": [],
|
||||
"segment-text-end": []
|
||||
}
|
||||
|
||||
wdx = 0
|
||||
seg_start_actual, seg_end_actual = None, None
|
||||
duration = t2 - t1
|
||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||
cdx_prev = 0
|
||||
for cdx, char in enumerate(transcription + " "):
|
||||
is_last = False
|
||||
if cdx == len(transcription):
|
||||
break
|
||||
elif cdx+1 == len(transcription):
|
||||
is_last = True
|
||||
|
||||
|
||||
start, end, score = None, None, None
|
||||
if cdx in clean_cdx:
|
||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
||||
start = char_seg.start * ratio + t1
|
||||
end = char_seg.end * ratio + t1
|
||||
score = char_seg.score
|
||||
|
||||
char_level["start"].append(start)
|
||||
char_level["end"].append(end)
|
||||
char_level["score"].append(score)
|
||||
char_level["word-index"].append(wdx)
|
||||
|
||||
# word-level info
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character == word
|
||||
wdx += 1
|
||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
wdx += 1
|
||||
word_level["start"].append(None)
|
||||
word_level["end"].append(None)
|
||||
word_level["score"].append(None)
|
||||
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
|
||||
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
|
||||
cdx_prev = cdx+2
|
||||
|
||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_level = pd.DataFrame(char_level)
|
||||
word_level = pd.DataFrame(word_level)
|
||||
|
||||
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
|
||||
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
|
||||
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
|
||||
|
||||
# fill missing
|
||||
if interpolate_method != "ignore":
|
||||
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
|
||||
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
|
||||
word_level["start"] = word_level["start"].values.tolist()
|
||||
word_level["end"] = word_level["end"].values.tolist()
|
||||
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
|
||||
|
||||
char_level = char_level.replace({np.nan:None}).to_dict("list")
|
||||
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
|
||||
else:
|
||||
word_level = None
|
||||
|
||||
aligned_subsegments.append(
|
||||
{
|
||||
"text": segment["seg-text"][sub_seg_idx],
|
||||
"start": seg_start_actual,
|
||||
"end": seg_end_actual,
|
||||
"char-segments": char_level,
|
||||
"word-segments": word_level
|
||||
}
|
||||
)
|
||||
if "language" in segment:
|
||||
aligned_subsegments[-1]["language"] = segment["language"]
|
||||
|
||||
char_level = {
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
"word-index": [],
|
||||
}
|
||||
word_level = {
|
||||
"start": [],
|
||||
"end": [],
|
||||
"score": [],
|
||||
"segment-text-start": [],
|
||||
"segment-text-end": []
|
||||
}
|
||||
wdx = 0
|
||||
cdx_prev = cdx + 2
|
||||
sub_seg_idx += 1
|
||||
seg_start_actual, seg_end_actual = None, None
|
||||
|
||||
|
||||
# take min-max for actual segment-level timestamp
|
||||
if seg_start_actual is None and start is not None:
|
||||
seg_start_actual = start
|
||||
if end is not None:
|
||||
seg_end_actual = end
|
||||
|
||||
|
||||
prev_t2 = segment["end"]
|
||||
|
||||
segment_align_success = True
|
||||
# end while True loop
|
||||
break
|
||||
|
||||
# reset prev_t2 due to drifting issues
|
||||
if not segment_align_success:
|
||||
prev_t2 = 0
|
||||
|
||||
start = interpolate_nans(pd.DataFrame(aligned_subsegments)["start"], method=interpolate_method)
|
||||
end = interpolate_nans(pd.DataFrame(aligned_subsegments)["end"], method=interpolate_method)
|
||||
for idx, seg in enumerate(aligned_subsegments):
|
||||
seg['start'] = start.iloc[idx]
|
||||
seg['end'] = end.iloc[idx]
|
||||
|
||||
aligned_segments += aligned_subsegments
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
for seg in aligned_segments:
|
||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||
# character based
|
||||
seg["word-segments"] = seg["char-segments"]
|
||||
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
|
||||
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
|
||||
|
||||
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
|
||||
for wdx, wrow in wseg.iterrows():
|
||||
if wrow["start"] is not None:
|
||||
word_seg.append(
|
||||
{
|
||||
"start": wrow["start"],
|
||||
"end": wrow["end"],
|
||||
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||
}
|
||||
)
|
||||
|
||||
return {"segments": aligned_segments, "word_segments": word_seg}
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||
else:
|
||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
||||
raise ValueError(f"No default align-model for language: {language_code}")
|
||||
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
pipeline_type = "torchaudio"
|
||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
||||
pipeline_type = "huggingface"
|
||||
align_model = align_model.to(device)
|
||||
labels = processor.tokenizer.get_vocab()
|
||||
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
||||
|
||||
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
||||
|
||||
return align_model, align_metadata
|
||||
|
||||
|
||||
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
||||
"""
|
||||
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
||||
Merge VAD segments into larger segments of approximately size ~CHUNK_LENGTH.
|
||||
TODO: Make sure VAD segment isn't too long, otherwise it will cause OOM when input to alignment model
|
||||
TODO: Or sliding window alignment model over long segment.
|
||||
"""
|
||||
curr_start = 0
|
||||
curr_end = 0
|
||||
@ -702,58 +340,6 @@ def transcribe_with_vad(
|
||||
return output
|
||||
|
||||
|
||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
||||
|
||||
for seg in result_segments:
|
||||
wdf = pd.DataFrame(seg['word-segments'])
|
||||
if len(wdf['start'].dropna()) == 0:
|
||||
wdf['start'] = seg['start']
|
||||
wdf['end'] = seg['end']
|
||||
speakers = []
|
||||
for wdx, wrow in wdf.iterrows():
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
||||
# remove no hit
|
||||
if not fill_nearest:
|
||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||
else:
|
||||
dia_tmp = diarize_df
|
||||
if len(dia_tmp) == 0:
|
||||
speaker = None
|
||||
else:
|
||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
||||
speakers.append(speaker)
|
||||
seg['word-segments']['speaker'] = speakers
|
||||
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
|
||||
|
||||
# create word level segments for .srt
|
||||
word_seg = []
|
||||
for seg in result_segments:
|
||||
wseg = pd.DataFrame(seg["word-segments"])
|
||||
for wdx, wrow in wseg.iterrows():
|
||||
if wrow["start"] is not None:
|
||||
speaker = wrow['speaker']
|
||||
if speaker is None or speaker == np.nan:
|
||||
speaker = "UNKNOWN"
|
||||
word_seg.append(
|
||||
{
|
||||
"start": wrow["start"],
|
||||
"end": wrow["end"],
|
||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: create segments but split words on new speaker
|
||||
|
||||
return result_segments, word_seg
|
||||
|
||||
class Segment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.speaker = speaker
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
@ -776,7 +362,7 @@ def cli():
|
||||
parser.add_argument("--max_speakers", default=None, type=int)
|
||||
# output save params
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
|
||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save")
|
||||
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
@ -868,6 +454,7 @@ def cli():
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
|
||||
|
||||
print("Performing alignment...")
|
||||
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
||||
@ -915,10 +502,16 @@ def cli():
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass)
|
||||
|
||||
# save ASS character-level
|
||||
if output_type in ["ass-char", "all"]:
|
||||
# # save ASS character-level
|
||||
if output_type in ["ass-char"]:
|
||||
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass, resolution="char")
|
||||
|
||||
# save word tsv
|
||||
if output_type in ["pickle"]:
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
|
||||
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import zlib
|
||||
from typing import Callable, TextIO, Iterator, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
@ -214,7 +215,7 @@ def write_ass(transcript: Iterator[dict],
|
||||
else:
|
||||
speaker_str = ""
|
||||
for cdx, crow in res_segs.iterrows():
|
||||
if crow['start'] is not None:
|
||||
if not np.isnan(crow['start']):
|
||||
if resolution == "char":
|
||||
idx_0 = cdx
|
||||
idx_1 = cdx + 1
|
||||
|
Reference in New Issue
Block a user