new logic, diarization, vad filtering

This commit is contained in:
Max Bain
2023-01-24 15:02:08 +00:00
parent ba102feb7f
commit d395c21b83
8 changed files with 498 additions and 260 deletions

View File

@ -48,6 +48,13 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
<h2 align="left", id="highlights">New🚨</h2>
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarization`)
<h2 align="left" id="setup">Setup ⚙️</h2>
Install this package using
@ -76,9 +83,9 @@ Run whisper on example segment (using default params)
whisperx examples/sample01.wav
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models e.g.
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
whisperx examples/sample01.wav --model large.en --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -162,7 +169,11 @@ The next major upgrade we are working on is whisper with speaker diarization, so
[x] ~~Python usage~~ done
[ ] Incorporating word-level speaker diarization
[x] ~~Character level timestamps~~
[x] ~~Incorporating speaker diarization~~
[ ] Improve diarization (word level)
[ ] Inference speedup with batch processing

View File

@ -6,3 +6,4 @@ soundfile
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
pyannote.audio

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe, load_align_model, align
from .transcribe import transcribe, load_align_model, align, transcribe_with_vad
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",

View File

@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

View File

@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
k = kv_cache[self.key]
v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x

View File

@ -1737,6 +1737,5 @@
"yoghurt": "yogurt",
"yoghurts": "yogurts",
"mhm": "hmm",
"mm": "hmm",
"mmm": "hmm"
}

View File

@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim,
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
import pandas as pd
if TYPE_CHECKING:
@ -280,8 +280,39 @@ def align(
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
drop_non_aligned_words: bool = False,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
@ -291,30 +322,78 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata['dictionary']
model_lang = align_model_metadata['language']
model_type = align_model_metadata['type']
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
total_word_segments_list = []
vad_segments_list = []
for idx, segment in enumerate(transcript):
word_segments_list = []
# first we pad
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
sdx = 0
for segment in transcript:
while True:
segment_align_success = False
# use prev_t2 as current t1 if it's later
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
per_word = transcription
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
continue
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
continue
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
@ -332,130 +411,177 @@ def align(
emission = emissions[0].cpu().detach()
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
ratio = waveform_segment.size(0) / emission.size(0)
space_idx = model_dictionary['|']
# find non-vad segments
for i in range(1, len(segment['vad'])):
start = segment['vad'][i-1][1]
end = segment['vad'][i][0]
if start < end: # check if there is a gap between intervals
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
start = segment['vad'][i][1]
end = segment['end']
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
transcription = segment['text'].strip()
if model_lang not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription]
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""]
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
segment['word-level'] = []
fail_fallback = False
if len(t_words_nonempty) > 0:
transcription_cleaned = "|".join(t_words_nonempty).lower()
tokens = [model_dictionary[c] for c in transcription_cleaned]
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
fail_fallback = True
else:
segments = merge_repeats(path, transcription_cleaned)
word_segments = merge_words(segments)
ratio = waveform_segment.size(0) / (trellis.size(0) - 1)
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
seg_start_actual, seg_end_actual = None, None
duration = t2 - t1
local = []
t_local = [None] * len(t_words)
for wdx, word in enumerate(word_segments):
t1_ = ratio * word.start
t2_ = ratio * word.end
local.append((t1_, t2_))
t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1)
t1_actual = t1 + local[0][0] * duration
t2_actual = t1 + local[-1][1] * duration
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
cdx_prev = 0
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
segment['start'] = t1_actual
segment['end'] = t2_actual
prev_t2 = segment['end']
# for the .ass output
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_level["start"].append(start)
char_level["end"].append(end)
char_level["score"].append(score)
char_level["word-index"].append(wdx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
word_level["start"].append(None)
word_level["end"].append(None)
word_level["score"].append(None)
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
cdx_prev = cdx+2
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_level = pd.DataFrame(char_level)
word_level = pd.DataFrame(word_level)
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
# fill missing
if interpolate_method != "ignore":
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
word_level["start"] = word_level["start"].values.tolist()
word_level["end"] = word_level["end"].values.tolist()
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
char_level = char_level.replace({np.nan:None}).to_dict("list")
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else:
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
word_level = None
# for per-word .srt ouput
# merge missing words to previous, or merge with next word ahead if idx == 0
found_first_ts = False
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
found_first_ts = True
elif not drop_non_aligned_words:
# then we merge
if not found_first_ts:
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
else:
word_segments_list[-1]['text'] += ' ' + curr_word
else:
fail_fallback = True
aligned_segments.append(
{
"text": segment["seg-text"][sub_seg_idx],
"start": seg_start_actual,
"end": seg_end_actual,
"char-segments": char_level,
"word-segments": word_level
}
)
if "language" in segment:
aligned_segments[-1]["language"] = segment["language"]
if fail_fallback:
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}")
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx += 1
seg_start_actual, seg_end_actual = None, None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None:
seg_start_actual = start
if end is not None:
seg_end_actual = end
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
if 'vad' in segment:
curr_vdx = 0
curr_text = ''
for wrd_seg in word_segments_list:
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
# shift segment index by amount of sub-segments
if "seg-text" in segment:
sdx += len(segment["seg-text"])
else:
sdx += 1
# create word level segments for .srt
word_seg = []
for seg in aligned_segments:
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character based
seg["word-segments"] = seg["char-segments"]
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
curr_vdx += 1
curr_text = ''
curr_text += ' ' + wrd_seg['text']
if len(curr_text) > 0:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
)
curr_text = ''
total_word_segments_list += word_segments_list
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None):
return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
'''
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
'''
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
"""
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
"""
curr_start = 0
curr_end = 0
merged_segments = []
@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs,
})
curr_start = seg.start
seg_idxs = []
@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs
})
return merged_segments
def transcribe_segments(
def transcribe_with_vad(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
merged_segments,
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
**kwargs
):
'''
Transcribe according to predefined VAD segments.
'''
"""
Transcribe per VAD segment
"""
if mel is None:
mel = log_mel_spectrogram(audio)
prev = 0
output = {"segments": []}
output = {'segments': []}
vad_segments_list = []
vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list)
for sdx, seg_t in enumerate(merged_segments):
print(sdx, seg_t['start'], seg_t['end'], '...')
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
for sdx, seg_t in enumerate(vad_segments):
if verbose:
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start]
result = transcribe(model, audio, mel=local_mel, **kwargs)
seg_t['text'] = result['text']
output['segments'].append(
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t["text"] = result["text"]
output["segments"].append(
{
'start': seg_t['start'],
'end': seg_t['end'],
'language': result['language'],
'text': result['text'],
'seg-text': [x['text'] for x in result['segments']],
'seg-start': [x['start'] for x in result['segments']],
'seg-end': [x['end'] for x in result['segments']],
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output['language'] = output['segments'][0]['language']
output["language"] = output["segments"][0]["language"]
return output
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = pd.DataFrame(seg['word-segments'])
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
@ -589,11 +766,17 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str)
# diarization params
parser.add_argument("--diarize", action='store_true')
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -627,24 +810,32 @@ def cli():
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
drop_non_aligned: bool = args.pop("drop_non_aligned")
interpolate_method: bool = args.pop("interpolate_method")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
vad_pipeline = None
if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter:
from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
diarize_pipeline = None
if diarize:
from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
args["language"] = "en"
temperature = args.pop("temperature")
@ -665,24 +856,10 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
if vad_filter or vad_input is not None:
output_segments = []
if vad_filter:
print("Performing VAD...")
# vad_segments = vad_pipeline(audio_path)
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
vad_segments = vad_pipeline(audio_path)
for speech_turn in vad_segments.get_timeline().support():
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
elif vad_input is not None:
# rttm format
for idx, row in vad_input.iterrows():
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
vad_segments = merge_chunks(output_segments)
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
vad_segments = None
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
@ -693,9 +870,20 @@ def cli():
print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
audio_basename = os.path.basename(audio_path)
if diarize:
print("Performing diarization...")
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
# save TXT
if output_type in ["txt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
@ -711,19 +899,27 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
# save TSV
if output_type in ["tsv", "all"]:
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_tsv(result_aligned["segments"], file=srt)
# save SRT word-level
if output_type in ["srt-word", "all"]:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
if output_type in ["ass", "all"]:
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if vad_filter is not None:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["vad_segments"], file=srt)
# save ASS character-level
if output_type in ["ass-char", "all"]:
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass, resolution="char")
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View File

@ -1,6 +1,7 @@
import os
import zlib
from typing import Iterator, TextIO, Tuple, List
from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
def exact_div(x, y):
assert x % y == 0
@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
flush=True,
)
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
)
def write_ass(transcript: Iterator[dict], file: TextIO,
def write_ass(transcript: Iterator[dict],
file: TextIO,
resolution: str = "word",
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
res: dict
transcript: dict
results from modified model
ass_path: str
output path (e.g. caption.ass)
file: TextIO
file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def dialogue(words: List[str], idx, start, end) -> str:
text = ''.join(f' {prefmt}{word}{suffmt}'
# if not word.startswith(' ') or word == ' ' else
# f' {prefmt}{word.strip()}{suffmt}')
if curr_idx == idx else
f' {word}'
for curr_idx, word in enumerate(words))
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
if idx_0 == -1:
text = chars
else:
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = []
for segment in transcript:
curr_words = [wrd['text'] for wrd in segment['word-level']]
prev = segment['word-level'][0]['start']
if prev is None:
if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
for wdx, word in enumerate(segment['word-level']):
if word['start'] is not None:
# fill gap between previous word
if word['start'] > prev:
if "speaker" in segment:
speaker_str = f"[{segment['speaker']}]: "
else:
speaker_str = ""
for cdx, crow in res_segs.iterrows():
if crow['start'] is not None:
if resolution == "char":
idx_0 = cdx
idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"words": curr_words,
"chars": speaker_str + segment['text'],
"start": prev,
"end": word['start'],
"idx": -1
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"words": curr_words,
"start": word['start'],
"end": word['end'],
"idx": wdx
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
}
ass_arr.append(f_word_ts)
prev = word['end']
prev = crow['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()