This commit is contained in:
Max Bain
2023-04-24 21:08:43 +01:00
parent da458863d7
commit 558d980535
11 changed files with 1034 additions and 846 deletions

View File

@ -2,16 +2,17 @@
Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union
import numpy as np
import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch
from dataclasses import dataclass
from whisper.audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -391,34 +392,42 @@ def align(
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
word_list = []
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if not curr_text.startswith(" "):
curr_text = " " + curr_text
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
word_start = wseg.iloc[wdx]['start']
word_end = wseg.iloc[wdx]['end']
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"],
"start": word_start,
"end": word_end
}
)
curr_text = ""
curr_text += " " + get_raw_text(wrow)
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
}
)
curr_text = " "
curr_text += get_raw_text(wrow) + " "
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
@ -427,6 +436,25 @@ def align(
}
)
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
}
)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"words": word_list,
# "word-segments": wseg,
# "char-segments": cseg
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word}