3 Commits

Author SHA1 Message Date
036b5b0717 Merge c89b4f898f into d700b56c9c 2025-06-13 15:33:03 +02:00
d700b56c9c docs: add missing torch import to Python usage example in README 2025-06-08 03:34:49 -06:00
c89b4f898f fix: incorrect type annotation in get_writer return value
The audio_path attribute that the __call__ method of the ResultWriter class takes is a str, not TextIO
2025-05-13 02:45:33 +02:00
2 changed files with 4 additions and 4 deletions

View File

@ -189,7 +189,7 @@ result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
@ -198,7 +198,7 @@ result = whisperx.align(result["segments"], model_a, metadata, audio, device, re
print(result["segments"]) # after alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)

View File

@ -410,7 +410,7 @@ class WriteJSON(ResultWriter):
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
) -> Callable[[dict, str, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
@ -425,7 +425,7 @@ def get_writer(
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict):
def write_all(result: dict, file: str, options: dict):
for writer in all_writers:
writer(result, file, options)