mirror of
https://github.com/jayofelony/pwnagotchi.git
synced 2025-07-01 18:37:27 -04:00
added support for stable_baselines3 AI backend. stable_baselines3 uses
pytorch instead of tensorflow
This commit is contained in:
@ -18,17 +18,38 @@ def load(config, agent, epoch, from_disk=True):
|
|||||||
logging.info("[ai] bootstrapping dependencies ...")
|
logging.info("[ai] bootstrapping dependencies ...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
SB_BACKEND = "stable_baselines3";
|
||||||
|
|
||||||
|
try:
|
||||||
|
from stable_baselines3 import A2C
|
||||||
|
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
from stable_baselines3.a2c import MlpPolicy
|
||||||
|
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
|
||||||
|
SB_A2C_POLICY = MlpPolicy
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
|
||||||
|
|
||||||
from stable_baselines import A2C
|
from stable_baselines import A2C
|
||||||
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
|
||||||
|
SB_BACKEND = "stable_baselines"
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
from stable_baselines.common.policies import MlpLstmPolicy
|
from stable_baselines.common.policies import MlpLstmPolicy
|
||||||
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
|
||||||
|
SB_A2C_POLICY = MlpLstmPolicy
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
from stable_baselines.common.vec_env import DummyVecEnv
|
from stable_baselines.common.vec_env import DummyVecEnv
|
||||||
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
import pwnagotchi.ai.gym as wrappers
|
import pwnagotchi.ai.gym as wrappers
|
||||||
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
logging.debug("[ai] gym wrapper imported in %.2fs" % (time.time() - start))
|
||||||
@ -39,7 +60,7 @@ def load(config, agent, epoch, from_disk=True):
|
|||||||
logging.info("[ai] creating model ...")
|
logging.info("[ai] creating model ...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
a2c = A2C(MlpLstmPolicy, env, **config['params'])
|
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
|
||||||
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
|
||||||
|
|
||||||
if from_disk and os.path.exists(config['path']):
|
if from_disk and os.path.exists(config['path']):
|
||||||
|
Reference in New Issue
Block a user