From 07e7da6ddc8b66798411a33b3160bddd3cd786e2 Mon Sep 17 00:00:00 2001 From: Jeroen Oudshoorn Date: Thu, 4 Jan 2024 23:02:53 +0100 Subject: [PATCH] Remove AI --- pwnagotchi/agent.py | 3 - pwnagotchi/ai/__init__.py | 67 ---------- pwnagotchi/ai/epoch.py | 249 ------------------------------------ pwnagotchi/ai/featurizer.py | 66 ---------- pwnagotchi/ai/gym.py | 151 ---------------------- pwnagotchi/ai/parameter.py | 30 ----- pwnagotchi/ai/reward.py | 27 ---- pwnagotchi/ai/train.py | 197 ---------------------------- pwnagotchi/ai/utils.py | 16 --- 9 files changed, 806 deletions(-) delete mode 100644 pwnagotchi/ai/__init__.py delete mode 100644 pwnagotchi/ai/epoch.py delete mode 100644 pwnagotchi/ai/featurizer.py delete mode 100644 pwnagotchi/ai/gym.py delete mode 100644 pwnagotchi/ai/parameter.py delete mode 100644 pwnagotchi/ai/reward.py delete mode 100644 pwnagotchi/ai/train.py delete mode 100644 pwnagotchi/ai/utils.py diff --git a/pwnagotchi/agent.py b/pwnagotchi/agent.py index e5715f4d..becea5f5 100644 --- a/pwnagotchi/agent.py +++ b/pwnagotchi/agent.py @@ -14,7 +14,6 @@ from pwnagotchi.automata import Automata from pwnagotchi.log import LastSession from pwnagotchi.bettercap import Client from pwnagotchi.mesh.utils import AsyncAdvertiser -# from pwnagotchi.ai.train import AsyncTrainer RECOVERY_DATA_FILE = '/root/.pwnagotchi-recovery' @@ -28,7 +27,6 @@ class Agent(Client, Automata, AsyncAdvertiser): config['bettercap']['password']) Automata.__init__(self, config, view) AsyncAdvertiser.__init__(self, config, view, keypair) - # AsyncTrainer.__init__(self, config) self._started_at = time.time() self._filter = None if not config['main']['filter'] else re.compile(config['main']['filter']) @@ -129,7 +127,6 @@ class Agent(Client, Automata, AsyncAdvertiser): time.sleep(1) def start(self): - # self.start_ai() self._wait_bettercap() self.setup_events() self.set_starting() diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py deleted file mode 100644 index ba458651..00000000 --- a/pwnagotchi/ai/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import time -import logging - -# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709 -# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'} - - -def load(config, agent, epoch, from_disk=True): - config = config['ai'] - if not config['enabled']: - logging.info("ai disabled") - return False - - try: - begin = time.time() - - logging.info("[AI] bootstrapping dependencies ...") - - start = time.time() - SB_BACKEND = "stable_baselines3" - - from stable_baselines3 import A2C - logging.debug("[AI] A2C imported in %.2fs" % (time.time() - start)) - - start = time.time() - from stable_baselines3.a2c import MlpPolicy - logging.debug("[AI] MlpPolicy imported in %.2fs" % (time.time() - start)) - SB_A2C_POLICY = MlpPolicy - - start = time.time() - from stable_baselines3.common.vec_env import DummyVecEnv - logging.debug("[AI] DummyVecEnv imported in %.2fs" % (time.time() - start)) - - start = time.time() - import pwnagotchi.ai.gym as wrappers - logging.debug("[AI] gym wrapper imported in %.2fs" % (time.time() - start)) - - env = wrappers.Environment(agent, epoch) - env = DummyVecEnv([lambda: env]) - - logging.info("[AI] creating model ...") - - start = time.time() - a2c = A2C(SB_A2C_POLICY, env, **config['params']) - logging.debug("[AI] A2C created in %.2fs" % (time.time() - start)) - - if from_disk and os.path.exists(config['path']): - logging.info("[AI] loading %s ..." % config['path']) - start = time.time() - a2c.load(config['path'], env) - logging.debug("[AI] A2C loaded in %.2fs" % (time.time() - start)) - else: - logging.info("[AI] model created:") - for key, value in config['params'].items(): - logging.info(" %s: %s" % (key, value)) - - logging.debug("[AI] total loading time is %.2fs" % (time.time() - begin)) - - return a2c - except Exception as e: - logging.exception("[AI] error while starting AI (%s)", e) - logging.info("[AI] Deleting brain and restarting.") - os.system("rm /root/brain.nn && service pwnagotchi restart") - - logging.warning("[AI] AI not loaded!") - return False diff --git a/pwnagotchi/ai/epoch.py b/pwnagotchi/ai/epoch.py deleted file mode 100644 index 2ba47212..00000000 --- a/pwnagotchi/ai/epoch.py +++ /dev/null @@ -1,249 +0,0 @@ -import time -import threading -import logging - -import pwnagotchi -import pwnagotchi.utils as utils -import pwnagotchi.mesh.wifi as wifi - -from pwnagotchi.ai.reward import RewardFunction - - -class Epoch(object): - def __init__(self, config): - self.epoch = 0 - self.config = config - # how many consecutive epochs with no activity - self.inactive_for = 0 - # how many consecutive epochs with activity - self.active_for = 0 - # number of epochs with no visible access points - self.blind_for = 0 - # number of epochs in sad state - self.sad_for = 0 - # number of epochs in bored state - self.bored_for = 0 - # did deauth in this epoch in the current channel? - self.did_deauth = False - # number of deauths in this epoch - self.num_deauths = 0 - # did associate in this epoch in the current channel? - self.did_associate = False - # number of associations in this epoch - self.num_assocs = 0 - # number of assocs or deauths missed - self.num_missed = 0 - # did get any handshake in this epoch? - self.did_handshakes = False - # number of handshakes captured in this epoch - self.num_shakes = 0 - # number of channels hops - self.num_hops = 0 - # number of seconds sleeping - self.num_slept = 0 - # number of peers seen during this epoch - self.num_peers = 0 - # cumulative bond factor - self.tot_bond_factor = 0.0 # cum_bond_factor sounded worse ... - # average bond factor - self.avg_bond_factor = 0.0 - # any activity at all during this epoch? - self.any_activity = False - # when the current epoch started - self.epoch_started = time.time() - # last epoch duration - self.epoch_duration = 0 - # https://www.metageek.com/training/resources/why-channels-1-6-11.html - self.non_overlapping_channels = {1: 0, 6: 0, 11: 0} - # observation vectors - self._observation = { - 'aps_histogram': [0.0] * wifi.NumChannels, - 'sta_histogram': [0.0] * wifi.NumChannels, - 'peers_histogram': [0.0] * wifi.NumChannels - } - self._observation_ready = threading.Event() - self._epoch_data = {} - self._epoch_data_ready = threading.Event() - self._reward = RewardFunction() - - def wait_for_epoch_data(self, with_observation=True, timeout=None): - # if with_observation: - # self._observation_ready.wait(timeout) - # self._observation_ready.clear() - self._epoch_data_ready.wait(timeout) - self._epoch_data_ready.clear() - return self._epoch_data if with_observation is False else {**self._observation, **self._epoch_data} - - def data(self): - return self._epoch_data - - def observe(self, aps, peers): - num_aps = len(aps) - if num_aps == 0: - self.blind_for += 1 - else: - self.blind_for = 0 - - bond_unit_scale = self.config['personality']['bond_encounters_factor'] - - self.num_peers = len(peers) - num_peers = self.num_peers + 1e-10 # avoid division by 0 - - self.tot_bond_factor = sum((peer.encounters for peer in peers)) / bond_unit_scale - self.avg_bond_factor = self.tot_bond_factor / num_peers - - num_aps = len(aps) + 1e-10 - num_sta = sum(len(ap['clients']) for ap in aps) + 1e-10 - aps_per_chan = [0.0] * wifi.NumChannels - sta_per_chan = [0.0] * wifi.NumChannels - peers_per_chan = [0.0] * wifi.NumChannels - - for ap in aps: - ch_idx = ap['channel'] - 1 - try: - aps_per_chan[ch_idx] += 1.0 - sta_per_chan[ch_idx] += len(ap['clients']) - except IndexError: - logging.error("got data on channel %d, we can store %d channels" % (ap['channel'], wifi.NumChannels)) - - for peer in peers: - try: - peers_per_chan[peer.last_channel - 1] += 1.0 - except IndexError: - logging.error( - "got peer data on channel %d, we can store %d channels" % (peer.last_channel, wifi.NumChannels)) - - # normalize - aps_per_chan = [e / num_aps for e in aps_per_chan] - sta_per_chan = [e / num_sta for e in sta_per_chan] - peers_per_chan = [e / num_peers for e in peers_per_chan] - - self._observation = { - 'aps_histogram': aps_per_chan, - 'sta_histogram': sta_per_chan, - 'peers_histogram': peers_per_chan - } - self._observation_ready.set() - - def track(self, deauth=False, assoc=False, handshake=False, hop=False, sleep=False, miss=False, inc=1): - if deauth: - self.num_deauths += inc - self.did_deauth = True - self.any_activity = True - - if assoc: - self.num_assocs += inc - self.did_associate = True - self.any_activity = True - - if miss: - self.num_missed += inc - - if hop: - self.num_hops += inc - # these two are used in order to determine the sleep time in seconds - # before switching to a new channel ... if nothing happened so far - # during this epoch on the current channel, we will sleep less - self.did_deauth = False - self.did_associate = False - - if handshake: - self.num_shakes += inc - self.did_handshakes = True - - if sleep: - self.num_slept += inc - - def next(self): - if self.any_activity is False and self.did_handshakes is False: - self.inactive_for += 1 - self.active_for = 0 - else: - self.active_for += 1 - self.inactive_for = 0 - self.sad_for = 0 - self.bored_for = 0 - - if self.inactive_for >= self.config['personality']['sad_num_epochs']: - # sad > bored; cant be sad and bored - self.bored_for = 0 - self.sad_for += 1 - elif self.inactive_for >= self.config['personality']['bored_num_epochs']: - # sad_treshhold > inactive > bored_treshhold; cant be sad and bored - self.sad_for = 0 - self.bored_for += 1 - else: - self.sad_for = 0 - self.bored_for = 0 - - now = time.time() - cpu = pwnagotchi.cpu_load("epoch") - mem = pwnagotchi.mem_usage() - temp = pwnagotchi.temperature() - - self.epoch_duration = now - self.epoch_started - - # cache the state of this epoch for other threads to read - self._epoch_data = { - 'duration_secs': self.epoch_duration, - 'slept_for_secs': self.num_slept, - 'blind_for_epochs': self.blind_for, - 'inactive_for_epochs': self.inactive_for, - 'active_for_epochs': self.active_for, - 'sad_for_epochs': self.sad_for, - 'bored_for_epochs': self.bored_for, - 'missed_interactions': self.num_missed, - 'num_hops': self.num_hops, - 'num_peers': self.num_peers, - 'tot_bond': self.tot_bond_factor, - 'avg_bond': self.avg_bond_factor, - 'num_deauths': self.num_deauths, - 'num_associations': self.num_assocs, - 'num_handshakes': self.num_shakes, - 'cpu_load': cpu, - 'mem_usage': mem, - 'temperature': temp - } - - self._epoch_data['reward'] = self._reward(self.epoch + 1, self._epoch_data) - self._epoch_data_ready.set() - - logging.info("[epoch %d] duration=%s slept_for=%s blind=%d sad=%d bored=%d inactive=%d active=%d peers=%d tot_bond=%.2f " - "avg_bond=%.2f hops=%d missed=%d deauths=%d assocs=%d handshakes=%d cpu=%d%% mem=%d%% " - "temperature=%dC reward=%s" % ( - self.epoch, - utils.secs_to_hhmmss(self.epoch_duration), - utils.secs_to_hhmmss(self.num_slept), - self.blind_for, - self.sad_for, - self.bored_for, - self.inactive_for, - self.active_for, - self.num_peers, - self.tot_bond_factor, - self.avg_bond_factor, - self.num_hops, - self.num_missed, - self.num_deauths, - self.num_assocs, - self.num_shakes, - cpu * 100, - mem * 100, - temp, - self._epoch_data['reward'])) - - self.epoch += 1 - self.epoch_started = now - self.did_deauth = False - self.num_deauths = 0 - self.num_peers = 0 - self.tot_bond_factor = 0.0 - self.avg_bond_factor = 0.0 - self.did_associate = False - self.num_assocs = 0 - self.num_missed = 0 - self.did_handshakes = False - self.num_shakes = 0 - self.num_hops = 0 - self.num_slept = 0 - self.any_activity = False diff --git a/pwnagotchi/ai/featurizer.py b/pwnagotchi/ai/featurizer.py deleted file mode 100644 index d8f27a74..00000000 --- a/pwnagotchi/ai/featurizer.py +++ /dev/null @@ -1,66 +0,0 @@ -import numpy as np - -import pwnagotchi.mesh.wifi as wifi - -MAX_EPOCH_DURATION = 1024 - - -def describe(extended=False): - if not extended: - histogram_size = wifi.NumChannels - else: - # see https://github.com/evilsocket/pwnagotchi/issues/583 - histogram_size = wifi.NumChannelsExt - - return histogram_size, (1, - # aps per channel - histogram_size + - # clients per channel - histogram_size + - # peers per channel - histogram_size + - # duration - 1 + - # inactive - 1 + - # active - 1 + - # missed - 1 + - # hops - 1 + - # deauths - 1 + - # assocs - 1 + - # handshakes - 1) - - -def featurize(state, step): - tot_epochs = step + 1e-10 - tot_interactions = (state['num_deauths'] + state['num_associations']) + 1e-10 - return np.concatenate(( - # aps per channel - state['aps_histogram'], - # clients per channel - state['sta_histogram'], - # peers per channel - state['peers_histogram'], - # duration - [np.clip(state['duration_secs'] / MAX_EPOCH_DURATION, 0.0, 1.0)], - # inactive - [state['inactive_for_epochs'] / tot_epochs], - # active - [state['active_for_epochs'] / tot_epochs], - # missed - [state['missed_interactions'] / tot_interactions], - # hops - [state['num_hops'] / wifi.NumChannels], - # deauths - [state['num_deauths'] / tot_interactions], - # assocs - [state['num_associations'] / tot_interactions], - # handshakes - [state['num_handshakes'] / tot_interactions], - )) diff --git a/pwnagotchi/ai/gym.py b/pwnagotchi/ai/gym.py deleted file mode 100644 index c188d31d..00000000 --- a/pwnagotchi/ai/gym.py +++ /dev/null @@ -1,151 +0,0 @@ -import logging -import gym -from gym import spaces -import numpy as np - -import pwnagotchi.ai.featurizer as featurizer -import pwnagotchi.ai.reward as reward -from pwnagotchi.ai.parameter import Parameter - - -class Environment(gym.Env): - render_mode = "human" - metadata = {'render_modes': ['human']} - params = [ - Parameter('min_rssi', min_value=-200, max_value=-50), - Parameter('ap_ttl', min_value=30, max_value=600), - Parameter('sta_ttl', min_value=60, max_value=300), - - Parameter('recon_time', min_value=5, max_value=60), - Parameter('max_inactive_scale', min_value=3, max_value=10), - Parameter('recon_inactive_multiplier', min_value=1, max_value=3), - Parameter('hop_recon_time', min_value=5, max_value=60), - Parameter('min_recon_time', min_value=1, max_value=30), - Parameter('max_interactions', min_value=1, max_value=25), - Parameter('max_misses_for_recon', min_value=3, max_value=10), - Parameter('excited_num_epochs', min_value=5, max_value=30), - Parameter('bored_num_epochs', min_value=5, max_value=30), - Parameter('sad_num_epochs', min_value=5, max_value=30), - ] - - def __init__(self, agent, epoch): - super(Environment, self).__init__() - self._agent = agent - self._epoch = epoch - self._epoch_num = 0 - self._last_render = None - - # see https://github.com/evilsocket/pwnagotchi/issues/583 - self._supported_channels = agent.supported_channels() - self._extended_spectrum = any(ch > 150 for ch in self._supported_channels) - self._histogram_size, self._observation_shape = featurizer.describe(self._extended_spectrum) - - Environment.params += [ - Parameter('_channel_%d' % ch, min_value=0, max_value=1, meta=ch + 1) for ch in - range(self._histogram_size) if ch + 1 in self._supported_channels - ] - - self.last = { - 'reward': 0.0, - 'observation': None, - 'policy': None, - 'params': {}, - 'state': None, - 'state_v': None - } - - self.action_space = spaces.MultiDiscrete([p.space_size() for p in Environment.params if p.trainable]) - self.observation_space = spaces.Box(low=0, high=1, shape=self._observation_shape, dtype=np.float32) - self.reward_range = reward.range - - @staticmethod - def policy_size(): - return len(list(p for p in Environment.params if p.trainable)) - - @staticmethod - def policy_to_params(policy): - num = len(policy) - params = {} - - assert len(Environment.params) == num - - channels = [] - - for i in range(num): - param = Environment.params[i] - - if '_channel' not in param.name: - params[param.name] = param.to_param_value(policy[i]) - else: - has_chan = param.to_param_value(policy[i]) - # print("%s policy:%s bool:%s" % (param.name, policy[i], has_chan)) - chan = param.meta - if has_chan: - channels.append(chan) - - params['channels'] = channels - - return params - - def _next_epoch(self): - logging.debug("[ai] waiting for epoch to finish ...") - return self._epoch.wait_for_epoch_data() - - def _apply_policy(self, policy): - new_params = Environment.policy_to_params(policy) - self.last['policy'] = policy - self.last['params'] = new_params - self._agent.on_ai_policy(new_params) - - def step(self, policy): - # create the parameters from the policy and update - # them in the algorithm - self._apply_policy(policy) - self._epoch_num += 1 - - # wait for the algorithm to run with the new parameters - state = self._next_epoch() - - self.last['reward'] = state['reward'] - self.last['state'] = state - self.last['state_v'] = featurizer.featurize(state, self._epoch_num) - - self._agent.on_ai_step() - - return self.last['state_v'], self.last['reward'], not self._agent.is_training(), {} - - def reset(self): - # logging.info("[ai] resetting environment ...") - self._epoch_num = 0 - state = self._next_epoch() - self.last['state'] = state - self.last['state_v'] = featurizer.featurize(state, 1) - return self.last['state_v'] - - def _render_histogram(self, hist): - for ch in range(self._histogram_size): - if hist[ch]: - logging.info(" CH %d: %s" % (ch + 1, hist[ch])) - - def render(self, mode='human', close=False, force=False): - # when using a vectorialized environment, render gets called twice - # avoid rendering the same data - if self._last_render == self._epoch_num: - return - - if not self._agent.is_training() and not force: - return - - self._last_render = self._epoch_num - - logging.info("[AI] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs())) - logging.info("[AI] REWARD: %f" % self.last['reward']) - - logging.debug( - "[AI] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items())) - - logging.info("[AI] observation:") - for name, value in self.last['state'].items(): - if 'histogram' in name: - logging.info(" %s" % name.replace('_histogram', '')) - self._render_histogram(value) diff --git a/pwnagotchi/ai/parameter.py b/pwnagotchi/ai/parameter.py deleted file mode 100644 index 414129b7..00000000 --- a/pwnagotchi/ai/parameter.py +++ /dev/null @@ -1,30 +0,0 @@ -from gym import spaces - - -class Parameter(object): - def __init__(self, name, value=0.0, min_value=0, max_value=2, meta=None, trainable=True): - self.name = name - self.trainable = trainable - self.meta = meta - self.value = value - self.min_value = min_value - self.max_value = max_value + 1 - - # gymnasium.space.Discrete is within [0, 1, 2, ..., n-1] - if self.min_value < 0: - self.scale_factor = abs(self.min_value) - elif self.min_value > 0: - self.scale_factor = -self.min_value - else: - self.scale_factor = 0 - - def space_size(self): - return self.max_value + self.scale_factor - - def space(self): - return spaces.Discrete(self.max_value + self.scale_factor) - - def to_param_value(self, policy_v): - self.value = policy_v - self.scale_factor - assert self.min_value <= self.value <= self.max_value - return int(self.value) diff --git a/pwnagotchi/ai/reward.py b/pwnagotchi/ai/reward.py deleted file mode 100644 index daaf75f6..00000000 --- a/pwnagotchi/ai/reward.py +++ /dev/null @@ -1,27 +0,0 @@ -import pwnagotchi.mesh.wifi as wifi - -range = (-.7, 1.02) -fuck_zero = 1e-20 - - -class RewardFunction(object): - def __call__(self, epoch_n, state): - tot_epochs = epoch_n + fuck_zero - tot_interactions = max(state['num_deauths'] + state['num_associations'], state['num_handshakes']) + fuck_zero - tot_channels = wifi.NumChannels - - h = state['num_handshakes'] / tot_interactions - a = .2 * (state['active_for_epochs'] / tot_epochs) - c = .1 * (state['num_hops'] / tot_channels) - - b = -.3 * (state['blind_for_epochs'] / tot_epochs) - m = -.3 * (state['missed_interactions'] / tot_interactions) - i = -.2 * (state['inactive_for_epochs'] / tot_epochs) - - # include emotions if state >= 5 epochs - _sad = state['sad_for_epochs'] if state['sad_for_epochs'] >= 5 else 0 - _bored = state['bored_for_epochs'] if state['bored_for_epochs'] >= 5 else 0 - s = -.2 * (_sad / tot_epochs) - l = -.1 * (_bored / tot_epochs) - - return h + a + c + b + i + m + s + l diff --git a/pwnagotchi/ai/train.py b/pwnagotchi/ai/train.py deleted file mode 100644 index 58758d6f..00000000 --- a/pwnagotchi/ai/train.py +++ /dev/null @@ -1,197 +0,0 @@ -import _thread -import threading -import time -import random -import os -import json -import logging - -import pwnagotchi.plugins as plugins -import pwnagotchi.ai as ai - - -class Stats(object): - def __init__(self, path, events_receiver): - self._lock = threading.Lock() - self._receiver = events_receiver - - self.path = path - self.born_at = time.time() - # total epochs lived (trained + just eval) - self.epochs_lived = 0 - # total training epochs - self.epochs_trained = 0 - - self.worst_reward = 0.0 - self.best_reward = 0.0 - - self.load() - - def on_epoch(self, data, training): - best_r = False - worst_r = False - with self._lock: - reward = data['reward'] - if reward < self.worst_reward: - self.worst_reward = reward - worst_r = True - - elif reward > self.best_reward: - best_r = True - self.best_reward = reward - - self.epochs_lived += 1 - if training: - self.epochs_trained += 1 - - self.save() - - if best_r: - self._receiver.on_ai_best_reward(reward) - elif worst_r: - self._receiver.on_ai_worst_reward(reward) - - def load(self): - with self._lock: - if os.path.exists(self.path) and os.path.getsize(self.path) > 0: - logging.info("[AI] loading %s" % self.path) - with open(self.path, 'rt') as fp: - obj = json.load(fp) - - self.born_at = obj['born_at'] - self.epochs_lived, self.epochs_trained = obj['epochs_lived'], obj['epochs_trained'] - self.best_reward, self.worst_reward = obj['rewards']['best'], obj['rewards']['worst'] - - def save(self): - with self._lock: - logging.info("[AI] saving %s" % self.path) - - data = json.dumps({ - 'born_at': self.born_at, - 'epochs_lived': self.epochs_lived, - 'epochs_trained': self.epochs_trained, - 'rewards': { - 'best': self.best_reward, - 'worst': self.worst_reward - } - }) - - temp = "%s.tmp" % self.path - back = "%s.bak" % self.path - with open(temp, 'wt') as fp: - fp.write(data) - - if os.path.isfile(self.path): - os.replace(self.path, back) - os.replace(temp, self.path) - - -class AsyncTrainer(object): - def __init__(self, config): - self._config = config - self._model = None - self._is_training = False - self._training_epochs = 0 - self._nn_path = self._config['ai']['path'] - self._stats = Stats("%s.json" % os.path.splitext(self._nn_path)[0], self) - - def set_training(self, training, for_epochs=0): - self._is_training = training - self._training_epochs = for_epochs - - if training: - plugins.on('ai_training_start', self, for_epochs) - else: - plugins.on('ai_training_end', self) - - def is_training(self): - return self._is_training - - def training_epochs(self): - return self._training_epochs - - def start_ai(self): - _thread.start_new_thread(self._ai_worker, ()) - - def _save_ai(self): - logging.info("[AI] saving model to %s ..." % self._nn_path) - temp = "%s.tmp" % self._nn_path - self._model.save(temp) - os.replace(temp, self._nn_path) - - def on_ai_step(self): - self._model.env.render() - - if self._is_training: - self._save_ai() - - self._stats.on_epoch(self._epoch.data(), self._is_training) - - def on_ai_training_step(self, _locals, _globals): - self._model.env.render() - plugins.on('ai_training_step', self, _locals, _globals) - - def on_ai_policy(self, new_params): - plugins.on('ai_policy', self, new_params) - logging.info("[AI] setting new policy:") - for name, value in new_params.items(): - if name in self._config['personality']: - curr_value = self._config['personality'][name] - if curr_value != value: - logging.info("[AI] ! %s: %s -> %s" % (name, curr_value, value)) - self._config['personality'][name] = value - else: - logging.error("[AI] param %s not in personality configuration!" % name) - - self.run('set wifi.ap.ttl %d' % self._config['personality']['ap_ttl']) - self.run('set wifi.sta.ttl %d' % self._config['personality']['sta_ttl']) - self.run('set wifi.rssi.min %d' % self._config['personality']['min_rssi']) - - def on_ai_ready(self): - self._view.on_ai_ready() - plugins.on('ai_ready', self) - - def on_ai_best_reward(self, r): - logging.info("[AI] best reward so far: %s" % r) - self._view.on_motivated(r) - plugins.on('ai_best_reward', self, r) - - def on_ai_worst_reward(self, r): - logging.info("[AI] worst reward so far: %s" % r) - self._view.on_demotivated(r) - plugins.on('ai_worst_reward', self, r) - - def _ai_worker(self): - self._model = ai.load(self._config, self, self._epoch) - - if self._model: - self.on_ai_ready() - - epochs_per_episode = self._config['ai']['epochs_per_episode'] - - obs = None - while True: - self._model.env.render() - # enter in training mode? - if random.random() > self._config['ai']['laziness']: - logging.info("[AI] learning for %d epochs ..." % epochs_per_episode) - try: - self.set_training(True, epochs_per_episode) - # back up brain file before starting new training set - if os.path.isfile(self._nn_path): - back = "%s.bak" % self._nn_path - os.replace(self._nn_path, back) - self._view.set("mode", " AI") - self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step) - except Exception as e: - logging.exception("[AI] error while training (%s)", e) - finally: - self.set_training(False) - obs = self._model.env.reset() - # init the first time - elif obs is None: - obs = self._model.env.reset() - - # run the inference - action, _ = self._model.predict(obs) - obs, _, _, _ = self._model.env.step(action) diff --git a/pwnagotchi/ai/utils.py b/pwnagotchi/ai/utils.py deleted file mode 100644 index e6284d05..00000000 --- a/pwnagotchi/ai/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy as np - - -def normalize(v, min_v, max_v): - return (v - min_v) / (max_v - min_v) - - -def as_batches(x, y, batch_size, shuffle=True): - x_size = len(x) - assert x_size == len(y) - - indices = np.random.permutation(x_size) if shuffle else None - - for offset in range(0, x_size - batch_size + 1, batch_size): - excerpt = indices[offset:offset + batch_size] if shuffle else slice(offset, offset + batch_size) - yield x[excerpt], y[excerpt]