mirror of
https://github.com/jayofelony/pwnagotchi.git
synced 2025-07-01 18:37:27 -04:00
Changed for torch installation
Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
@ -18,45 +18,67 @@ def load(config, agent, epoch, from_disk=True):
|
||||
logging.info("[ai] bootstrapping dependencies ...")
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines import A2C
|
||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||
SB_BACKEND = "stable_baselines3";
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.policies import MlpLstmPolicy
|
||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||
try:
|
||||
from stable_baselines3 import A2C
|
||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.vec_env import DummyVecEnv
|
||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
import pwnagotchi.ai.gym as wrappers
|
||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||
|
||||
env = wrappers.Environment(agent, epoch)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
logging.info("[ai] creating model ...")
|
||||
|
||||
start = time.time()
|
||||
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
||||
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
||||
|
||||
if from_disk and os.path.exists(config['path']):
|
||||
logging.info("[ai] loading %s ..." % config['path'])
|
||||
start = time.time()
|
||||
a2c.load(config['path'], env)
|
||||
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||
else:
|
||||
logging.info("[ai] model created:")
|
||||
for key, value in config['params'].items():
|
||||
logging.info(" %s: %s" % (key, value))
|
||||
from stable_baselines3.a2c import MlpPolicy
|
||||
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
|
||||
SB_A2C_POLICY = MlpPolicy
|
||||
|
||||
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
|
||||
start = time.time()
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||
|
||||
return a2c
|
||||
except Exception as e:
|
||||
logging.exception("error while starting AI (%s)", e)
|
||||
except Exception as e:
|
||||
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
|
||||
|
||||
logging.warning("[ai] AI not loaded!")
|
||||
return False
|
||||
from stable_baselines import A2C
|
||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||
SB_BACKEND = "stable_baselines"
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.policies import MlpLstmPolicy
|
||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||
SB_A2C_POLICY = MlpLstmPolicy
|
||||
|
||||
start = time.time()
|
||||
from stable_baselines.common.vec_env import DummyVecEnv
|
||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||
|
||||
|
||||
start = time.time()
|
||||
import pwnagotchi.ai.gym as wrappers
|
||||
|
||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||
|
||||
env = wrappers.Environment(agent, epoch)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
logging.info("[ai] creating model ...")
|
||||
|
||||
start = time.time()
|
||||
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
|
||||
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
||||
|
||||
if from_disk and os.path.exists(config['path']):
|
||||
logging.info("[ai] loading %s ..." % config['path'])
|
||||
start = time.time()
|
||||
a2c.load(config['path'], env)
|
||||
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||
else:
|
||||
logging.info("[ai] model created:")
|
||||
for key, value in config['params'].items():
|
||||
logging.info(" %s: %s" % (key, value))
|
||||
|
||||
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
|
||||
|
||||
return a2c
|
||||
except Exception as e:
|
||||
logging.exception("error while starting AI (%s)", e)
|
||||
|
||||
logging.warning("[ai] AI not loaded!")
|
||||
return False
|
@ -140,10 +140,11 @@ class Environment(gym.Env):
|
||||
logging.info("[ai] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs()))
|
||||
logging.info("[ai] REWARD: %f" % self.last['reward'])
|
||||
|
||||
logging.debug("[ai] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
|
||||
logging.debug(
|
||||
"[ai] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
|
||||
|
||||
logging.info("[ai] observation:")
|
||||
for name, value in self.last['state'].items():
|
||||
if 'histogram' in name:
|
||||
logging.info(" %s" % name.replace('_histogram', ''))
|
||||
self._render_histogram(value)
|
||||
self._render_histogram(value)
|
@ -137,10 +137,7 @@ ai.params.vf_coef = 0.25
|
||||
ai.params.ent_coef = 0.01
|
||||
ai.params.max_grad_norm = 0.5
|
||||
ai.params.learning_rate = 0.001
|
||||
ai.params.alpha = 0.99
|
||||
ai.params.epsilon = 0.00001
|
||||
ai.params.verbose = 1
|
||||
ai.params.lr_schedule = "constant"
|
||||
|
||||
personality.advertise = true
|
||||
personality.deauth = true
|
||||
|
Reference in New Issue
Block a user