Changed for torch installation

Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-07-19 12:02:11 +02:00
parent 045759685c
commit 5da9f3598a
6 changed files with 222 additions and 129 deletions

View File

@ -18,45 +18,67 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[ai] bootstrapping dependencies ...")
start = time.time()
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines3";
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
try:
from stable_baselines3 import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
import pwnagotchi.ai.gym as wrappers
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
env = wrappers.Environment(agent, epoch)
env = DummyVecEnv([lambda: env])
logging.info("[ai] creating model ...")
start = time.time()
a2c = A2C(MlpLstmPolicy, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']):
logging.info("[ai] loading %s ..." % config['path'])
start = time.time()
a2c.load(config['path'], env)
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
else:
logging.info("[ai] model created:")
for key, value in config['params'].items():
logging.info(" %s: %s" % (key, value))
from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
return a2c
except Exception as e:
logging.exception("error while starting AI (%s)", e)
except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
logging.warning("[ai] AI not loaded!")
return False
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines"
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpLstmPolicy
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
import pwnagotchi.ai.gym as wrappers
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
env = wrappers.Environment(agent, epoch)
env = DummyVecEnv([lambda: env])
logging.info("[ai] creating model ...")
start = time.time()
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']):
logging.info("[ai] loading %s ..." % config['path'])
start = time.time()
a2c.load(config['path'], env)
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
else:
logging.info("[ai] model created:")
for key, value in config['params'].items():
logging.info(" %s: %s" % (key, value))
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
return a2c
except Exception as e:
logging.exception("error while starting AI (%s)", e)
logging.warning("[ai] AI not loaded!")
return False

View File

@ -140,10 +140,11 @@ class Environment(gym.Env):
logging.info("[ai] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs()))
logging.info("[ai] REWARD: %f" % self.last['reward'])
logging.debug("[ai] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
logging.debug(
"[ai] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items()))
logging.info("[ai] observation:")
for name, value in self.last['state'].items():
if 'histogram' in name:
logging.info(" %s" % name.replace('_histogram', ''))
self._render_histogram(value)
self._render_histogram(value)

View File

@ -137,10 +137,7 @@ ai.params.vf_coef = 0.25
ai.params.ent_coef = 0.01
ai.params.max_grad_norm = 0.5
ai.params.learning_rate = 0.001
ai.params.alpha = 0.99
ai.params.epsilon = 0.00001
ai.params.verbose = 1
ai.params.lr_schedule = "constant"
personality.advertise = true
personality.deauth = true