mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
update readme
This commit is contained in:
@ -585,10 +585,9 @@ def cli():
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
# vad params
|
||||
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
||||
parser.add_argument("--vad_input", default=None, type=str)
|
||||
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action='store_true')
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int)
|
||||
parser.add_argument("--max_speakers", default=None, type=int)
|
||||
# output save params
|
||||
@ -632,7 +631,6 @@ def cli():
|
||||
|
||||
hf_token: str = args.pop("hf_token")
|
||||
vad_filter: bool = args.pop("vad_filter")
|
||||
vad_input: bool = args.pop("vad_input")
|
||||
parallel_bs: int = args.pop("parallel_bs")
|
||||
|
||||
diarize: bool = args.pop("diarize")
|
||||
@ -640,9 +638,9 @@ def cli():
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
|
||||
vad_pipeline = None
|
||||
if vad_input is not None:
|
||||
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
||||
elif vad_filter:
|
||||
if vad_filter:
|
||||
if hf_token is None:
|
||||
print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
|
||||
from pyannote.audio import Inference
|
||||
vad_pipeline = Inference("pyannote/segmentation",
|
||||
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||
@ -650,6 +648,8 @@ def cli():
|
||||
|
||||
diarize_pipeline = None
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
from pyannote.audio import Pipeline
|
||||
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
|
||||
use_auth_token=hf_token)
|
||||
@ -756,7 +756,7 @@ def cli():
|
||||
# save word tsv
|
||||
if output_type in ["vad"]:
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
|
||||
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])
|
||||
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
|
||||
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -65,8 +65,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in transcript:
|
||||
print(round(1000 * segment['start']), file=file, end="\t")
|
||||
print(round(1000 * segment['end']), file=file, end="\t")
|
||||
print(segment['start'], file=file, end="\t")
|
||||
print(segment['end'], file=file, end="\t")
|
||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
|
@ -137,8 +137,6 @@ class Binarize:
|
||||
|
||||
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
# because of padding, some active regions might be overlapping: merge them.
|
||||
# also: fill same speaker gaps shorter than min_duration_off
|
||||
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
@ -161,16 +159,27 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyannote.audio import Inference
|
||||
hook = lambda segmentation: segmentation
|
||||
inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||||
audio = "/tmp/11962.wav"
|
||||
scores = inference(audio)
|
||||
binarize = Binarize(max_duration=15)
|
||||
anno = binarize(scores)
|
||||
res = []
|
||||
for ann in anno.get_timeline():
|
||||
res.append((ann.start, ann.end))
|
||||
# from pyannote.audio import Inference
|
||||
# hook = lambda segmentation: segmentation
|
||||
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||||
# audio = "/tmp/11962.wav"
|
||||
# scores = inference(audio)
|
||||
# binarize = Binarize(max_duration=15)
|
||||
# anno = binarize(scores)
|
||||
# res = []
|
||||
# for ann in anno.get_timeline():
|
||||
# res.append((ann.start, ann.end))
|
||||
|
||||
res = pd.DataFrame(res)
|
||||
res[2] = res[1] - res[0]
|
||||
# res = pd.DataFrame(res)
|
||||
# res[2] = res[1] - res[0]
|
||||
import pandas as pd
|
||||
input_fp = "tt298650_sync.wav"
|
||||
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
|
||||
print(len(df))
|
||||
N = 0.15
|
||||
g = df[0].sub(df[1].shift())
|
||||
input_base = input_fp.split('.')[0]
|
||||
df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
|
||||
df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
|
||||
print(df)
|
||||
import pdb; pdb.set_trace()
|
Reference in New Issue
Block a user