diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py index 190025fc..7390193e 100644 --- a/pwnagotchi/ai/__init__.py +++ b/pwnagotchi/ai/__init__.py @@ -1,14 +1,13 @@ import os +import time +import warnings +import logging # https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'} -import warnings - # https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning warnings.simplefilter(action='ignore', category=FutureWarning) -import logging - def load(config, agent, epoch, from_disk=True): config = config['ai'] @@ -18,25 +17,39 @@ def load(config, agent, epoch, from_disk=True): logging.info("[ai] bootstrapping dependencies ...") + start = time.time() from stable_baselines import A2C - from stable_baselines.common.policies import MlpLstmPolicy - from stable_baselines.common.vec_env import DummyVecEnv + logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) + start = time.time() + from stable_baselines.common.policies import MlpLstmPolicy + logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) + + start = time.time() + from stable_baselines.common.vec_env import DummyVecEnv + logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) + + start = time.time() import pwnagotchi.ai.gym as wrappers + logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) env = wrappers.Environment(agent, epoch) env = DummyVecEnv([lambda: env]) - logging.info("[ai] bootstrapping model ...") + logging.info("[ai] creating model ...") + start = time.time() a2c = A2C(MlpLstmPolicy, env, **config['params']) + logging.debug("[ai] A2C crated in %.2fs" % (time.time() - start)) if from_disk and os.path.exists(config['path']): logging.info("[ai] loading %s ..." % config['path']) + start = time.time() a2c.load(config['path'], env) + logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) else: logging.info("[ai] model created:") for key, value in config['params'].items(): logging.info(" %s: %s" % (key, value)) - return a2c \ No newline at end of file + return a2c