fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based)

This commit is contained in:
Max Bain
2023-05-07 15:32:58 +01:00
parent 07361ba1d7
commit 24008aa1ed
6 changed files with 269 additions and 574 deletions

View File

@ -5,3 +5,4 @@ transformers
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
pandas pandas
setuptools==65.6.3 setuptools==65.6.3
nltk

View File

@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from .utils import interpolate_nans
import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -84,44 +85,13 @@ def align(
align_model_metadata: dict, align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
device: str, device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest", interpolate_method: str = "nearest",
return_char_alignments: bool = False,
): ):
""" """
Force align phoneme recognition predictions to known transcription Align phoneme recognition predictions to known transcription.
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
@ -135,335 +105,206 @@ def align(
model_lang = align_model_metadata["language"] model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"] model_type = align_model_metadata["type"]
aligned_segments = [] # 1. Preprocess to keep only characters in dictionary
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
while True: # strip spaces at beginning / end, but keep track of the amount.
segment_align_success = False num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
# strip spaces at beginning / end, but keep track of the amount. # split into words
num_leading = len(segment["text"]) - len(segment["text"].lstrip()) if model_lang not in LANGUAGES_WITHOUT_SPACES:
num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) per_word = text.split(" ")
transcription = segment["text"] else:
per_word = text
# TODO: convert number tokenizer / symbols to phonetic words for alignment. clean_char, clean_cdx = [], []
# e.g. "$300" -> "three hundred dollars" for cdx, char in enumerate(text):
# currently "$300" is ignored since no characters present in the phonetic dictionary char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES: if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ") char_ = char_.replace(" ", "|")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary) # ignore whitespace at beginning and end of transcript
clean_char, clean_cdx = [], [] if cdx < num_leading:
for cdx, char in enumerate(transcription): pass
char_ = char.lower() elif cdx > len(text) - num_trailing - 1:
# wav2vec2 models use "|" character to represent spaces pass
if model_lang not in LANGUAGES_WITHOUT_SPACES: elif char_ in model_dictionary.keys():
char_ = char_.replace(" ", "|") clean_char.append(char_)
clean_cdx.append(cdx)
# ignore whitespace at beginning and end of transcript clean_wdx = []
if cdx < num_leading: for wdx, wrd in enumerate(per_word):
pass if any([c in model_dictionary.keys() for c in wrd]):
elif cdx > len(transcription) - num_trailing - 1: clean_wdx.append(wdx)
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = [] sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# we only pad if not using VAD filtering
if "seg_text" not in segment:
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = list(np.cumsum(seg_lens))
sub_seg_idx = 0
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
word_segments_arr = {}
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().values
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments = [] aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True) # 2. Get prediction matrix from alignment model & align
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True) for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
for sdx, srow in segments_arr.iterrows(): aligned_seg = {
"start": t1,
"end": t2,
"text": text,
"words": [],
}
seg_idx = int(srow["segment-idx"]) if return_char_alignments:
sub_start = int(srow["subsegment-idx-start"]) aligned_seg["chars"] = []
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx] # check we can align
text = "".join(seg["seg-text"][sub_start:sub_end]) if len(segment["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] if t1 >= MAX_DURATION or t2 - t1 < 0.02:
wseg["start"].fillna(srow["start"], inplace=True) print("Failed to align segment: original start time longer than audio duration, skipping...")
wseg["end"].fillna(srow["end"], inplace=True) aligned_segments.append(aligned_seg)
wseg["segment-text-start"].fillna(0, inplace=True) continue
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] text_clean = "".join(segment["clean_char"])
# fixes bug for single segment in transcript tokens = [model_dictionary[c] for c in text_clean]
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
def get_raw_text(word_row): f1 = int(t1 * SAMPLE_RATE)
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] f2 = int(t2 * SAMPLE_RATE)
word_list = [] # TODO: Probably can get some speedup gain with batched inference here
wdx = 0 waveform_segment = audio[:, f1:f2]
curr_text = get_raw_text(wseg.iloc[wdx])
if not curr_text.startswith(" "):
curr_text = " " + curr_text
if len(wseg) > 1: with torch.inference_mode():
for _, wrow in wseg.iloc[1:].iterrows(): if model_type == "torchaudio":
if wrow['start'] != wseg.iloc[wdx]['start']: emissions, _ = model(waveform_segment.to(device))
word_start = wseg.iloc[wdx]['start'] elif model_type == "huggingface":
word_end = wseg.iloc[wdx]['end'] emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
aligned_segments_word.append( emission = emissions[0].cpu().detach()
{
"text": curr_text.strip(),
"start": word_start,
"end": word_end
}
)
word_list.append( blank_id = 0
{ for char, code in model_dictionary.items():
"word": curr_text.rstrip(), if char == '[pad]' or char == '<pad>':
"start": word_start, blank_id = code
"end": word_end,
}
)
curr_text = " " trellis = get_trellis(emission, tokens, blank_id)
curr_text += get_raw_text(wrow) + " " path = backtrack(trellis, emission, tokens, blank_id)
wdx += 1
aligned_segments_word.append( if path is None:
{ print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
"text": curr_text.strip(), aligned_segments.append(aligned_seg)
"start": wseg.iloc[wdx]["start"], continue
"end": wseg.iloc[wdx]["end"]
}
)
word_list.append( char_segments = merge_repeats(path, text_clean)
{
"word": curr_text.rstrip(),
"start": wseg.iloc[wdx]['start'],
"end": wseg.iloc[wdx]['end'],
}
)
aligned_segments.append( duration = t2 -t1
{ ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
"start": srow["start"],
"end": srow["end"],
"text": text,
"words": word_list,
"word-segments": wseg,
"char-segments": cseg
}
)
# assign timestamps to aligned characters
char_segments_arr = []
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
return {"segments": aligned_segments, "word_segments": aligned_segments_word} char_segments_arr.append(
{
"char": char,
"start": start,
"end": end,
"score": score,
"word-idx": word_idx,
}
)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES:
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
sentence_end = curr_chars["end"].max()
sentence_words = []
for word_idx in curr_chars["word-idx"].unique():
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue
word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable
word_segment = {"word": word_text}
if not np.isnan(word_start):
word_segment["start"] = word_start
if not np.isnan(word_end):
word_segment["end"] = word_end
if not np.isnan(word_score):
word_segment["score"] = word_score
sentence_words.append(word_segment)
aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"words": sentence_words,
})
if return_char_alignments:
curr_chars = curr_chars[["char", "start", "end", "score"]]
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
if return_char_alignments:
agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
aligned_subsegments = aligned_subsegments.to_dict('records')
aligned_segments += aligned_subsegments
# create word_segments list
word_segments = []
for segment in aligned_segments:
word_segments += segment["words"]
return {"segments": aligned_segments, "word_segments": word_segments}
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html

View File

@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
class WhisperModel(faster_whisper.WhisperModel): class WhisperModel(faster_whisper.WhisperModel):
''' '''
FasterWhisperModel provides batched inference for faster-whisper. FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode. Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
''' '''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline): class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__( def __init__(
self, self,
model, model,
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
language = language_token[2:-2] language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language return language
if __name__ == "__main__":
main_type = "simple"
import time
import jiwer
from tqdm import tqdm
from whisper.normalizers import EnglishTextNormalizer
from benchmark.tedlium import parse_tedlium_annos
if main_type == "complex":
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions
from faster_whisper.vad import (SpeechTimestampsMap,
get_speech_timestamps)
from whisperx.vad import load_vad_model, merge_chunks
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
faster_t_options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1,
length_penalty=1,
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=False,
initial_prompt=None,
prefix=None,
suppress_blank=True,
suppress_tokens=[-1],
without_timestamps=True,
max_initial_timestamp=0.0,
word_timestamps=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、"
)
whisper_arch = "large-v2"
device = "cuda"
batch_size = 16
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
fn = "DanielKahneman_2010.wav"
wav_dir = f"/tmp/test/wav/"
vad_model = load_vad_model("cuda", 0.6, 0.3)
audio = load_audio(os.path.join(wav_dir, fn))
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_method="pyannote"
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
if vad_method == "pyannote":
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
elif vad_method == "silero":
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
new_segs = []
curr_start = vad_segments[0]['start']
curr_end = vad_segments[0]['end']
for seg in vad_segments[1:]:
if seg['end'] - curr_start > 30:
new_segs.append({"start": curr_start, "end": curr_end})
curr_start = seg['start']
curr_end = seg['end']
else:
curr_end = seg['end']
new_segs.append({"start": curr_start, "end": curr_end})
vad_segments = new_segs
text = []
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
text.append(out['text'])
t2 = time.time()
if batch_size == 1:
text = [x[0] for x in text]
text = " ".join(text)
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))
elif main_type == "simple":
model = load_model(
"large-v2",
device="cuda",
language="en",
)
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
# fn = "DanielKahneman_2010.wav"
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
out = model.transcribe(audio_fp, batch_size=8)["segments"]
t2 = time.time()
text = " ".join([x['text'] for x in out])
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))

View File

@ -11,7 +11,6 @@ class DiarizationPipeline:
use_auth_token=None, use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu", device: Optional[Union[str, torch.device]] = "cpu",
): ):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
@ -21,59 +20,44 @@ class DiarizationPipeline:
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True)) diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
diarize_df.rename(columns={2: "speaker"}, inplace=True)
return diarize_df return diarize_df
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = seg['word-segments']
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
if not np.isnan(wrow['start']):
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
else:
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
speaker_count = pd.Series(speakers).value_counts() def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
if len(speaker_count) == 0: transcript_segments = transcript_result["segments"]
seg["speaker"]= "UNKNOWN" for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
# remove no hit, otherwise we look for closest (even negative intersection...)
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else: else:
seg["speaker"] = speaker_count.index[0] dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
seg["speaker"] = speaker
# create word level segments for .srt # assign speaker to words
word_seg = [] if 'words' in seg:
for seg in result_segments: for word in seg['words']:
wseg = pd.DataFrame(seg["word-segments"]) if 'start' in word:
for wdx, wrow in wseg.iterrows(): diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
if wrow["start"] is not None: diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
speaker = wrow['speaker'] # remove no hit
if speaker is None or speaker == np.nan: if not fill_nearest:
speaker = "UNKNOWN" dia_tmp = diarize_df[diarize_df['intersection'] > 0]
word_seg.append( else:
{ dia_tmp = diarize_df
"start": wrow["start"], if len(dia_tmp) > 0:
"end": wrow["end"], # sum over speakers
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
} word["speaker"] = speaker
)
# TODO: create segments but split words on new speaker return transcript_result
return result_segments, word_seg
class Segment: class Segment:
def __init__(self, start, end, speaker=None): def __init__(self, start, end, speaker=None):

View File

@ -64,14 +64,11 @@ def cli():
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -97,7 +94,6 @@ def cli():
min_speakers: int = args.pop("min_speakers") min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers") max_speakers: int = args.pop("max_speakers")
# TODO: check model loading works.
if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None: if args["language"] is not None:
@ -176,6 +172,7 @@ def cli():
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...") print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method) result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
results.append((result, audio_path)) results.append((result, audio_path))
# Unload align model # Unload align model
@ -193,18 +190,10 @@ def cli():
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) result = assign_word_speakers(diarize_segments, result)
result = {"segments": results_segments, "word_segments": word_segments}
results.append((result, input_audio_path)) results.append((result, input_audio_path))
# >> Write # >> Write
for result, audio_path in results: for result, audio_path in results:
# Remove pandas dataframes from result so that
# we can serialize the result with json
for seg in result["segments"]:
seg.pop("word-segments", None)
seg.pop("char-segments", None)
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"] times = []
last = result["segments"][0]["start"]
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy() timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0 long_pause = not preserve_segments
if "start" in timing:
long_pause = long_pause and timing["start"] - last > 3.0
else:
long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break: if line_len > 0 and has_room and not long_pause and not seg_break:
@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
or seg_break or seg_break
): ):
# subtitle break # subtitle break
yield subtitle yield subtitle, times
subtitle = [] subtitle = []
times = []
line_count = 1 line_count = 1
elif line_len > 0: elif line_len > 0:
# line break # line break
@ -260,40 +266,53 @@ class SubtitlesWriter(ResultWriter):
timing["word"] = "\n" + timing["word"] timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip()) line_len = len(timing["word"].strip())
subtitle.append(timing) subtitle.append(timing)
last = timing["start"] times.append((segment["start"], segment["end"], segment.get("speaker")))
if "start" in timing:
last = timing["start"]
if len(subtitle) > 0: if len(subtitle) > 0:
yield subtitle yield subtitle, times
if "words" in result["segments"][0]: if "words" in result["segments"][0]:
for subtitle in iterate_subtitles(): for subtitle, _ in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"]) sstart, ssend, speaker = _[0]
subtitle_end = self.format_timestamp(subtitle[-1]["end"]) subtitle_start = self.format_timestamp(sstart)
subtitle_text = "".join([word["word"] for word in subtitle]) subtitle_end = self.format_timestamp(ssend)
if highlight_words: subtitle_text = " ".join([word["word"] for word in subtitle])
has_timing = any(["start" in word for word in subtitle])
# add [$SPEAKER_ID]: to each subtitle if speaker is available
prefix = ""
if speaker is not None:
prefix = f"[{speaker}]: "
if highlight_words and has_timing:
last = subtitle_start last = subtitle_start
all_words = [timing["word"] for timing in subtitle] all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle): for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"]) if "start" in this_word:
end = self.format_timestamp(this_word["end"]) start = self.format_timestamp(this_word["start"])
if last != start: end = self.format_timestamp(this_word["end"])
yield last, start, subtitle_text if last != start:
yield last, start, subtitle_text
yield start, end, "".join( yield start, end, prefix + " ".join(
[ [
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i if j == i
else word else word
for j, word in enumerate(all_words) for j, word in enumerate(all_words)
] ]
) )
last = end last = end
else: else:
yield subtitle_start, subtitle_end, subtitle_text yield subtitle_start, subtitle_end, prefix + subtitle_text
else: else:
for segment in result["segments"]: for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"]) segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"]) segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->") segment_text = segment["text"].strip().replace("-->", "->")
if "speaker" in segment:
segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float): def format_timestamp(self, seconds: float):