Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-10-07 00:42:01 +02:00
parent eae6917c08
commit 74a6fa85e1

View File

@ -21,12 +21,12 @@ def load(config, agent, epoch, from_disk=True):
SB_BACKEND = "stable_baselines3" SB_BACKEND = "stable_baselines3"
try: try:
from baselines.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
from baselines.a2c import a2c from stable_baselines3.a2c import a2c
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start)) logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
from baselines.a2c.policies import MlpPolicy from stable_baselines3.a2c.policies import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy SB_A2C_POLICY = MlpPolicy