diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index ad41fc10..b5821819 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -21,18 +21,15 @@ def load(config, agent, epoch, from_disk=True): SB_BACKEND = "stable_baselines3" try: - from stable_baselines3 import A2C + from baselines.common.vec_env import DummyVecEnv + from baselines.a2c import a2c logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) start = time.time() - from stable_baselines3.a2c import MlpPolicy + from baselines.a2c.policies import MlpPolicy logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) SB_A2C_POLICY = MlpPolicy - start = time.time() - from stable_baselines3.common.vec_env import DummyVecEnv - logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) - except Exception as e: logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")