support huggingface + model select based on lang.

This commit is contained in:
Max Bain
2022-12-20 19:54:55 +00:00
parent 8b2f40d02a
commit e909f2f766
2 changed files with 71 additions and 25 deletions

View File

@ -122,8 +122,7 @@ https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-
[x] Subtitle .ass output
[ ] Automatic align model selection based on language detection
[x] Automatic align model selection based on language detection
[ ] Incorporating word-level speaker diarization

View File

@ -17,8 +17,19 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st
if TYPE_CHECKING:
from .model import Whisper
hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
asian_languages = ["ja"]
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
}
def transcribe(
@ -255,7 +266,7 @@ def align(
transcript: Iterator[dict],
language: str,
model: torch.nn.Module,
model_dictionary: dict,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
@ -272,6 +283,10 @@ def align(
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata['dictionary']
model_lang = align_model_metadata['language']
model_type = align_model_metadata['type']
prev_t2 = 0
word_segments_list = []
for idx, segment in enumerate(transcript):
@ -285,14 +300,16 @@ def align(
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if language not in asian_languages:
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
else:
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
transcription = segment['text'].strip()
if language not in asian_languages:
if language not in LANGUAGES_WITHOUT_SPACES:
t_words = transcription.split(' ')
else:
t_words = [c for c in transcription]
@ -359,6 +376,41 @@ def align(
return {"segments": transcript, "word_segments": word_segments_list}
def load_align_model(language_code, device, model_name=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
pipeline_type = "torchaudio"
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = AutoProcessor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
pipeline_type = "huggingface"
align_model = align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
return align_model, align_metadata
def cli():
from . import available_models
@ -368,7 +420,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
# alignment params
parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
@ -430,24 +482,19 @@ def cli():
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
if align_model in torchaudio.pipelines.__all__:
bundle = torchaudio.pipelines.__dict__[align_model]
align_model = bundle.get_model().to(device)
labels = bundle.get_labels()
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
elif align_model in hugginface_models:
processor = AutoProcessor.from_pretrained(align_model)
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
align_model.to(device)
labels = processor.tokenizer.get_vocab()
align_dictionary = processor.tokenizer.get_vocab()
else:
print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface} \n\
See details here https://pytorch.org/audio/stable/pipelines.html#id14')
raise ValueError(f'Align model "{align_model}" not supported')
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
if result["language"] != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
result_aligned = align(result["segments"], result["language"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
audio_basename = os.path.basename(audio_path)