Merge branch 'main' into cuda-11.8

This commit is contained in:
Max Bain
2023-07-25 00:28:53 +01:00
committed by GitHub
8 changed files with 135 additions and 53 deletions

View File

@ -54,6 +54,7 @@ This repository provides fast automatic speech recognition (70x realtime with la
<h2 align="left", id="highlights">New🚨</h2> <h2 align="left", id="highlights">New🚨</h2>
- _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization - v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! - v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper. - v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
@ -74,7 +75,7 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia` `conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200) See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
@ -184,6 +185,11 @@ print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs print(result["segments"]) # segments are now assigned speaker IDs
``` ```
## Demos 🚀
[![Replicate](https://replicate.com/daanelson/whisperx/badge)](https://replicate.com/daanelson/whisperx)
If you don't have access to your own GPUs, use the link above to try out WhisperX.
<h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2> <h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2>
@ -276,7 +282,7 @@ If you use this in your research, please cite the paper:
@article{bain2022whisperx, @article{bain2022whisperx,
title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio}, title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew}, author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
journal={arXiv preprint, arXiv:2303.00747}, journal={INTERSPEECH 2023},
year={2023} year={2023}
} }
``` ```

View File

@ -1,8 +1,8 @@
torch==2.0.0 torch>=2
torchaudio==2.0.1 torchaudio>=2
faster-whisper faster-whisper
transformers transformers
ffmpeg-python==0.2.0 ffmpeg-python>=0.2
pandas pandas
setuptools==65.6.3 setuptools>=65
nltk nltk

View File

@ -15,6 +15,9 @@ from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -33,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
@ -42,6 +46,9 @@ DEFAULT_ALIGN_MODELS_HF = {
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
} }
@ -141,7 +148,11 @@ def align(
if any([c in model_dictionary.keys() for c in wrd]): if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx) clean_wdx.append(wdx)
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx segment["clean_cdx"] = clean_cdx
@ -300,6 +311,8 @@ def align(
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps # concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"} agg_dict = {"text": " ".join, "words": "sum"}
if model_lang in LANGUAGES_WITHOUT_SPACES:
agg_dict["text"] = "".join
if return_char_alignments: if return_char_alignments:
agg_dict["chars"] = "sum" agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)

View File

@ -13,8 +13,25 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, def find_numeral_symbol_tokens(tokenizer):
vad_options=None, model=None, task="transcribe"): numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
has_numeral_symbol = any(c in "0123456789%" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens
def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language=None,
vad_options=None,
model=None,
task="transcribe",
download_root=None):
'''Load a Whisper model for inference. '''Load a Whisper model for inference.
Args: Args:
whisper_arch: str - The name of the Whisper model to load. whisper_arch: str - The name of the Whisper model to load.
@ -22,14 +39,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
compute_type: str - The compute type to use for the model. compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model. options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now) language: str - The language of the model. (use English for now)
download_root: Optional[str] - The root directory to download the model to.
Returns: Returns:
A Whisper pipeline. A Whisper pipeline.
''' '''
if whisper_arch.endswith(".en"): if whisper_arch.endswith(".en"):
language = "en" language = "en"
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) model = WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root)
if language is not None: if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else: else:
@ -54,11 +76,22 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
"max_initial_timestamp": 0.0, "max_initial_timestamp": 0.0,
"word_timestamps": False, "word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-", "prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、" "append_punctuations": "\"'.。,!?::”)]}、",
"suppress_numerals": False,
} }
if asr_options is not None: if asr_options is not None:
default_asr_options.update(asr_options) default_asr_options.update(asr_options)
if default_asr_options["suppress_numerals"]:
if tokenizer is None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
default_asr_options["suppress_tokens"] += numeral_symbol_tokens
default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = { default_vad_options = {
@ -106,15 +139,12 @@ class WhisperModel(faster_whisper.WhisperModel):
result = self.model.generate( result = self.model.generate(
encoder_output, encoder_output,
[prompt] * batch_size, [prompt] * batch_size,
# length_penalty=options.length_penalty, length_penalty=options.length_penalty,
# max_length=self.max_length, max_length=self.max_length,
# return_scores=True, suppress_blank=options.suppress_blank,
# return_no_speech_prob=True, suppress_tokens=options.suppress_tokens,
# suppress_blank=options.suppress_blank,
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
) )
tokens_batch = [x.sequences_ids[0] for x in result] tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str: def decode_batch(tokens: List[List[int]]) -> str:
@ -127,7 +157,7 @@ class WhisperModel(faster_whisper.WhisperModel):
text = decode_batch(tokens_batch) text = decode_batch(tokens_batch)
return text return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView: def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved # When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job. # to the CPU since we don't know which GPU will handle the next job.
@ -136,9 +166,9 @@ class WhisperModel(faster_whisper.WhisperModel):
if len(features.shape) == 2: if len(features.shape) == 2:
features = np.expand_dims(features, 0) features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features) features = faster_whisper.transcribe.get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline): class FasterWhisperPipeline(Pipeline):
""" """
Huggingface Pipeline wrapper for FasterWhisperModel. Huggingface Pipeline wrapper for FasterWhisperModel.
@ -176,7 +206,7 @@ class FasterWhisperPipeline(Pipeline):
self.device = torch.device(f"cuda:{device}") self.device = torch.device(f"cuda:{device}")
else: else:
self.device = device self.device = device
super(Pipeline, self).__init__() super(Pipeline, self).__init__()
self.vad_model = vad self.vad_model = vad
@ -194,7 +224,7 @@ class FasterWhisperPipeline(Pipeline):
def _forward(self, model_inputs): def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs} return {'text': outputs}
def postprocess(self, model_outputs): def postprocess(self, model_outputs):
return model_outputs return model_outputs
@ -214,11 +244,11 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator return final_iterator
def transcribe( def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
def data(audio, segments): def data(audio, segments):
for seg in segments: for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE) f1 = int(seg['start'] * SAMPLE_RATE)
@ -228,14 +258,19 @@ class FasterWhisperPipeline(Pipeline):
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30) vad_segments = merge_chunks(vad_segments, 30)
del_tokenizer = False
if self.tokenizer is None: if self.tokenizer is None:
language = self.detect_language(audio) language = language or self.detect_language(audio)
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) task = task or "transcribe"
del_tokenizer = True self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
else: else:
language = self.tokenizer.language_code language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
segments: List[SingleSegment] = [] segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size batch_size = batch_size or self._batch_size
@ -250,9 +285,6 @@ class FasterWhisperPipeline(Pipeline):
"end": round(vad_segments[idx]['end'], 3) "end": round(vad_segments[idx]['end'], 3)
} }
) )
if del_tokenizer:
self.tokenizer = None
return {"segments": segments, "language": language} return {"segments": segments, "language": language}

View File

@ -21,11 +21,12 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@ -50,9 +51,11 @@ def cli():
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@ -78,6 +81,7 @@ def cli():
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
device: str = args.pop("device") device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type") compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush") # model_flush: bool = args.pop("model_flush")
@ -128,6 +132,8 @@ def cli():
"no_speech_threshold": args.pop("no_speech_threshold"), "no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False, "condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"), "initial_prompt": args.pop("initial_prompt"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
} }
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
@ -144,7 +150,7 @@ def cli():
results = [] results = []
tmp_results = [] tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir) # model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
audio = load_audio(audio_path) audio = load_audio(audio_path)
@ -204,4 +210,4 @@ def cli():
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional from typing import TypedDict, Optional, List
class SingleWordSegment(TypedDict): class SingleWordSegment(TypedDict):
@ -38,15 +38,15 @@ class SingleAlignedSegment(TypedDict):
start: float start: float
end: float end: float
text: str text: str
words: list[SingleWordSegment] words: List[SingleWordSegment]
chars: Optional[list[SingleCharSegment]] chars: Optional[List[SingleCharSegment]]
class TranscriptionResult(TypedDict): class TranscriptionResult(TypedDict):
""" """
A list of segments and word segments of a speech. A list of segments and word segments of a speech.
""" """
segments: list[SingleSegment] segments: List[SingleSegment]
language: str language: str
@ -54,5 +54,5 @@ class AlignedTranscriptionResult(TypedDict):
""" """
A list of segments and word segments of a speech. A list of segments and word segments of a speech.
""" """
segments: list[SingleAlignedSegment] segments: List[SingleAlignedSegment]
word_segments: list[SingleWordSegment] word_segments: List[SingleWordSegment]

View File

@ -365,6 +365,28 @@ class WriteTSV(ResultWriter):
print(round(1000 * segment["end"]), file=file, end="\t") print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True) print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteAudacity(ResultWriter):
"""
Write a transcript to a text file that audacity can import as labels.
The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
Yet this is not an audacity project but only a label file!
Please note : Audacity uses seconds in timestamps not ms!
Also there is no header expected.
If speaker is provided it is prepended to the text between double square brackets [[]].
"""
extension: str = "aud"
def write_result(self, result: dict, file: TextIO, options: dict):
ARROW = " "
for segment in result["segments"]:
print(segment["start"], file=file, end=ARROW)
print(segment["end"], file=file, end=ARROW)
print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
@ -383,6 +405,9 @@ def get_writer(
"tsv": WriteTSV, "tsv": WriteTSV,
"json": WriteJSON, "json": WriteJSON,
} }
optional_writers = {
"aud": WriteAudacity,
}
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
@ -393,10 +418,12 @@ def get_writer(
return write_all return write_all
if output_format in optional_writers:
return optional_writers[output_format](output_dir)
return writers[output_format](output_dir) return writers[output_format](output_dir)
def interpolate_nans(x, method='nearest'): def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1: if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill() return x.interpolate(method=method).ffill().bfill()
else: else:
return x.ffill().bfill() return x.ffill().bfill()

View File

@ -147,8 +147,6 @@ class Binarize:
if is_active: if is_active:
curr_duration = t - start curr_duration = t - start
if curr_duration > self.max_duration: if curr_duration > self.max_duration:
# if curr_duration > 15:
# import pdb; pdb.set_trace()
search_after = len(curr_scores) // 2 search_after = len(curr_scores) // 2
# divide segment # divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
@ -166,14 +164,14 @@ class Binarize:
is_active = False is_active = False
curr_scores = [] curr_scores = []
curr_timestamps = [] curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive # currently inactive
else: else:
# switching from inactive to active # switching from inactive to active
if y > self.onset: if y > self.onset:
start = t start = t
is_active = True is_active = True
curr_scores.append(y)
curr_timestamps.append(t)
# if active at the end, add final region # if active at the end, add final region
if is_active: if is_active: