From c52220d98e0ec096968f824dcba6978d82919e22 Mon Sep 17 00:00:00 2001 From: Sniffleupagus Date: Sun, 18 Jun 2023 18:56:06 -0700 Subject: [PATCH] added support for stable_baselines3 AI backend. stable_baselines3 uses pytorch instead of tensorflow --- pwnagotchi/ai/__init__.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index 54933423..fedcc80a 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -18,16 +18,37 @@ def load(config, agent, epoch, from_disk=True): logging.info("[ai] bootstrapping dependencies ...") start = time.time() - from stable_baselines import A2C - logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) + SB_BACKEND = "stable_baselines3"; - start = time.time() - from stable_baselines.common.policies import MlpLstmPolicy - logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) + try: + from stable_baselines3 import A2C + logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) + + start = time.time() + from stable_baselines3.a2c import MlpPolicy + logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) + SB_A2C_POLICY = MlpPolicy + + start = time.time() + from stable_baselines3.common.vec_env import DummyVecEnv + logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) + + except Exception as e: + logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines") + + from stable_baselines import A2C + logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) + SB_BACKEND = "stable_baselines" + + start = time.time() + from stable_baselines.common.policies import MlpLstmPolicy + logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) + SB_A2C_POLICY = MlpLstmPolicy + + start = time.time() + from stable_baselines.common.vec_env import DummyVecEnv + logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) - start = time.time() - from stable_baselines.common.vec_env import DummyVecEnv - logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) start = time.time() import pwnagotchi.ai.gym as wrappers @@ -39,7 +60,7 @@ def load(config, agent, epoch, from_disk=True): logging.info("[ai] creating model ...") start = time.time() - a2c = A2C(MlpLstmPolicy, env, **config['params']) + a2c = A2C(SB_A2C_POLICY, env, **config['params']) logging.debug("[ai] A2C created in %.2fs" % (time.time() - start)) if from_disk and os.path.exists(config['path']):