new logic, diarization, vad filtering

This commit is contained in:
Max Bain
2023-01-24 15:02:08 +00:00
parent ba102feb7f
commit d395c21b83
8 changed files with 498 additions and 260 deletions

View File

@ -12,7 +12,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, CHUNK_LENGTH, pad_or_trim,
from .alignment import get_trellis, backtrack, merge_repeats, merge_words
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt, write_ass
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
import pandas as pd
if TYPE_CHECKING:
@ -280,8 +280,39 @@ def align(
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
drop_non_aligned_words: bool = False,
interpolate_method: str = "nearest",
):
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
@ -291,171 +322,266 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata['dictionary']
model_lang = align_model_metadata['language']
model_type = align_model_metadata['type']
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
total_word_segments_list = []
vad_segments_list = []
for idx, segment in enumerate(transcript):
word_segments_list = []
# first we pad
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
sdx = 0
for segment in transcript:
while True:
segment_align_success = False
# use prev_t2 as current t1 if it's later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
continue
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
continue
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
per_word = transcription
emission = emissions[0].cpu().detach()
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
if "vad" in segment and len(segment['vad']) > 1 and '|' in model_dictionary:
ratio = waveform_segment.size(0) / emission.size(0)
space_idx = model_dictionary['|']
# find non-vad segments
for i in range(1, len(segment['vad'])):
start = segment['vad'][i-1][1]
end = segment['vad'][i][0]
if start < end: # check if there is a gap between intervals
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
start = segment['vad'][i][1]
end = segment['end']
non_vad_f1 = int(start / ratio)
non_vad_f2 = int(end / ratio)
# non-vad should be masked, use space to do so
emission[non_vad_f1:non_vad_f2, :] = float("-inf")
emission[non_vad_f1:non_vad_f2, space_idx] = 0
transcription = segment['text'].strip()
if model_lang not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription]
t_words_clean = [''.join([w for w in word if w.lower() in model_dictionary.keys()]) for word in t_words]
t_words_nonempty = [x for x in t_words_clean if x != ""]
t_words_nonempty_idx = [x for x in range(len(t_words_clean)) if t_words_clean[x] != ""]
segment['word-level'] = []
fail_fallback = False
if len(t_words_nonempty) > 0:
transcription_cleaned = "|".join(t_words_nonempty).lower()
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens)
if path is None:
print("Failed to align segment: backtrack failed, resorting to original...")
fail_fallback = True
else:
segments = merge_repeats(path, transcription_cleaned)
word_segments = merge_words(segments)
ratio = waveform_segment.size(0) / (trellis.size(0) - 1)
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
duration = t2 - t1
local = []
t_local = [None] * len(t_words)
for wdx, word in enumerate(word_segments):
t1_ = ratio * word.start
t2_ = ratio * word.end
local.append((t1_, t2_))
t_local[t_words_nonempty_idx[wdx]] = (t1_ * duration + t1, t2_ * duration + t1)
t1_actual = t1 + local[0][0] * duration
t2_actual = t1 + local[-1][1] * duration
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
v = 0
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = [v := v + n for n in seg_lens]
sub_seg_idx = 0
segment['start'] = t1_actual
segment['end'] = t2_actual
prev_t2 = segment['end']
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
# for the .ass output
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
segment['word-level'].append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
seg_start_actual, seg_end_actual = None, None
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
cdx_prev = 0
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
score = char_seg.score
char_level["start"].append(start)
char_level["end"].append(end)
char_level["score"].append(score)
char_level["word-index"].append(wdx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
word_level["start"].append(None)
word_level["end"].append(None)
word_level["score"].append(None)
word_level["segment-text-start"].append(cdx_prev-seg_lens_cumsum[sub_seg_idx])
word_level["segment-text-end"].append(cdx+1-seg_lens_cumsum[sub_seg_idx])
cdx_prev = cdx+2
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_level = pd.DataFrame(char_level)
word_level = pd.DataFrame(word_level)
not_space = pd.Series(list(segment["seg-text"][sub_seg_idx])) != " "
word_level["start"] = char_level[not_space].groupby("word-index")["start"].min() # take min of all chars in a word ignoring space
word_level["end"] = char_level[not_space].groupby("word-index")["end"].max() # take max of all chars in a word
# fill missing
if interpolate_method != "ignore":
word_level["start"] = interpolate_nans(word_level["start"], method=interpolate_method)
word_level["end"] = interpolate_nans(word_level["end"], method=interpolate_method)
word_level["start"] = word_level["start"].values.tolist()
word_level["end"] = word_level["end"].values.tolist()
word_level["score"] = char_level.groupby("word-index")["score"].mean() # take mean of all scores
char_level = char_level.replace({np.nan:None}).to_dict("list")
word_level = pd.DataFrame(word_level).replace({np.nan:None}).to_dict("list")
else:
segment['word-level'].append({"text": curr_word, "start": None, "end": None})
word_level = None
# for per-word .srt ouput
# merge missing words to previous, or merge with next word ahead if idx == 0
found_first_ts = False
for x in range(len(t_local)):
curr_word = t_words[x]
curr_timestamp = t_local[x]
if curr_timestamp is not None:
word_segments_list.append({"text": curr_word, "start": curr_timestamp[0], "end": curr_timestamp[1]})
found_first_ts = True
elif not drop_non_aligned_words:
# then we merge
if not found_first_ts:
t_words[x+1] = " ".join([curr_word, t_words[x+1]])
else:
word_segments_list[-1]['text'] += ' ' + curr_word
else:
fail_fallback = True
if fail_fallback:
# then we resort back to original whisper timestamps
# segment['start] and segment['end'] are unchanged
prev_t2 = 0
segment['word-level'].append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
word_segments_list.append({"text": segment['text'], "start": segment['start'], "end":segment['end']})
if 'vad' in segment:
curr_vdx = 0
curr_text = ''
for wrd_seg in word_segments_list:
if wrd_seg['start'] > segment['vad'][curr_vdx][1]:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
aligned_segments.append(
{
"text": segment["seg-text"][sub_seg_idx],
"start": seg_start_actual,
"end": seg_end_actual,
"char-segments": char_level,
"word-segments": word_level
}
)
curr_vdx += 1
curr_text = ''
curr_text += ' ' + wrd_seg['text']
if len(curr_text) > 0:
curr_speaker = segment['speakers'][curr_vdx]
vad_segments_list.append(
{'start': segment['vad'][curr_vdx][0],
'end': segment['vad'][curr_vdx][1],
'text': f"[{curr_speaker}]: " + curr_text.strip()}
)
curr_text = ''
total_word_segments_list += word_segments_list
print(f"[{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}] {segment['text']}")
if "language" in segment:
aligned_segments[-1]["language"] = segment["language"]
print(f"[{format_timestamp(aligned_segments[-1]['start'])} --> {format_timestamp(aligned_segments[-1]['end'])}] {aligned_segments[-1]['text']}")
return {"segments": transcript, "word_segments": total_word_segments_list, "vad_segments": vad_segments_list}
char_level = {
"start": [],
"end": [],
"score": [],
"word-index": [],
}
word_level = {
"start": [],
"end": [],
"score": [],
"segment-text-start": [],
"segment-text-end": []
}
wdx = 0
cdx_prev = cdx + 2
sub_seg_idx += 1
seg_start_actual, seg_end_actual = None, None
# take min-max for actual segment-level timestamp
if seg_start_actual is None and start is not None:
seg_start_actual = start
if end is not None:
seg_end_actual = end
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
# shift segment index by amount of sub-segments
if "seg-text" in segment:
sdx += len(segment["seg-text"])
else:
sdx += 1
# create word level segments for .srt
word_seg = []
for seg in aligned_segments:
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character based
seg["word-segments"] = seg["char-segments"]
seg["word-segments"]["segment-text-start"] = range(len(seg['word-segments']['start']))
seg["word-segments"]["segment-text-end"] = range(1, len(seg['word-segments']['start'])+1)
wseg = pd.DataFrame(seg["word-segments"]).replace({np.nan:None})
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
return {"segments": aligned_segments, "word_segments": word_seg}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
@ -492,11 +618,11 @@ def load_align_model(language_code, device, model_name=None):
return align_model, align_metadata
def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
'''
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
'''
def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
"""
Merge VAD segments into larger segments of size ~CHUNK_LENGTH.
"""
curr_start = 0
curr_end = 0
merged_segments = []
@ -508,7 +634,6 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs,
})
curr_start = seg.start
seg_idxs = []
@ -521,55 +646,107 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH, speakers=False):
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
"speakers": speaker_idxs
})
return merged_segments
def transcribe_segments(
def transcribe_with_vad(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
merged_segments,
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
**kwargs
):
'''
Transcribe according to predefined VAD segments.
'''
"""
Transcribe per VAD segment
"""
if mel is None:
mel = log_mel_spectrogram(audio)
prev = 0
output = {"segments": []}
output = {'segments': []}
vad_segments_list = []
vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list)
for sdx, seg_t in enumerate(merged_segments):
print(sdx, seg_t['start'], seg_t['end'], '...')
seg_f_start, seg_f_end = int(seg_t['start'] * SAMPLE_RATE / HOP_LENGTH), int(seg_t['end'] * SAMPLE_RATE / HOP_LENGTH)
for sdx, seg_t in enumerate(vad_segments):
if verbose:
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE / HOP_LENGTH), int(seg_t["end"] * SAMPLE_RATE / HOP_LENGTH)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
mel = mel[:, local_f_start:] # seek forward
prev = seg_f_start
local_mel = mel[:, :local_f_end-local_f_start]
result = transcribe(model, audio, mel=local_mel, **kwargs)
seg_t['text'] = result['text']
output['segments'].append(
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t["text"] = result["text"]
output["segments"].append(
{
'start': seg_t['start'],
'end': seg_t['end'],
'language': result['language'],
'text': result['text'],
'seg-text': [x['text'] for x in result['segments']],
'seg-start': [x['start'] for x in result['segments']],
'seg-end': [x['end'] for x in result['segments']],
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output['language'] = output['segments'][0]['language']
output["language"] = output["segments"][0]["language"]
return output
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = pd.DataFrame(seg['word-segments'])
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
seg["speaker"] = pd.Series(speakers).value_counts().index[0]
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
self.start = start
@ -589,11 +766,17 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD...")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str)
# diarization params
parser.add_argument("--diarize", action='store_true')
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -627,24 +810,32 @@ def cli():
align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
drop_non_aligned: bool = args.pop("drop_non_aligned")
interpolate_method: bool = args.pop("interpolate_method")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
vad_pipeline = None
if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter:
from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection")
# vad_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
diarize_pipeline = None
if diarize:
from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(f'{model_name} is an English-only model but receipted "{args["language"]}"; using English instead.')
args["language"] = "en"
temperature = args.pop("temperature")
@ -665,24 +856,10 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
if vad_filter or vad_input is not None:
output_segments = []
if vad_filter:
print("Performing VAD...")
# vad_segments = vad_pipeline(audio_path)
# for speech_turn, track, speaker in vad_segments.itertracks(yield_label=True):
# output_segments.append(Segment(speech_turn.start, speech_turn.end, speaker))
vad_segments = vad_pipeline(audio_path)
for speech_turn in vad_segments.get_timeline().support():
output_segments.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
elif vad_input is not None:
# rttm format
for idx, row in vad_input.iterrows():
output_segments.append(Segment(row[3], row[3]+row[4], f"SPEAKER {row[7]}"))
vad_segments = merge_chunks(output_segments)
result = transcribe_segments(model, audio_path, merged_segments=vad_segments, temperature=temperature, **args)
if vad_filter:
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
vad_segments = None
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
@ -693,9 +870,20 @@ def cli():
print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
audio_basename = os.path.basename(audio_path)
if diarize:
print("Performing diarization...")
diarize_segments = diarize_pipeline(audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
# save TXT
if output_type in ["txt", "all"]:
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
@ -711,19 +899,27 @@ def cli():
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["segments"], file=srt)
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save TSV
if output_type in ["tsv", "all"]:
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_tsv(result_aligned["segments"], file=srt)
# save SRT word-level
if output_type in ["srt-word", "all"]:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".word.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if output_type in ["ass", "all"]:
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if vad_filter is not None:
# save per-word SRT
with open(os.path.join(output_dir, audio_basename + ".vad.srt"), "w", encoding="utf-8") as srt:
write_srt(result_aligned["vad_segments"], file=srt)
# save ASS character-level
if output_type in ["ass-char", "all"]:
with open(os.path.join(output_dir, audio_basename + ".char.ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass, resolution="char")
if __name__ == '__main__':
if __name__ == "__main__":
cli()