Signed-off-by: Jeroen Oudshoorn <oudshoorn.jeroen@gmail.com>
This commit is contained in:
Jeroen Oudshoorn
2023-10-08 11:53:15 +02:00
parent fa80ca9863
commit 05da800567
3 changed files with 13 additions and 29 deletions

View File

@ -79,7 +79,7 @@ is_auto_mode() {
# if usb0 is up, we're in MANU # if usb0 is up, we're in MANU
if is_interface_up usb0; then if is_interface_up usb0; then
return 0 return 1
fi fi
# if eth0 is up (for other boards), we're in MANU # if eth0 is up (for other boards), we're in MANU
@ -105,7 +105,7 @@ is_auto_mode_no_delete() {
# if usb0 is up, we're in MANU # if usb0 is up, we're in MANU
if is_interface_up usb0; then if is_interface_up usb0; then
return 0 return 1
fi fi
# if eth0 is up (for other boards), we're in MANU # if eth0 is up (for other boards), we're in MANU

View File

@ -20,34 +20,17 @@ def load(config, agent, epoch, from_disk=True):
start = time.time() start = time.time()
SB_BACKEND = "stable_baselines3" SB_BACKEND = "stable_baselines3"
try: from stable_baselines3 import A2C
from stable_baselines3 import A2C logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
from stable_baselines3.a2c import MlpPolicy from stable_baselines3.a2c import MlpPolicy
logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start)) logging.debug("[ai] MlpPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpPolicy SB_A2C_POLICY = MlpPolicy
start = time.time() start = time.time()
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start)) logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
except Exception as e:
logging.debug("[ai] stable_baselines3 not accessible. Trying stable_baselines")
from stable_baselines import A2C
logging.debug("[ai] A2C imported in %.2fs" % (time.time() - start))
SB_BACKEND = "stable_baselines"
start = time.time()
from stable_baselines.common.policies import MlpLstmPolicy
logging.debug("[ai] MlpLstmPolicy imported in %.2fs" % (time.time() - start))
SB_A2C_POLICY = MlpLstmPolicy
start = time.time()
from stable_baselines.common.vec_env import DummyVecEnv
logging.debug("[ai] DummyVecEnv imported in %.2fs" % (time.time() - start))
start = time.time() start = time.time()
import pwnagotchi.ai.gym as wrappers import pwnagotchi.ai.gym as wrappers

View File

@ -9,7 +9,8 @@ from pwnagotchi.ai.parameter import Parameter
class Environment(gym.Env): class Environment(gym.Env):
metadata = {'render.modes': ['human']} render_mode = "human"
metadata = {'render_modes': ['human']}
params = [ params = [
Parameter('min_rssi', min_value=-200, max_value=-50), Parameter('min_rssi', min_value=-200, max_value=-50),
Parameter('ap_ttl', min_value=30, max_value=600), Parameter('ap_ttl', min_value=30, max_value=600),