added support for stable_baselines3 AI backend. stable_baselines3 uses

pytorch instead of tensorflow
This commit is contained in:
Sniffleupagus
2023-06-18 18:56:06 -07:00
parent d2227b939d
commit c52220d98e

View File

@ -18,16 +18,37 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[ai] bootstrapping dependencies ...")
start = time.time()
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines3";
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
try:
from stable_baselines3 import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines"
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpLstmPolicy
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time()
import pwnagotchi.ai.gym as wrappers
@ -39,7 +60,7 @@ def load(config, agent, epoch, from_disk=True):
logging.info("[ai] creating model ...")
start = time.time()
a2c = A2C(MlpLstmPolicy, env, **config['params'])
a2c = A2C(SB_A2C_POLICY, env, **config['params'])
logging.debug("[ai] A2C created in %.2fs" % (time.time() - start))
if from_disk and os.path.exists(config['path']):