Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-10-11 10:27:39 +02:00
parent 05da800567
commit bdb5f78670
5 changed files with 21 additions and 21 deletions

View File

@ -3,7 +3,7 @@ import time
import logging
# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
def load(config, agent, epoch, from_disk=True):
@ -15,47 +15,47 @@ def load(config, agent, epoch, from_disk=True):
try:
begin = time.time()
logging.info("[ai] bootstrapping dependencies ...")
logging.info("[AI] bootstrapping dependencies ...")
start = time.time()
SB_BACKEND = "stable_baselines3"
from stable_baselines3 import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
logging.debug("[AI] A2C imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
logging.debug("[AI] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
logging.debug("[AI] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
import pwnagotchi.ai.gym as wrappers
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
logging.debug("[AI] gym wrapper imported in %.2fs" % (time.time() - start))
env = wrappers.Environment(agent, epoch)
env = DummyVecEnv([lambda: env])
logging.info("[ai] creating model ...")
logging.info("[AI] creating model ...")
start = time.time()
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
logging.debug("[AI] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']):
logging.info("[ai] loading %s ..." % config['path'])
logging.info("[AI] loading %s ..." % config['path'])
start = time.time()
a2c.load(config['path'], env)
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
logging.debug("[AI] A2C loaded in %.2fs" % (time.time() - start))
else:
logging.info("[ai] model created:")
logging.info("[AI] model created:")
for key, value in config['params'].items():
logging.info(" %s: %s" % (key, value))
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
logging.debug("[AI] total loading time is %.2fs" % (time.time() - begin))
return a2c
except Exception as e:
@ -63,5 +63,5 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[AI] Deleting brain and restarting.")
os.system("rm /root/brain.nn && service pwnagotchi restart")
logging.warning("[ai] AI not loaded!")
logging.warning("[AI] AI not loaded!")
return False

View File

@ -138,13 +138,13 @@ class Environment(gym.Env):
self._last_render = self._epoch_num
logging.info("[ai] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs()))
logging.info("[ai] REWARD: %f" % self.last['reward'])
logging.info("[AI] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs()))
logging.info("[AI] REWARD: %f" % self.last['reward'])
logging.debug(
"[ai] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
"[AI] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
logging.info("[ai] observation:")
logging.info("[AI] observation:")
for name, value in self.last['state'].items():
if 'histogram' in name:
logging.info(" %s" % name.replace('_histogram', ''))