mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
e24ca9e0a2 | |||
601c91140f | |||
31a9ec7466 | |||
b9c8c5072b | |||
a903e57cf1 | |||
cb176a186e |
@ -450,8 +450,8 @@ def align(
|
|||||||
"end": srow["end"],
|
"end": srow["end"],
|
||||||
"text": text,
|
"text": text,
|
||||||
"words": word_list,
|
"words": word_list,
|
||||||
# "word-segments": wseg,
|
"word-segments": wseg,
|
||||||
# "char-segments": cseg
|
"char-segments": cseg
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None
|
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
|
||||||
):
|
):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -232,7 +232,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
segments = []
|
segments = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
|
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
|
||||||
text = out['text']
|
text = out['text']
|
||||||
if batch_size in [0, 1, None]:
|
if batch_size in [0, 1, None]:
|
||||||
text = text[0]
|
text = text[0]
|
||||||
@ -251,7 +251,10 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
|
|
||||||
def detect_language(self, audio: np.ndarray):
|
def detect_language(self, audio: np.ndarray):
|
||||||
segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
|
if audio.shape[0] < N_SAMPLES:
|
||||||
|
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||||
|
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||||
|
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||||
encoder_output = self.model.encode(segment)
|
encoder_output = self.model.encode(segment)
|
||||||
results = self.model.model.detect_language(encoder_output)
|
results = self.model.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
|
@ -203,6 +203,12 @@ def cli():
|
|||||||
|
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
|
# Remove pandas dataframes from result so that
|
||||||
|
# we can serialize the result with json
|
||||||
|
for seg in result["segments"]:
|
||||||
|
seg.pop("word-segments", None)
|
||||||
|
seg.pop("char-segments", None)
|
||||||
|
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user