diff --git a/README.md b/README.md
index f89918a..fb08e80 100644
--- a/README.md
+++ b/README.md
@@ -48,6 +48,13 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
+
New🚨
+
+- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
+- Character level timestamps (see `*.char.ass` file output)
+- Diarization (still in beta, add `--diarization`)
+
+
Setup ⚙️
Install this package using
@@ -76,9 +83,9 @@ Run whisper on example segment (using default params)
whisperx examples/sample01.wav
-For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models e.g.
+For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models and VAD filtering e.g.
- whisperx examples/sample01.wav --model large.en --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
+ whisperx examples/sample01.wav --model large.en --vad_filter --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@@ -162,7 +169,11 @@ The next major upgrade we are working on is whisper with speaker diarization, so
[x] ~~Python usage~~ done
-[ ] Incorporating word-level speaker diarization
+[x] ~~Character level timestamps~~
+
+[x] ~~Incorporating speaker diarization~~
+
+[ ] Improve diarization (word level)
[ ] Inference speedup with batch processing
diff --git a/requirements.txt b/requirements.txt
index 26a53c3..999a8d6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ soundfile
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
+pyannote.audio
diff --git a/whisperx/__init__.py b/whisperx/__init__.py
index 839b29a..4f253b3 100644
--- a/whisperx/__init__.py
+++ b/whisperx/__init__.py
@@ -11,7 +11,7 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
-from .transcribe import transcribe, load_align_model, align
+from .transcribe import transcribe, load_align_model, align, transcribe_with_vad
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
diff --git a/whisperx/audio.py b/whisperx/audio.py
index a3d8a13..b6c7e83 100644
--- a/whisperx/audio.py
+++ b/whisperx/audio.py
@@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
- magnitudes = stft[:, :-1].abs() ** 2
+ magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
@@ -121,4 +121,4 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
- return log_spec
+ return log_spec
\ No newline at end of file
diff --git a/whisperx/model.py b/whisperx/model.py
index ca3928e..9bdd84e 100644
--- a/whisperx/model.py
+++ b/whisperx/model.py
@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
k = kv_cache[self.key]
v = kv_cache[self.value]
- wv = self.qkv_attention(q, k, v, mask)
- return self.out(wv)
+ wv, qk = self.qkv_attention(q, k, v, mask)
+ return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
+ qk = qk.float()
- w = F.softmax(qk.float(), dim=-1).to(q.dtype)
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
@@ -264,4 +265,4 @@ class Whisper(nn.Module):
detect_language = detect_language_function
transcribe = transcribe_function
- decode = decode_function
+ decode = decode_function
\ No newline at end of file
diff --git a/whisperx/normalizers/english.json b/whisperx/normalizers/english.json
index bd84ae7..74a1c35 100644
--- a/whisperx/normalizers/english.json
+++ b/whisperx/normalizers/english.json
@@ -1737,6 +1737,5 @@
"yoghurt": "yogurt",
"yoghurts": "yogurts",
"mhm": "hmm",
- "mm": "hmm",
"mmm": "hmm"
}
\ No newline at end of file
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 9ff5a64..ba92d3d 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim,
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
+from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
import pandas as pd
if TYPE_CHECKING:
@@ -280,8 +280,39 @@ def align(
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
- drop_non_aligned_words: bool = False,
+ interpolate_method: str = "nearest",
):
+ """
+ Force align phoneme recognition predictions to known transcription
+
+ Parameters
+ ----------
+ transcript: Iterator[dict]
+ The Whisper model instance
+
+ model: torch.nn.Module
+ Alignment model (wav2vec2)
+
+ audio: Union[str, np.ndarray, torch.Tensor]
+ The path to the audio file to open, or the audio waveform
+
+ device: str
+ cuda device
+
+ extend_duration: float
+ Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
+
+ If the gzip compression ratio is above this value, treat as failed
+
+ interpolate_method: str ["nearest", "linear", "ignore"]
+ Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
+ "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
+
+ Returns
+ -------
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
+ """
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
@@ -291,171 +322,266 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
- model_dictionary = align_model_metadata['dictionary']
- model_lang = align_model_metadata['language']
- model_type = align_model_metadata['type']
+ model_dictionary = align_model_metadata["dictionary"]
+ model_lang = align_model_metadata["language"]
+ model_type = align_model_metadata["type"]
+
+ aligned_segments = []
prev_t2 = 0
- total_word_segments_list = []
- vad_segments_list = []
- for idx, segment in enumerate(transcript):
- word_segments_list = []
- # first we pad
- t1 = max(segment['start'] - extend_duration, 0)
- t2 = min(segment['end'] + extend_duration, MAX_DURATION)
+ sdx = 0
+ for segment in transcript:
+ while True:
+ segment_align_success = False
- # use prev_t2 as current t1 if it's later
- if start_from_previous and t1 < prev_t2:
- t1 = prev_t2
+ # strip spaces at beginning / end, but keep track of the amount.
+ num_leading = len(segment["text"]) - len(segment["text"].lstrip())
+ num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
+ transcription = segment["text"]
- # check if timestamp range is still valid
- if t1 >= MAX_DURATION:
- print("Failed to align segment: original start time longer than audio duration, skipping...")
- continue
- if t2 - t1 < 0.02:
- print("Failed to align segment: duration smaller than 0.02s time precision")
- continue
+ # TODO: convert number tokenizer / symbols to phonetic words for alignment.
+ # e.g. "$300" -> "three hundred dollars"
+ # currently "$300" is ignored since no characters present in the phonetic dictionary
- f1 = int(t1 * SAMPLE_RATE)
- f2 = int(t2 * SAMPLE_RATE)
-
- waveform_segment = audio[:, f1:f2]
-
- with torch.inference_mode():
- if model_type == "torchaudio":
- emissions, _ = model(waveform_segment.to(device))
- elif model_type == "huggingface":
- emissions = model(waveform_segment.to(device)).logits
+ # split into words
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
+ per_word = transcription.split(" ")
else:
- raise NotImplementedError(f"Align model of type {model_type} not supported.")
- emissions = torch.log_softmax(emissions, dim=-1)
+ per_word = transcription
- emission = emissions[0].cpu().detach()
+ # first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
+ clean_char, clean_cdx = [], []
+ for cdx, char in enumerate(transcription):
+ char_ = char.lower()
+ # wav2vec2 models use "|" character to represent spaces
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
+ char_ = char_.replace(" ", "|")
+
+ # ignore whitespace at beginning and end of transcript
+ if cdx < num_leading:
+ pass
+ elif cdx > len(transcription) - num_trailing - 1:
+ pass
+ elif char_ in model_dictionary.keys():
+ clean_char.append(char_)
+ clean_cdx.append(cdx)
- if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
- ratio = waveform_segment.size(0) / emission.size(0)
- space_idx = model_dictionary['|']
- # find non-vad segments
- for i in range(1, len(segment['vad'])):
- start = segment['vad'][i-1][1]
- end = segment['vad'][i][0]
- if start < end: # check if there is a gap between intervals
- non_vad_f1 = int(start / ratio)
- non_vad_f2 = int(end / ratio)
- # non-vad should be masked, use space to do so
- emission[non_vad_f1:non_vad_f2, :] = float("-inf")
- emission[non_vad_f1:non_vad_f2, space_idx] = 0
+ clean_wdx = []
+ for wdx, wrd in enumerate(per_word):
+ if any([c in model_dictionary.keys() for c in wrd]):
+ clean_wdx.append(wdx)
-
- start = segment['vad'][i][1]
- end = segment['end']
- non_vad_f1 = int(start / ratio)
- non_vad_f2 = int(end / ratio)
- # non-vad should be masked, use space to do so
- emission[non_vad_f1:non_vad_f2, :] = float("-inf")
- emission[non_vad_f1:non_vad_f2, space_idx] = 0
-
- transcription = segment['text'].strip()
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
- t_words = transcription.split(' ')
- else:
- t_words = [c for c in transcription]
-
- t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
- t_words_nonempty = [x for x in t_words_clean if x != ""]
- t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
- segment['word-level'] = []
-
- fail_fallback = False
- if len(t_words_nonempty) > 0:
- transcription_cleaned = "|".join(t_words_nonempty).lower()
+ # if no characters are in the dictionary, then we skip this segment...
+ if len(clean_char) == 0:
+ print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
+ break
+
+ transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
+
+ # pad according original timestamps
+ t1 = max(segment["start"] - extend_duration, 0)
+ t2 = min(segment["end"] + extend_duration, MAX_DURATION)
+
+ # use prev_t2 as current t1 if it"s later
+ if start_from_previous and t1 < prev_t2:
+ t1 = prev_t2
+
+ # check if timestamp range is still valid
+ if t1 >= MAX_DURATION:
+ print("Failed to align segment: original start time longer than audio duration, skipping...")
+ break
+ if t2 - t1 < 0.02:
+ print("Failed to align segment: duration smaller than 0.02s time precision")
+ break
+
+ f1 = int(t1 * SAMPLE_RATE)
+ f2 = int(t2 * SAMPLE_RATE)
+
+ waveform_segment = audio[:, f1:f2]
+
+ with torch.inference_mode():
+ if model_type == "torchaudio":
+ emissions, _ = model(waveform_segment.to(device))
+ elif model_type == "huggingface":
+ emissions = model(waveform_segment.to(device)).logits
+ else:
+ raise NotImplementedError(f"Align model of type {model_type} not supported.")
+ emissions = torch.log_softmax(emissions, dim=-1)
+
+ emission = emissions[0].cpu().detach()
+
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
- fail_fallback = True
- else:
- segments = merge_repeats(path, transcription_cleaned)
- word_segments = merge_words(segments)
- ratio = waveform_segment.size(0) / (trellis.size(0) - 1)
+ break
+ char_segments = merge_repeats(path, transcription_cleaned)
+ # word_segments = merge_words(char_segments)
+
- duration = t2 - t1
- local = []
- t_local = [None] * len(t_words)
- for wdx, word in enumerate(word_segments):
- t1_ = ratio * word.start
- t2_ = ratio * word.end
- local.append((t1_, t2_))
- t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1)
- t1_actual = t1 + local[0][0] * duration
- t2_actual = t1 + local[-1][1] * duration
+ # sub-segments
+ if "seg-text" not in segment:
+ segment["seg-text"] = [transcription]
+
+ v = 0
+ seg_lens = [0] + [len(x) for x in segment["seg-text"]]
+ seg_lens_cumsum = [v := v + n for n in seg_lens]
+ sub_seg_idx = 0
- segment['start'] = t1_actual
- segment['end'] = t2_actual
- prev_t2 = segment['end']
+ char_level = {
+ "start": [],
+ "end": [],
+ "score": [],
+ "word-index": [],
+ }
- # for the .ass output
- for x in range(len(t_local)):
- curr_word = t_words[x]
- curr_timestamp = t_local[x]
- if curr_timestamp is not None:
- segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
+ word_level = {
+ "start": [],
+ "end": [],
+ "score": [],
+ "segment-text-start": [],
+ "segment-text-end": []
+ }
+
+ wdx = 0
+ seg_start_actual, seg_end_actual = None, None
+ duration = t2 - t1
+ ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
+ cdx_prev = 0
+ for cdx, char in enumerate(transcription + " "):
+ is_last = False
+ if cdx == len(transcription):
+ break
+ elif cdx+1 == len(transcription):
+ is_last = True
+
+
+ start, end, score = None, None, None
+ if cdx in clean_cdx:
+ char_seg = char_segments[clean_cdx.index(cdx)]
+ start = char_seg.start * ratio + t1
+ end = char_seg.end * ratio + t1
+ score = char_seg.score
+
+ char_level["start"].append(start)
+ char_level["end"].append(end)
+ char_level["score"].append(score)
+ char_level["word-index"].append(wdx)
+
+ # word-level info
+ if model_lang in LANGUAGES_WITHOUT_SPACES:
+ # character == word
+ wdx += 1
+ elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
+ wdx += 1
+ word_level["start"].append(None)
+ word_level["end"].append(None)
+ word_level["score"].append(None)
+ word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
+ word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
+ cdx_prev = cdx+2
+
+ if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
+ char_level = pd.DataFrame(char_level)
+ word_level = pd.DataFrame(word_level)
+
+ not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
+ word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
+ word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
+
+ # fill missing
+ if interpolate_method != "ignore":
+ word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
+ word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
+ word_level["start"] = word_level["start"].values.tolist()
+ word_level["end"] = word_level["end"].values.tolist()
+ word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
+
+ char_level = char_level.replace({np.nan:None}).to_dict("list")
+ word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else:
- segment['word-level'].append({"text": curr_word, "start": None, "end": None})
+ word_level = None
- # for per-word .srt ouput
- # merge missing words to previous, or merge with next word ahead if idx == 0
- found_first_ts = False
- for x in range(len(t_local)):
- curr_word = t_words[x]
- curr_timestamp = t_local[x]
- if curr_timestamp is not None:
- word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
- found_first_ts = True
- elif not drop_non_aligned_words:
- # then we merge
- if not found_first_ts:
- t_words[x+1] = " ".join([curr_word, t_words[x+1]])
- else:
- word_segments_list[-1]['text'] += ' ' + curr_word
- else:
- fail_fallback = True
-
- if fail_fallback:
- # then we resort back to original whisper timestamps
- # segment['start] and segment['end'] are unchanged
- prev_t2 = 0
- segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
- word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
-
- if 'vad' in segment:
- curr_vdx = 0
- curr_text = ''
- for wrd_seg in word_segments_list:
- if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
- curr_speaker = segment['speakers'][curr_vdx]
- vad_segments_list.append(
- {'start': segment['vad'][curr_vdx][0],
- 'end': segment['vad'][curr_vdx][1],
- 'text': f"[{curr_speaker}]: " + curr_text.strip()}
+ aligned_segments.append(
+ {
+ "text": segment["seg-text"][sub_seg_idx],
+ "start": seg_start_actual,
+ "end": seg_end_actual,
+ "char-segments": char_level,
+ "word-segments": word_level
+ }
)
- curr_vdx += 1
- curr_text = ''
- curr_text += ' ' + wrd_seg['text']
- if len(curr_text) > 0:
- curr_speaker = segment['speakers'][curr_vdx]
- vad_segments_list.append(
- {'start': segment['vad'][curr_vdx][0],
- 'end': segment['vad'][curr_vdx][1],
- 'text': f"[{curr_speaker}]: " + curr_text.strip()}
- )
- curr_text = ''
- total_word_segments_list += word_segments_list
- print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
+ if "language" in segment:
+ aligned_segments[-1]["language"] = segment["language"]
+
+ print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}")
- return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
+ char_level = {
+ "start": [],
+ "end": [],
+ "score": [],
+ "word-index": [],
+ }
+ word_level = {
+ "start": [],
+ "end": [],
+ "score": [],
+ "segment-text-start": [],
+ "segment-text-end": []
+ }
+ wdx = 0
+ cdx_prev = cdx + 2
+ sub_seg_idx += 1
+ seg_start_actual, seg_end_actual = None, None
+
+
+ # take min-max for actual segment-level timestamp
+ if seg_start_actual is None and start is not None:
+ seg_start_actual = start
+ if end is not None:
+ seg_end_actual = end
+
+
+ prev_t2 = segment["end"]
+
+ segment_align_success = True
+ # end while True loop
+ break
+
+ # reset prev_t2 due to drifting issues
+ if not segment_align_success:
+ prev_t2 = 0
+
+ # shift segment index by amount of sub-segments
+ if "seg-text" in segment:
+ sdx += len(segment["seg-text"])
+ else:
+ sdx += 1
+
+ # create word level segments for .srt
+ word_seg = []
+ for seg in aligned_segments:
+ if model_lang in LANGUAGES_WITHOUT_SPACES:
+ # character based
+ seg["word-segments"] = seg["char-segments"]
+ seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
+ seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
+
+ wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
+ for wdx, wrow in wseg.iterrows():
+ if wrow["start"] is not None:
+ word_seg.append(
+ {
+ "start": wrow["start"],
+ "end": wrow["end"],
+ "text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
+ }
+ )
+
+ return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
@@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None):
return align_model, align_metadata
-def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
- '''
- Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
- '''
+def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
+ """
+ Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
+ """
curr_start = 0
curr_end = 0
merged_segments = []
@@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
- "speakers": speaker_idxs,
})
curr_start = seg.start
seg_idxs = []
@@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
- "speakers": speaker_idxs
})
return merged_segments
-
-def transcribe_segments(
+def transcribe_with_vad(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
- merged_segments,
+ vad_pipeline,
mel = None,
+ verbose: Optional[bool] = None,
**kwargs
):
- '''
- Transcribe according to predefined VAD segments.
- '''
+ """
+ Transcribe per VAD segment
+ """
if mel is None:
mel = log_mel_spectrogram(audio)
prev = 0
+ output = {"segments": []}
- output = {'segments': []}
+ vad_segments_list = []
+ vad_segments = vad_pipeline(audio)
+ for speech_turn in vad_segments.get_timeline().support():
+ vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
+ # merge segments to approx 30s inputs to make whisper most appropraite
+ vad_segments = merge_chunks(vad_segments_list)
- for sdx, seg_t in enumerate(merged_segments):
- print(sdx, seg_t['start'], seg_t['end'], '...')
- seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
+ for sdx, seg_t in enumerate(vad_segments):
+ if verbose:
+ print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
+ seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start]
- result = transcribe(model, audio, mel=local_mel, **kwargs)
- seg_t['text'] = result['text']
- output['segments'].append(
+ result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
+ seg_t["text"] = result["text"]
+ output["segments"].append(
{
- 'start': seg_t['start'],
- 'end': seg_t['end'],
- 'language': result['language'],
- 'text': result['text'],
- 'seg-text': [x['text'] for x in result['segments']],
- 'seg-start': [x['start'] for x in result['segments']],
- 'seg-end': [x['end'] for x in result['segments']],
+ "start": seg_t["start"],
+ "end": seg_t["end"],
+ "language": result["language"],
+ "text": result["text"],
+ "seg-text": [x["text"] for x in result["segments"]],
+ "seg-start": [x["start"] for x in result["segments"]],
+ "seg-end": [x["end"] for x in result["segments"]],
}
)
- output['language'] = output['segments'][0]['language']
+ output["language"] = output["segments"][0]["language"]
return output
+
+def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
+
+ for seg in result_segments:
+ wdf = pd.DataFrame(seg['word-segments'])
+ if len(wdf['start'].dropna()) == 0:
+ wdf['start'] = seg['start']
+ wdf['end'] = seg['end']
+ speakers = []
+ for wdx, wrow in wdf.iterrows():
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
+ diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
+ # remove no hit
+ if not fill_nearest:
+ dia_tmp = diarize_df[diarize_df['intersection'] > 0]
+ else:
+ dia_tmp = diarize_df
+ if len(dia_tmp) == 0:
+ speaker = None
+ else:
+ speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
+ speakers.append(speaker)
+ seg['word-segments']['speaker'] = speakers
+ seg["speaker"] = pd.Series(speakers).value_counts().index[0]
+
+ # create word level segments for .srt
+ word_seg = []
+ for seg in result_segments:
+ wseg = pd.DataFrame(seg["word-segments"])
+ for wdx, wrow in wseg.iterrows():
+ if wrow["start"] is not None:
+ speaker = wrow['speaker']
+ if speaker is None or speaker == np.nan:
+ speaker = "UNKNOWN"
+ word_seg.append(
+ {
+ "start": wrow["start"],
+ "end": wrow["end"],
+ "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
+ }
+ )
+
+ # TODO: create segments but split words on new speaker
+
+ return result_segments, word_seg
+
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
@@ -589,11 +766,17 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
- parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
- parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
+ parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
+ # vad params
+ parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str)
+ # diarization params
+ parser.add_argument("--diarize", action='store_true')
+ parser.add_argument("--min_speakers", default=None, type=int)
+ parser.add_argument("--max_speakers", default=None, type=int)
+ # output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
+ parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@@ -627,24 +810,32 @@ def cli():
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
- drop_non_aligned: bool = args.pop("drop_non_aligned")
+ interpolate_method: bool = args.pop("interpolate_method")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
+ diarize: bool = args.pop("diarize")
+ min_speakers: int = args.pop("min_speakers")
+ max_speakers: int = args.pop("max_speakers")
+
vad_pipeline = None
if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter:
from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
- # vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
+
+ diarize_pipeline = None
+ if diarize:
+ from pyannote.audio import Pipeline
+ diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
- warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
+ warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
args["language"] = "en"
temperature = args.pop("temperature")
@@ -665,24 +856,10 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
- if vad_filter or vad_input is not None:
- output_segments = []
- if vad_filter:
- print("Performing VAD...")
- # vad_segments = vad_pipeline(audio_path)
- # for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
- # output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
- vad_segments = vad_pipeline(audio_path)
- for speech_turn in vad_segments.get_timeline().support():
- output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
- elif vad_input is not None:
- # rttm format
- for idx, row in vad_input.iterrows():
- output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
- vad_segments = merge_chunks(output_segments)
- result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
+ if vad_filter:
+ print("Performing VAD...")
+ result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
- vad_segments = None
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
@@ -693,9 +870,20 @@ def cli():
print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
- extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
+ extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
audio_basename = os.path.basename(audio_path)
+ if diarize:
+ print("Performing diarization...")
+ diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
+ diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
+ diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
+ diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
+ # assumes each utterance is single speaker (needs fix)
+ result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
+ result_aligned["segments"] = result_segments
+ result_aligned["word_segments"] = word_segments
+
# save TXT
if output_type in ["txt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
@@ -711,19 +899,27 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
- # save per-word SRT
- with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
- write_srt(result_aligned["word_segments"], file=srt)
+ # save TSV
+ if output_type in ["tsv", "all"]:
+ with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
+ write_tsv(result_aligned["segments"], file=srt)
+
+ # save SRT word-level
+ if output_type in ["srt-word", "all"]:
+ # save per-word SRT
+ with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
+ write_srt(result_aligned["word_segments"], file=srt)
# save ASS
- with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
- write_ass(result_aligned["segments"], file=ass)
+ if output_type in ["ass", "all"]:
+ with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
+ write_ass(result_aligned["segments"], file=ass)
- if vad_filter is not None:
- # save per-word SRT
- with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
- write_srt(result_aligned["vad_segments"], file=srt)
+ # save ASS character-level
+ if output_type in ["ass-char", "all"]:
+ with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
+ write_ass(result_aligned["segments"], file=ass, resolution="char")
-if __name__ == '__main__':
+if __name__ == "__main__":
cli()
diff --git a/whisperx/utils.py b/whisperx/utils.py
index 56e3483..590eaab 100644
--- a/whisperx/utils.py
+++ b/whisperx/utils.py
@@ -1,6 +1,7 @@
+import os
import zlib
-from typing import Iterator, TextIO, Tuple, List
-
+from typing import Callable, TextIO, Iterator, Tuple
+import pandas as pd
def exact_div(x, y):
assert x % y == 0
@@ -60,6 +61,13 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
flush=True,
)
+def write_tsv(transcript: Iterator[dict], file: TextIO):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in transcript:
+ print(round(1000 * segment['start']), file=file, end="\t")
+ print(round(1000 * segment['end']), file=file, end="\t")
+ print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
+
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
@@ -88,7 +96,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
)
-def write_ass(transcript: Iterator[dict], file: TextIO,
+def write_ass(transcript: Iterator[dict],
+ file: TextIO,
+ resolution: str = "word",
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
@@ -102,10 +112,12 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
- res: dict
+ transcript: dict
results from modified model
- ass_path: str
- output path (e.g. caption.ass)
+ file: TextIO
+ file object to write to
+ resolution: str
+ "word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
@@ -176,49 +188,67 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
- def dialogue(words: List[str], idx, start, end) -> str:
- text = ''.join(f' {prefmt}{word}{suffmt}'
- # if not word.startswith(' ') or word == ' ' else
- # f' {prefmt}{word.strip()}{suffmt}')
- if curr_idx == idx else
- f' {word}'
- for curr_idx, word in enumerate(words))
+ def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
+ if idx_0 == -1:
+ text = chars
+ else:
+ text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
-
+ if resolution == "word":
+ resolution_key = "word-segments"
+ elif resolution == "char":
+ resolution_key = "char-segments"
+ else:
+ raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
+
ass_arr = []
for segment in transcript:
- curr_words = [wrd['text'] for wrd in segment['word-level']]
- prev = segment['word-level'][0]['start']
- if prev is None:
+ if resolution_key in segment:
+ res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
- for wdx, word in enumerate(segment['word-level']):
- if word['start'] is not None:
- # fill gap between previous word
- if word['start'] > prev:
- filler_ts = {
- "words": curr_words,
- "start": prev,
- "end": word['start'],
- "idx": -1
+ if "speaker" in segment:
+ speaker_str = f"[{segment['speaker']}]: "
+ else:
+ speaker_str = ""
+ for cdx, crow in res_segs.iterrows():
+ if crow['start'] is not None:
+ if resolution == "char":
+ idx_0 = cdx
+ idx_1 = cdx + 1
+ elif resolution == "word":
+ idx_0 = int(crow["segment-text-start"])
+ idx_1 = int(crow["segment-text-end"])
+ # fill gap
+ if crow['start'] > prev:
+ filler_ts = {
+ "chars": speaker_str + segment['text'],
+ "start": prev,
+ "end": crow['start'],
+ "idx_0": -1,
+ "idx_1": -1
+ }
+
+ ass_arr.append(filler_ts)
+ # highlight current word
+ f_word_ts = {
+ "chars": speaker_str + segment['text'],
+ "start": crow['start'],
+ "end": crow['end'],
+ "idx_0": idx_0 + len(speaker_str),
+ "idx_1": idx_1 + len(speaker_str)
}
- ass_arr.append(filler_ts)
-
- # highlight current word
- f_word_ts = {
- "words": curr_words,
- "start": word['start'],
- "end": word['end'],
- "idx": wdx
- }
- ass_arr.append(f_word_ts)
-
- prev = word['end']
-
-
+ ass_arr.append(f_word_ts)
+ prev = crow['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str)
+
+def interpolate_nans(x, method='nearest'):
+ if x.notnull().sum() > 1:
+ return x.interpolate(method=method).ffill().bfill()
+ else:
+ return x.ffill().bfill()
\ No newline at end of file