mirror of
https://github.com/jayofelony/pwnagotchi.git
synced 2025-07-01 18:37:27 -04:00
Minor edits
Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
@ -38,45 +38,45 @@ def load(config, agent, epoch, from_disk=True):
|
|||||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||||
SB_BACKEND = "stable_baselines"
|
SB_BACKEND = "stable_baselines"
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
from stable_baselines.common.policies import MlpLstmPolicy
|
from stable_baselines.common.policies import MlpLstmPolicy
|
||||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||||
SB_A2C_POLICY = MlpLstmPolicy
|
SB_A2C_POLICY = MlpLstmPolicy
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
from stable_baselines.common.vec_env import DummyVecEnv
|
from stable_baselines.common.vec_env import DummyVecEnv
|
||||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
import pwnagotchi.ai.gym as wrappers
|
import pwnagotchi.ai.gym as wrappers
|
||||||
|
|
||||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
env = wrappers.Environment(agent, epoch)
|
env = wrappers.Environment(agent, epoch)
|
||||||
env = DummyVecEnv([lambda: env])
|
env = DummyVecEnv([lambda: env])
|
||||||
|
|
||||||
logging.info("[ai] creating model ...")
|
logging.info("[ai] creating model ...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
|
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
|
||||||
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
if from_disk and os.path.exists(config['path']):
|
if from_disk and os.path.exists(config['path']):
|
||||||
logging.info("[ai] loading %s ..." % config['path'])
|
logging.info("[ai] loading %s ..." % config['path'])
|
||||||
start = time.time()
|
start = time.time()
|
||||||
a2c.load(config['path'], env)
|
a2c.load(config['path'], env)
|
||||||
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
logging.debug("[ai] A2C loaded in %.2fs" % (time.time() - start))
|
||||||
else:
|
else:
|
||||||
logging.info("[ai] model created:")
|
logging.info("[ai] model created:")
|
||||||
for key, value in config['params'].items():
|
for key, value in config['params'].items():
|
||||||
logging.info(" %s: %s" % (key, value))
|
logging.info(" %s: %s" % (key, value))
|
||||||
|
|
||||||
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
|
logging.debug("[ai] total loading time is %.2fs" % (time.time() - begin))
|
||||||
|
|
||||||
return a2c
|
return a2c
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("error while starting AI (%s)", e)
|
logging.exception("error while starting AI (%s)", e)
|
||||||
|
|
||||||
logging.warning("[ai] AI not loaded!")
|
logging.warning("[ai] AI not loaded!")
|
||||||
return False
|
return False
|
@ -19,3 +19,4 @@ websockets>=8.1
|
|||||||
torch>=2.0.1
|
torch>=2.0.1
|
||||||
torchvision>=0.15.2
|
torchvision>=0.15.2
|
||||||
stable-baselines3>=1.4.0
|
stable-baselines3>=1.4.0
|
||||||
|
RPi.GPIO
|
Reference in New Issue
Block a user