mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add chinese, dutch. python usage. readme update
This commit is contained in:
@ -29,6 +29,8 @@ DEFAULT_ALIGN_MODELS_TORCH = {
|
||||
|
||||
DEFAULT_ALIGN_MODELS_HF = {
|
||||
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
||||
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
||||
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
||||
}
|
||||
|
||||
|
||||
@ -264,7 +266,6 @@ def transcribe(
|
||||
|
||||
def align(
|
||||
transcript: Iterator[dict],
|
||||
language: str,
|
||||
model: torch.nn.Module,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
@ -309,7 +310,7 @@ def align(
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
emission = emissions[0].cpu().detach()
|
||||
transcription = segment['text'].strip()
|
||||
if language not in LANGUAGES_WITHOUT_SPACES:
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
t_words = transcription.split(' ')
|
||||
else:
|
||||
t_words = [c for c in transcription]
|
||||
@ -426,7 +427,7 @@ def cli():
|
||||
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
|
||||
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
@ -494,7 +495,7 @@ def cli():
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
|
||||
result_aligned = align(result["segments"], result["language"], align_model, align_metadata, audio_path, device,
|
||||
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
@ -518,8 +519,8 @@ def cli():
|
||||
write_srt(result_aligned["word_segments"], file=srt)
|
||||
|
||||
# save ASS
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as srt:
|
||||
write_ass(result_aligned["segments"], file=srt)
|
||||
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
|
||||
write_ass(result_aligned["segments"], file=ass)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user