Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-10-07 00:46:05 +02:00
parent 74a6fa85e1
commit d9bc67000a

View File

@ -21,15 +21,18 @@ def load(config, agent, epoch, from_disk=True):
SB_BACKEND = "stable_baselines3" SB_BACKEND = "stable_baselines3"
try: try:
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3 import A2C
from stable_baselines3.a2c import a2c
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
from stable_baselines3.a2c.policies import MlpPolicy from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy SB_A2C_POLICY = MlpPolicy
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
except Exception as e: except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines") logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")