Minor edits

Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-07-21 00:06:35 +02:00
parent 8a00c1835b
commit ce0d99c46d
2 changed files with 33 additions and 32 deletions

View File

@ -38,45 +38,45 @@ def load(config, agent, epoch, from_disk=True):
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines" SB_BACKEND = "stable_baselines"
start = time.time() start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpLstmPolicy SB_A2C_POLICY = MlpLstmPolicy
start = time.time() start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
import pwnagotchi.ai.gym as wrappers import pwnagotchi.ai.gym as wrappers
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start)) logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
env = wrappers.Environment(agent, epoch) env = wrappers.Environment(agent, epoch)
env = DummyVecEnv([lambda: env]) env = DummyVecEnv([lambda: env])
logging.info("[ai] creating model ...") logging.info("[ai] creating model ...")
start = time.time() start = time.time()
a2c = A2C(SB_A2C_POLICY, env, **config['params']) a2c = A2C(SB_A2C_POLICY, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']): if from_disk and os.path.exists(config['path']):
logging.info("[ai] loading %s ..." % config['path']) logging.info("[ai] loading %s ..." % config['path'])
start = time.time() start = time.time()
a2c.load(config['path'], env) a2c.load(config['path'], env)
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
else: else:
logging.info("[ai] model created:") logging.info("[ai] model created:")
for key, value in config['params'].items(): for key, value in config['params'].items():
logging.info(" %s: %s" % (key, value)) logging.info(" %s: %s" % (key, value))
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin)) logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
return a2c return a2c
except Exception as e: except Exception as e:
logging.exception("error while starting AI (%s)", e) logging.exception("error while starting AI (%s)", e)
logging.warning("[ai] AI not loaded!") logging.warning("[ai] AI not loaded!")
return False return False

View File

@ -18,4 +18,5 @@ python-dateutil>=2.8.1
websockets>=8.1 websockets>=8.1
torch>=2.0.1 torch>=2.0.1
torchvision>=0.15.2 torchvision>=0.15.2
stable-baselines3>=1.4.0 stable-baselines3>=1.4.0
RPi.GPIO