vad filter

This commit is contained in:
Max Bain
2023-01-20 12:54:20 +00:00
parent 78c87d3bfd
commit ba102feb7f
3 changed files with 185 additions and 12 deletions

View File

@ -116,14 +116,14 @@ audio_file = "audio.mp3"
model = whisperx.load_model("large", device)
result = model.transcribe(audio_file)
print(result["segments"]) # before alignment
# load alignment model and metadata
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
print(result["segments"]) # before alignment
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
```

View File

@ -8,11 +8,12 @@ import torch
import torchaudio
from transformers import AutoProcessor, Wav2Vec2ForCTC
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim, log_mel_spectrogram, load_audio
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
import pandas as pd
if TYPE_CHECKING:
from .model import Whisper
@ -45,6 +46,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = False, # turn off by default due to errors it causes
mel: np.ndarray = None,
**decode_options,
):
"""
@ -100,7 +102,8 @@ def transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
if mel is None:
mel = log_mel_spectrogram(audio)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
@ -293,8 +296,10 @@ def align(
model_type = align_model_metadata['type']
prev_t2 = 0
word_segments_list = []
total_word_segments_list = []
vad_segments_list = []
for idx, segment in enumerate(transcript):
word_segments_list = []
# first we pad
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
@ -326,6 +331,30 @@ def align(
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
ratio = waveform_segment.size(0) / emission.size(0)
space_idx = model_dictionary['|']
# find non-vad segments
for i in range(1, len(segment['vad'])):
start = segment['vad'][i-1][1]
end = segment['vad'][i][0]
if start < end: # check if there is a gap between intervals
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
start = segment['vad'][i][1]
end = segment['end']
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
transcription = segment['text'].strip()
if model_lang not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
@ -400,9 +429,33 @@ def align(
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
if 'vad' in segment:
curr_vdx = 0
curr_text = ''
for wrd_seg in word_segments_list:
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
)
curr_vdx += 1
curr_text = ''
curr_text += ' ' + wrd_seg['text']
if len(curr_text) > 0:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
)
curr_text = ''
total_word_segments_list += word_segments_list
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
return {"segments": transcript, "word_segments": word_segments_list}
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
@ -439,6 +492,91 @@ def load_align_model(language_code, device, model_name=None):
return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
'''
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
'''
curr_start = 0
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
for sdx, seg in enumerate(segments):
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs
})
return merged_segments
def transcribe_segments(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
merged_segments,
mel = None,
**kwargs
):
'''
Transcribe according to predefined VAD segments.
'''
if mel is None:
mel = log_mel_spectrogram(audio)
prev = 0
output = {'segments': []}
for sdx, seg_t in enumerate(merged_segments):
print(sdx, seg_t['start'], seg_t['end'], '...')
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start]
result = transcribe(model, audio, mel=local_mel, **kwargs)
seg_t['text'] = result['text']
output['segments'].append(
{
'start': seg_t['start'],
'end': seg_t['end'],
'language': result['language'],
'text': result['text'],
'seg-text': [x['text'] for x in result['segments']],
'seg-start': [x['start'] for x in result['segments']],
'seg-end': [x['end'] for x in result['segments']],
}
)
output['language'] = output['segments'][0]['language']
return output
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
self.speaker = speaker
def cli():
from . import available_models
@ -452,7 +590,8 @@ def cli():
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
parser.add_argument("--vad_input", default=None, type=str)
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
@ -490,6 +629,17 @@ def cli():
align_from_prev: bool = args.pop("align_from_prev")
drop_non_aligned: bool = args.pop("drop_non_aligned")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
vad_pipeline = None
if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter:
from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
@ -515,8 +665,26 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
if vad_filter or vad_input is not None:
output_segments = []
if vad_filter:
print("Performing VAD...")
# vad_segments = vad_pipeline(audio_path)
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
vad_segments = vad_pipeline(audio_path)
for speech_turn in vad_segments.get_timeline().support():
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
elif vad_input is not None:
# rttm format
for idx, row in vad_input.iterrows():
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
vad_segments = merge_chunks(output_segments)
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
else:
vad_segments = None
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
if result["language"] != align_metadata["language"]:
# load new language
@ -548,8 +716,13 @@ def cli():
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
# with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
# write_ass(result_aligned["segments"], file=ass)
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if vad_filter is not None:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["vad_segments"], file=srt)
if __name__ == '__main__':

View File

@ -193,7 +193,7 @@ def write_ass(transcript: Iterator[dict], file: TextIO,
curr_words = [wrd['text'] for wrd in segment['word-level']]
prev = segment['word-level'][0]['start']
if prev is None:
prev = 0
prev = segment['start']
for wdx, word in enumerate(segment['word-level']):
if word['start'] is not None:
# fill gap between previous word