diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index dc3ed603..c84a7fe2 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -38,45 +38,45 @@ def load(config, agent, epoch, from_disk=True): logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) SB_BACKEND = "stable_baselines" - start = time.time() - from stable_baselines.common.policies import MlpLstmPolicy - logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) - SB_A2C_POLICY = MlpLstmPolicy + start = time.time() + from stable_baselines.common.policies import MlpLstmPolicy + logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) + SB_A2C_POLICY = MlpLstmPolicy - start = time.time() - from stable_baselines.common.vec_env import DummyVecEnv - logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) + start = time.time() + from stable_baselines.common.vec_env import DummyVecEnv + logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) -start = time.time() -import pwnagotchi.ai.gym as wrappers + start = time.time() + import pwnagotchi.ai.gym as wrappers -logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) + logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) -env = wrappers.Environment(agent, epoch) -env = DummyVecEnv([lambda: env]) + env = wrappers.Environment(agent, epoch) + env = DummyVecEnv([lambda: env]) -logging.info("[ai] creating model ...") + logging.info("[ai] creating model ...") -start = time.time() -a2c = A2C(SB_A2C_POLICY, env, **config['params']) -logging.debug("[ai] A2C created in %.2fs" % (time.time() - start)) + start = time.time() + a2c = A2C(SB_A2C_POLICY, env, **config['params']) + logging.debug("[ai] A2C created in %.2fs" % (time.time() - start)) -if from_disk and os.path.exists(config['path']): - logging.info("[ai] loading %s ..." % config['path']) - start = time.time() - a2c.load(config['path'], env) - logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) -else: - logging.info("[ai] model created:") - for key, value in config['params'].items(): - logging.info(" %s: %s" % (key, value)) + if from_disk and os.path.exists(config['path']): + logging.info("[ai] loading %s ..." % config['path']) + start = time.time() + a2c.load(config['path'], env) + logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) + else: + logging.info("[ai] model created:") + for key, value in config['params'].items(): + logging.info(" %s: %s" % (key, value)) -logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin)) + logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin)) -return a2c -except Exception as e: -logging.exception("error while starting AI (%s)", e) + return a2c + except Exception as e: + logging.exception("error while starting AI (%s)", e) -logging.warning("[ai] AI not loaded!") -return False \ No newline at end of file + logging.warning("[ai] AI not loaded!") + return False \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 600a4e08..b428e399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ python-dateutil>=2.8.1 websockets>=8.1 torch>=2.0.1 torchvision>=0.15.2 -stable-baselines3>=1.4.0 \ No newline at end of file +stable-baselines3>=1.4.0 +RPi.GPIO \ No newline at end of file