add chinese, dutch. python usage. readme update

This commit is contained in:
Max Bain
2022-12-23 00:41:12 +00:00
parent e909f2f766
commit c6fa7df3cc
3 changed files with 55 additions and 41 deletions

View File

@ -29,6 +29,8 @@ DEFAULT_ALIGN_MODELS_TORCH = {
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
}
@ -264,7 +266,6 @@ def transcribe(
def align(
transcript: Iterator[dict],
language: str,
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
@ -309,7 +310,7 @@ def align(
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
transcription = segment['text'].strip()
if language not in LANGUAGES_WITHOUT_SPACES:
if model_lang not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription]
@ -426,7 +427,7 @@ def cli():
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="directory to save the outputs")
parser.add_argument("--output_type", default="srt", choices=['all', 'srt', 'vtt', 'txt'], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -494,7 +495,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
result_aligned = align(result["segments"], result["language"], align_model, align_metadata, audio_path, device,
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
audio_basename = os.path.basename(audio_path)
@ -518,8 +519,8 @@ def cli():
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as srt:
write_ass(result_aligned["segments"], file=srt)
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
if __name__ == '__main__':