mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
vad filter
This commit is contained in:
@ -116,14 +116,14 @@ audio_file = "audio.mp3"
|
|||||||
model = whisperx.load_model("large", device)
|
model = whisperx.load_model("large", device)
|
||||||
result = model.transcribe(audio_file)
|
result = model.transcribe(audio_file)
|
||||||
|
|
||||||
|
print(result["segments"]) # before alignment
|
||||||
|
|
||||||
# load alignment model and metadata
|
# load alignment model and metadata
|
||||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||||
|
|
||||||
# align whisper output
|
# align whisper output
|
||||||
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
|
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
|
||||||
|
|
||||||
print(result["segments"]) # before alignment
|
|
||||||
|
|
||||||
print(result_aligned["segments"]) # after alignment
|
print(result_aligned["segments"]) # after alignment
|
||||||
print(result_aligned["word_segments"]) # after alignment
|
print(result_aligned["word_segments"]) # after alignment
|
||||||
```
|
```
|
||||||
|
@ -8,11 +8,12 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||||
import tqdm
|
import tqdm
|
||||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
|
||||||
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
@ -45,6 +46,7 @@ def transcribe(
|
|||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
|
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
|
||||||
|
mel: np.ndarray = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -100,6 +102,7 @@ def transcribe(
|
|||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
|
if mel is None:
|
||||||
mel = log_mel_spectrogram(audio)
|
mel = log_mel_spectrogram(audio)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
@ -293,8 +296,10 @@ def align(
|
|||||||
model_type = align_model_metadata['type']
|
model_type = align_model_metadata['type']
|
||||||
|
|
||||||
prev_t2 = 0
|
prev_t2 = 0
|
||||||
word_segments_list = []
|
total_word_segments_list = []
|
||||||
|
vad_segments_list = []
|
||||||
for idx, segment in enumerate(transcript):
|
for idx, segment in enumerate(transcript):
|
||||||
|
word_segments_list = []
|
||||||
# first we pad
|
# first we pad
|
||||||
t1 = max(segment['start'] - extend_duration, 0)
|
t1 = max(segment['start'] - extend_duration, 0)
|
||||||
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
|
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
|
||||||
@ -326,6 +331,30 @@ def align(
|
|||||||
emissions = torch.log_softmax(emissions, dim=-1)
|
emissions = torch.log_softmax(emissions, dim=-1)
|
||||||
|
|
||||||
emission = emissions[0].cpu().detach()
|
emission = emissions[0].cpu().detach()
|
||||||
|
|
||||||
|
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
|
||||||
|
ratio = waveform_segment.size(0) / emission.size(0)
|
||||||
|
space_idx = model_dictionary['|']
|
||||||
|
# find non-vad segments
|
||||||
|
for i in range(1, len(segment['vad'])):
|
||||||
|
start = segment['vad'][i-1][1]
|
||||||
|
end = segment['vad'][i][0]
|
||||||
|
if start < end: # check if there is a gap between intervals
|
||||||
|
non_vad_f1 = int(start / ratio)
|
||||||
|
non_vad_f2 = int(end / ratio)
|
||||||
|
# non-vad should be masked, use space to do so
|
||||||
|
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
||||||
|
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
||||||
|
|
||||||
|
|
||||||
|
start = segment['vad'][i][1]
|
||||||
|
end = segment['end']
|
||||||
|
non_vad_f1 = int(start / ratio)
|
||||||
|
non_vad_f2 = int(end / ratio)
|
||||||
|
# non-vad should be masked, use space to do so
|
||||||
|
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
|
||||||
|
emission[non_vad_f1:non_vad_f2, space_idx] = 0
|
||||||
|
|
||||||
transcription = segment['text'].strip()
|
transcription = segment['text'].strip()
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
t_words = transcription.split(' ')
|
t_words = transcription.split(' ')
|
||||||
@ -400,9 +429,33 @@ def align(
|
|||||||
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||||
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
|
||||||
|
|
||||||
|
if 'vad' in segment:
|
||||||
|
curr_vdx = 0
|
||||||
|
curr_text = ''
|
||||||
|
for wrd_seg in word_segments_list:
|
||||||
|
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
|
||||||
|
curr_speaker = segment['speakers'][curr_vdx]
|
||||||
|
vad_segments_list.append(
|
||||||
|
{'start': segment['vad'][curr_vdx][0],
|
||||||
|
'end': segment['vad'][curr_vdx][1],
|
||||||
|
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
||||||
|
)
|
||||||
|
curr_vdx += 1
|
||||||
|
curr_text = ''
|
||||||
|
curr_text += ' ' + wrd_seg['text']
|
||||||
|
if len(curr_text) > 0:
|
||||||
|
curr_speaker = segment['speakers'][curr_vdx]
|
||||||
|
vad_segments_list.append(
|
||||||
|
{'start': segment['vad'][curr_vdx][0],
|
||||||
|
'end': segment['vad'][curr_vdx][1],
|
||||||
|
'text': f"[{curr_speaker}]: " + curr_text.strip()}
|
||||||
|
)
|
||||||
|
curr_text = ''
|
||||||
|
total_word_segments_list += word_segments_list
|
||||||
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
|
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
|
||||||
|
|
||||||
return {"segments": transcript, "word_segments": word_segments_list}
|
|
||||||
|
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
|
||||||
|
|
||||||
def load_align_model(language_code, device, model_name=None):
|
def load_align_model(language_code, device, model_name=None):
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
@ -439,6 +492,91 @@ def load_align_model(language_code, device, model_name=None):
|
|||||||
|
|
||||||
return align_model, align_metadata
|
return align_model, align_metadata
|
||||||
|
|
||||||
|
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
|
||||||
|
'''
|
||||||
|
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
|
||||||
|
'''
|
||||||
|
|
||||||
|
curr_start = 0
|
||||||
|
curr_end = 0
|
||||||
|
merged_segments = []
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
for sdx, seg in enumerate(segments):
|
||||||
|
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
"speakers": speaker_idxs,
|
||||||
|
})
|
||||||
|
curr_start = seg.start
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
curr_end = seg.end
|
||||||
|
seg_idxs.append((seg.start, seg.end))
|
||||||
|
speaker_idxs.append(seg.speaker)
|
||||||
|
# add final
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
"speakers": speaker_idxs
|
||||||
|
})
|
||||||
|
return merged_segments
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_segments(
|
||||||
|
model: "Whisper",
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
merged_segments,
|
||||||
|
mel = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Transcribe according to predefined VAD segments.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if mel is None:
|
||||||
|
mel = log_mel_spectrogram(audio)
|
||||||
|
|
||||||
|
prev = 0
|
||||||
|
|
||||||
|
output = {'segments': []}
|
||||||
|
|
||||||
|
for sdx, seg_t in enumerate(merged_segments):
|
||||||
|
print(sdx, seg_t['start'], seg_t['end'], '...')
|
||||||
|
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
|
||||||
|
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
||||||
|
mel = mel[:, local_f_start:] # seek forward
|
||||||
|
prev = seg_f_start
|
||||||
|
local_mel = mel[:, :local_f_end-local_f_start]
|
||||||
|
result = transcribe(model, audio, mel=local_mel, **kwargs)
|
||||||
|
seg_t['text'] = result['text']
|
||||||
|
output['segments'].append(
|
||||||
|
{
|
||||||
|
'start': seg_t['start'],
|
||||||
|
'end': seg_t['end'],
|
||||||
|
'language': result['language'],
|
||||||
|
'text': result['text'],
|
||||||
|
'seg-text': [x['text'] for x in result['segments']],
|
||||||
|
'seg-start': [x['start'] for x in result['segments']],
|
||||||
|
'seg-end': [x['end'] for x in result['segments']],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output['language'] = output['segments'][0]['language']
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
class Segment:
|
||||||
|
def __init__(self, start, end, speaker=None):
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.speaker = speaker
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
@ -452,7 +590,8 @@ def cli():
|
|||||||
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
||||||
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
||||||
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
||||||
|
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
|
||||||
|
parser.add_argument("--vad_input", default=None, type=str)
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
|
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
|
||||||
|
|
||||||
@ -490,6 +629,17 @@ def cli():
|
|||||||
align_from_prev: bool = args.pop("align_from_prev")
|
align_from_prev: bool = args.pop("align_from_prev")
|
||||||
drop_non_aligned: bool = args.pop("drop_non_aligned")
|
drop_non_aligned: bool = args.pop("drop_non_aligned")
|
||||||
|
|
||||||
|
vad_filter: bool = args.pop("vad_filter")
|
||||||
|
vad_input: bool = args.pop("vad_input")
|
||||||
|
|
||||||
|
vad_pipeline = None
|
||||||
|
if vad_input is not None:
|
||||||
|
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
||||||
|
elif vad_filter:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
|
||||||
|
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
@ -515,6 +665,24 @@ def cli():
|
|||||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
|
if vad_filter or vad_input is not None:
|
||||||
|
output_segments = []
|
||||||
|
if vad_filter:
|
||||||
|
print("Performing VAD...")
|
||||||
|
# vad_segments = vad_pipeline(audio_path)
|
||||||
|
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
|
||||||
|
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
|
||||||
|
vad_segments = vad_pipeline(audio_path)
|
||||||
|
for speech_turn in vad_segments.get_timeline().support():
|
||||||
|
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
|
elif vad_input is not None:
|
||||||
|
# rttm format
|
||||||
|
for idx, row in vad_input.iterrows():
|
||||||
|
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
|
||||||
|
vad_segments = merge_chunks(output_segments)
|
||||||
|
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
|
||||||
|
else:
|
||||||
|
vad_segments = None
|
||||||
print("Performing transcription...")
|
print("Performing transcription...")
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
|
||||||
@ -548,8 +716,13 @@ def cli():
|
|||||||
write_srt(result_aligned["word_segments"], file=srt)
|
write_srt(result_aligned["word_segments"], file=srt)
|
||||||
|
|
||||||
# save ASS
|
# save ASS
|
||||||
# with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||||
# write_ass(result_aligned["segments"], file=ass)
|
write_ass(result_aligned["segments"], file=ass)
|
||||||
|
|
||||||
|
if vad_filter is not None:
|
||||||
|
# save per-word SRT
|
||||||
|
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
|
||||||
|
write_srt(result_aligned["vad_segments"], file=srt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -193,7 +193,7 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
|
|||||||
curr_words = [wrd['text'] for wrd in segment['word-level']]
|
curr_words = [wrd['text'] for wrd in segment['word-level']]
|
||||||
prev = segment['word-level'][0]['start']
|
prev = segment['word-level'][0]['start']
|
||||||
if prev is None:
|
if prev is None:
|
||||||
prev = 0
|
prev = segment['start']
|
||||||
for wdx, word in enumerate(segment['word-level']):
|
for wdx, word in enumerate(segment['word-level']):
|
||||||
if word['start'] is not None:
|
if word['start'] is not None:
|
||||||
# fill gap between previous word
|
# fill gap between previous word
|
||||||
|
Reference in New Issue
Block a user