Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-10-07 00:07:22 +02:00
parent f9b46fbd6b
commit 633e9087ad

View File

@ -21,18 +21,15 @@ def load(config, agent, epoch, from_disk=True):
SB_BACKEND = "stable_baselines3" SB_BACKEND = "stable_baselines3"
try: try:
from stable_baselines3 import A2C from baselines.common.vec_env import DummyVecEnv
from baselines.a2c import a2c
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
from stable_baselines3.a2c import MlpPolicy from baselines.a2c.policies import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy SB_A2C_POLICY = MlpPolicy
start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
except Exception as e: except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines") logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")