added support for stable_baselines3 AI backend. stable_baselines3 uses

pytorch instead of tensorflow
This commit is contained in:
Sniffleupagus
2023-06-18 18:56:06 -07:00
parent d2227b939d
commit c52220d98e

View File

@ -18,16 +18,37 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[ai] bootstrapping dependencies ...") logging.info("[ai] bootstrapping dependencies ...")
start = time.time() start = time.time()
from stable_baselines import A2C SB_BACKEND = "stable_baselines3";
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time() try:
from stable_baselines.common.policies import MlpLstmPolicy from stable_baselines3 import A2C
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines"
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpLstmPolicy
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
import pwnagotchi.ai.gym as wrappers import pwnagotchi.ai.gym as wrappers
@ -39,7 +60,7 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[ai] creating model ...") logging.info("[ai] creating model ...")
start = time.time() start = time.time()
a2c = A2C(MlpLstmPolicy, env, **config['params']) a2c = A2C(SB_A2C_POLICY, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']): if from_disk and os.path.exists(config['path']):