new logic, diarization, vad filtering

This commit is contained in:
Max Bain
2023-01-24 15:02:08 +00:00
parent ba102feb7f
commit d395c21b83
8 changed files with 498 additions and 260 deletions

View File

@ -48,6 +48,13 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation. **Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
<h2 align="left", id="highlights">New🚨</h2>
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarization`)
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Install this package using Install this package using
@ -76,9 +83,9 @@ Run whisper on example segment (using default params)
whisperx examples/sample01.wav whisperx examples/sample01.wav
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models e.g. For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
whisperx examples/sample01.wav --model large.en --align_model WAV2VEC2_ASR_LARGE_LV60K_960H whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large: Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -162,7 +169,11 @@ The next major upgrade we are working on is whisper with speaker diarization, so
[x] ~~Python usage~~ done [x] ~~Python usage~~ done
[ ] Incorporating word-level speaker diarization [x] ~~Character level timestamps~~
[x] ~~Incorporating speaker diarization~~
[ ] Improve diarization (word level)
[ ] Inference speedup with batch processing [ ] Inference speedup with batch processing

View File

@ -6,3 +6,4 @@ soundfile
more-itertools more-itertools
transformers>=4.19.0 transformers>=4.19.0
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
pyannote.audio

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions from .model import Whisper, ModelDimensions
from .transcribe import transcribe, load_align_model, align from .transcribe import transcribe, load_align_model, align, transcribe_with_vad
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",

View File

@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2 magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels) filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes mel_spec = filters @ magnitudes
@ -121,4 +121,4 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec

View File

@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
k = kv_cache[self.key] k = kv_cache[self.key]
v = kv_cache[self.value] v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv) return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape n_batch, n_ctx, n_state = q.shape
@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
qk = q @ k qk = q @ k
if mask is not None: if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx] qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk.float(), dim=-1).to(q.dtype) w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn: if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x)) x = x + self.mlp(self.mlp_ln(x))
return x return x
@ -264,4 +265,4 @@ class Whisper(nn.Module):
detect_language = detect_language_function detect_language = detect_language_function
transcribe = transcribe_function transcribe = transcribe_function
decode = decode_function decode = decode_function

View File

@ -1737,6 +1737,5 @@
"yoghurt": "yogurt", "yoghurt": "yogurt",
"yoghurts": "yogurts", "yoghurts": "yogurts",
"mhm": "hmm", "mhm": "hmm",
"mm": "hmm",
"mmm": "hmm" "mmm": "hmm"
} }

View File

@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim,
from .alignment import get_trellis, backtrack, merge_repeats, merge_words from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
import pandas as pd import pandas as pd
if TYPE_CHECKING: if TYPE_CHECKING:
@ -280,8 +280,39 @@ def align(
device: str, device: str,
extend_duration: float = 0.0, extend_duration: float = 0.0,
start_from_previous: bool = True, start_from_previous: bool = True,
drop_non_aligned_words: bool = False, interpolate_method: str = "nearest",
): ):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
@ -291,171 +322,266 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata['dictionary'] model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata['language'] model_lang = align_model_metadata["language"]
model_type = align_model_metadata['type'] model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0 prev_t2 = 0
total_word_segments_list = [] sdx = 0
vad_segments_list = [] for segment in transcript:
for idx, segment in enumerate(transcript): while True:
word_segments_list = [] segment_align_success = False
# first we pad
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it's later # strip spaces at beginning / end, but keep track of the amount.
if start_from_previous and t1 < prev_t2: num_leading = len(segment["text"]) - len(segment["text"].lstrip())
t1 = prev_t2 num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# check if timestamp range is still valid # TODO: convert number tokenizer / symbols to phonetic words for alignment.
if t1 >= MAX_DURATION: # e.g. "$300" -> "three hundred dollars"
print("Failed to align segment: original start time longer than audio duration, skipping...") # currently "$300" is ignored since no characters present in the phonetic dictionary
continue
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
continue
f1 = int(t1 * SAMPLE_RATE) # split into words
f2 = int(t2 * SAMPLE_RATE) if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else: else:
raise NotImplementedError(f"Align model of type {model_type} not supported.") per_word = transcription
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach() # first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary: clean_wdx = []
ratio = waveform_segment.size(0) / emission.size(0) for wdx, wrd in enumerate(per_word):
space_idx = model_dictionary['|'] if any([c in model_dictionary.keys() for c in wrd]):
# find non-vad segments clean_wdx.append(wdx)
for i in range(1, len(segment['vad'])):
start = segment['vad'][i-1][1]
end = segment['vad'][i][0]
if start < end: # check if there is a gap between intervals
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
# if no characters are in the dictionary, then we skip this segment...
start = segment['vad'][i][1] if len(clean_char) == 0:
end = segment['end'] print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
non_vad_f1 = int(start / ratio) break
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so transcription_cleaned = "".join(clean_char)
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
transcription = segment['text'].strip()
if model_lang not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription]
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""]
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
segment['word-level'] = []
fail_fallback = False
if len(t_words_nonempty) > 0:
transcription_cleaned = "|".join(t_words_nonempty).lower()
tokens = [model_dictionary[c] for c in transcription_cleaned] tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens) trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens) path = backtrack(trellis, emission, tokens)
if path is None: if path is None:
print("Failed to align segment: backtrack failed, resorting to original...") print("Failed to align segment: backtrack failed, resorting to original...")
fail_fallback = True break
else: char_segments = merge_repeats(path, transcription_cleaned)
segments = merge_repeats(path, transcription_cleaned) # word_segments = merge_words(char_segments)
word_segments = merge_words(segments)
ratio = waveform_segment.size(0) / (trellis.size(0) - 1)
duration = t2 - t1 # sub-segments
local = [] if "seg-text" not in segment:
t_local = [None] * len(t_words) segment["seg-text"] = [transcription]
for wdx, word in enumerate(word_segments):
t1_ = ratio * word.start v = 0
t2_ = ratio * word.end seg_lens = [0] + [len(x) for x in segment["seg-text"]]
local.append((t1_, t2_)) seg_lens_cumsum = [v := v + n for n in seg_lens]
t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1) sub_seg_idx = 0
t1_actual = t1 + local[0][0] * duration
t2_actual = t1 + local[-1][1] * duration
segment['start'] = t1_actual char_level = {
segment['end'] = t2_actual "start": [],
prev_t2 = segment['end'] "end": [],
"score": [],
"word-index": [],
}
# for the .ass output word_level = {
for x in range(len(t_local)): "start": [],
curr_word = t_words[x] "end": [],
curr_timestamp = t_local[x] "score": [],
if curr_timestamp is not None: "segment-text-start": [],
segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]}) "segment-text-end": []
}
wdx = 0
seg_start_actual, seg_end_actual = None, None
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
cdx_prev = 0
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_level["start"].append(start)
char_level["end"].append(end)
char_level["score"].append(score)
char_level["word-index"].append(wdx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
word_level["start"].append(None)
word_level["end"].append(None)
word_level["score"].append(None)
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
cdx_prev = cdx+2
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_level = pd.DataFrame(char_level)
word_level = pd.DataFrame(word_level)
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
# fill missing
if interpolate_method != "ignore":
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
word_level["start"] = word_level["start"].values.tolist()
word_level["end"] = word_level["end"].values.tolist()
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
char_level = char_level.replace({np.nan:None}).to_dict("list")
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else: else:
segment['word-level'].append({"text": curr_word, "start": None, "end": None}) word_level = None
# for per-word .srt ouput aligned_segments.append(
# merge missing words to previous, or merge with next word ahead if idx == 0 {
found_first_ts = False "text": segment["seg-text"][sub_seg_idx],
for x in range(len(t_local)): "start": seg_start_actual,
curr_word = t_words[x] "end": seg_end_actual,
curr_timestamp = t_local[x] "char-segments": char_level,
if curr_timestamp is not None: "word-segments": word_level
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]}) }
found_first_ts = True
elif not drop_non_aligned_words:
# then we merge
if not found_first_ts:
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
else:
word_segments_list[-1]['text'] += ' ' + curr_word
else:
fail_fallback = True
if fail_fallback:
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
prev_t2 = 0
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
if 'vad' in segment:
curr_vdx = 0
curr_text = ''
for wrd_seg in word_segments_list:
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
) )
curr_vdx += 1 if "language" in segment:
curr_text = '' aligned_segments[-1]["language"] = segment["language"]
curr_text += ' ' + wrd_seg['text']
if len(curr_text) > 0: print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}")
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
)
curr_text = ''
total_word_segments_list += word_segments_list
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list} char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx += 1
seg_start_actual, seg_end_actual = None, None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None:
seg_start_actual = start
if end is not None:
seg_end_actual = end
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
# shift segment index by amount of sub-segments
if "seg-text" in segment:
sdx += len(segment["seg-text"])
else:
sdx += 1
# create word level segments for .srt
word_seg = []
for seg in aligned_segments:
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character based
seg["word-segments"] = seg["char-segments"]
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None): def load_align_model(language_code, device, model_name=None):
if model_name is None: if model_name is None:
@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None):
return align_model, align_metadata return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
'''
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
'''
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
"""
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
"""
curr_start = 0 curr_start = 0
curr_end = 0 curr_end = 0
merged_segments = [] merged_segments = []
@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start, "start": curr_start,
"end": curr_end, "end": curr_end,
"segments": seg_idxs, "segments": seg_idxs,
"speakers": speaker_idxs,
}) })
curr_start = seg.start curr_start = seg.start
seg_idxs = [] seg_idxs = []
@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start, "start": curr_start,
"end": curr_end, "end": curr_end,
"segments": seg_idxs, "segments": seg_idxs,
"speakers": speaker_idxs
}) })
return merged_segments return merged_segments
def transcribe_with_vad(
def transcribe_segments(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
merged_segments, vad_pipeline,
mel = None, mel = None,
verbose: Optional[bool] = None,
**kwargs **kwargs
): ):
''' """
Transcribe according to predefined VAD segments. Transcribe per VAD segment
''' """
if mel is None: if mel is None:
mel = log_mel_spectrogram(audio) mel = log_mel_spectrogram(audio)
prev = 0 prev = 0
output = {"segments": []}
output = {'segments': []} vad_segments_list = []
vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list)
for sdx, seg_t in enumerate(merged_segments): for sdx, seg_t in enumerate(vad_segments):
print(sdx, seg_t['start'], seg_t['end'], '...') if verbose:
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH) print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start] local_mel = mel[:, :local_f_end-local_f_start]
result = transcribe(model, audio, mel=local_mel, **kwargs) result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t['text'] = result['text'] seg_t["text"] = result["text"]
output['segments'].append( output["segments"].append(
{ {
'start': seg_t['start'], "start": seg_t["start"],
'end': seg_t['end'], "end": seg_t["end"],
'language': result['language'], "language": result["language"],
'text': result['text'], "text": result["text"],
'seg-text': [x['text'] for x in result['segments']], "seg-text": [x["text"] for x in result["segments"]],
'seg-start': [x['start'] for x in result['segments']], "seg-start": [x["start"] for x in result["segments"]],
'seg-end': [x['end'] for x in result['segments']], "seg-end": [x["end"] for x in result["segments"]],
} }
) )
output['language'] = output['segments'][0]['language'] output["language"] = output["segments"][0]["language"]
return output return output
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = pd.DataFrame(seg['word-segments'])
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment: class Segment:
def __init__(self, start, end, speaker=None): def __init__(self, start, end, speaker=None):
self.start = start self.start = start
@ -589,11 +766,17 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment") parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment") parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.") parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...") # vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str) parser.add_argument("--vad_input", default=None, type=str)
# diarization params
parser.add_argument("--diarize", action='store_true')
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save") parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -627,24 +810,32 @@ def cli():
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend") align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev") align_from_prev: bool = args.pop("align_from_prev")
drop_non_aligned: bool = args.pop("drop_non_aligned") interpolate_method: bool = args.pop("interpolate_method")
vad_filter: bool = args.pop("vad_filter") vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input") vad_input: bool = args.pop("vad_input")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
vad_pipeline = None vad_pipeline = None
if vad_input is not None: if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ") vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter: elif vad_filter:
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
diarize_pipeline = None
if diarize:
from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None: if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
args["language"] = "en" args["language"] = "en"
temperature = args.pop("temperature") temperature = args.pop("temperature")
@ -665,24 +856,10 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
if vad_filter or vad_input is not None: if vad_filter:
output_segments = [] print("Performing VAD...")
if vad_filter: result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
print("Performing VAD...")
# vad_segments = vad_pipeline(audio_path)
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
vad_segments = vad_pipeline(audio_path)
for speech_turn in vad_segments.get_timeline().support():
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
elif vad_input is not None:
# rttm format
for idx, row in vad_input.iterrows():
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
vad_segments = merge_chunks(output_segments)
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
else: else:
vad_segments = None
print("Performing transcription...") print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
@ -693,9 +870,20 @@ def cli():
print("Performing alignment...") print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device, result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned) extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
if diarize:
print("Performing diarization...")
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
# save TXT # save TXT
if output_type in ["txt", "all"]: if output_type in ["txt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
@ -711,19 +899,27 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt) write_srt(result_aligned["segments"], file=srt)
# save per-word SRT # save TSV
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt: if output_type in ["tsv", "all"]:
write_srt(result_aligned["word_segments"], file=srt) with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_tsv(result_aligned["segments"], file=srt)
# save SRT word-level
if output_type in ["srt-word", "all"]:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS # save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: if output_type in ["ass", "all"]:
write_ass(result_aligned["segments"], file=ass) with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if vad_filter is not None: # save ASS character-level
# save per-word SRT if output_type in ["ass-char", "all"]:
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt: with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_srt(result_aligned["vad_segments"], file=srt) write_ass(result_aligned["segments"], file=ass, resolution="char")
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()

View File

@ -1,6 +1,7 @@
import os
import zlib import zlib
from typing import Iterator, TextIO, Tuple, List from typing import Callable, TextIO, Iterator, Tuple
import pandas as pd
def exact_div(x, y): def exact_div(x, y):
assert x % y == 0 assert x % y == 0
@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
flush=True, flush=True,
) )
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
def write_srt(transcript: Iterator[dict], file: TextIO): def write_srt(transcript: Iterator[dict], file: TextIO):
""" """
@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
) )
def write_ass(transcript: Iterator[dict], file: TextIO, def write_ass(transcript: Iterator[dict],
file: TextIO,
resolution: str = "word",
color: str = None, underline=True, color: str = None, underline=True,
prefmt: str = None, suffmt: str = None, prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24, font: str = None, font_size: int = 24,
@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
Note: ass file is used in the same way as srt, vtt, etc. Note: ass file is used in the same way as srt, vtt, etc.
Parameters Parameters
---------- ----------
res: dict transcript: dict
results from modified model results from modified model
ass_path: str file: TextIO
output path (e.g. caption.ass) file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str color: str
color code for a word at its corresponding timestamp color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00) <bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}' return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def dialogue(words: List[str], idx, start, end) -> str: def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
text = ''.join(f' {prefmt}{word}{suffmt}' if idx_0 == -1:
# if not word.startswith(' ') or word == ' ' else text = chars
# f' {prefmt}{word.strip()}{suffmt}') else:
if curr_idx == idx else text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
f' {word}'
for curr_idx, word in enumerate(words))
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \ return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}" f"Default,,0,0,0,,{text.strip() if strip else text}"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = [] ass_arr = []
for segment in transcript: for segment in transcript:
curr_words = [wrd['text'] for wrd in segment['word-level']] if resolution_key in segment:
prev = segment['word-level'][0]['start'] res_segs = pd.DataFrame(segment[resolution_key])
if prev is None:
prev = segment['start'] prev = segment['start']
for wdx, word in enumerate(segment['word-level']): if "speaker" in segment:
if word['start'] is not None: speaker_str = f"[{segment['speaker']}]: "
# fill gap between previous word else:
if word['start'] > prev: speaker_str = ""
filler_ts = { for cdx, crow in res_segs.iterrows():
"words": curr_words, if crow['start'] is not None:
"start": prev, if resolution == "char":
"end": word['start'], idx_0 = cdx
"idx": -1 idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"chars": speaker_str + segment['text'],
"start": prev,
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
} }
ass_arr.append(filler_ts) ass_arr.append(f_word_ts)
prev = crow['end']
# highlight current word
f_word_ts = {
"words": curr_words,
"start": word['start'],
"end": word['end'],
"idx": wdx
}
ass_arr.append(f_word_ts)
prev = word['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr)) ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str) file.write(ass_str)
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()