mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add device to dia pipeline @sorgfresser
This commit is contained in:
@ -11,6 +11,7 @@ class DiarizationPipeline:
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
Reference in New Issue
Block a user