diff --git a/pwnagotchi/agent.py b/pwnagotchi/agent.py index becea5f5..e5715f4d 100644 --- a/pwnagotchi/agent.py +++ b/pwnagotchi/agent.py @@ -14,6 +14,7 @@ from pwnagotchi.automata import Automata from pwnagotchi.log import LastSession from pwnagotchi.bettercap import Client from pwnagotchi.mesh.utils import AsyncAdvertiser +# from pwnagotchi.ai.train import AsyncTrainer RECOVERY_DATA_FILE = '/root/.pwnagotchi-recovery' @@ -27,6 +28,7 @@ class Agent(Client, Automata, AsyncAdvertiser): config['bettercap']['password']) Automata.__init__(self, config, view) AsyncAdvertiser.__init__(self, config, view, keypair) + # AsyncTrainer.__init__(self, config) self._started_at = time.time() self._filter = None if not config['main']['filter'] else re.compile(config['main']['filter']) @@ -127,6 +129,7 @@ class Agent(Client, Automata, AsyncAdvertiser): time.sleep(1) def start(self): + # self.start_ai() self._wait_bettercap() self.setup_events() self.set_starting() diff --git a/pwnagotchi/ai/__init__.py b/pwnagotchi/ai/__init__.py new file mode 100644 index 00000000..ba458651 --- /dev/null +++ b/pwnagotchi/ai/__init__.py @@ -0,0 +1,67 @@ +import os +import time +import logging + +# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709 +# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'} + + +def load(config, agent, epoch, from_disk=True): + config = config['ai'] + if not config['enabled']: + logging.info("ai disabled") + return False + + try: + begin = time.time() + + logging.info("[AI] bootstrapping dependencies ...") + + start = time.time() + SB_BACKEND = "stable_baselines3" + + from stable_baselines3 import A2C + logging.debug("[AI] A2C imported in %.2fs" % (time.time() - start)) + + start = time.time() + from stable_baselines3.a2c import MlpPolicy + logging.debug("[AI] MlpPolicy imported in %.2fs" % (time.time() - start)) + SB_A2C_POLICY = MlpPolicy + + start = time.time() + from stable_baselines3.common.vec_env import DummyVecEnv + logging.debug("[AI] DummyVecEnv imported in %.2fs" % (time.time() - start)) + + start = time.time() + import pwnagotchi.ai.gym as wrappers + logging.debug("[AI] gym wrapper imported in %.2fs" % (time.time() - start)) + + env = wrappers.Environment(agent, epoch) + env = DummyVecEnv([lambda: env]) + + logging.info("[AI] creating model ...") + + start = time.time() + a2c = A2C(SB_A2C_POLICY, env, **config['params']) + logging.debug("[AI] A2C created in %.2fs" % (time.time() - start)) + + if from_disk and os.path.exists(config['path']): + logging.info("[AI] loading %s ..." % config['path']) + start = time.time() + a2c.load(config['path'], env) + logging.debug("[AI] A2C loaded in %.2fs" % (time.time() - start)) + else: + logging.info("[AI] model created:") + for key, value in config['params'].items(): + logging.info(" %s: %s" % (key, value)) + + logging.debug("[AI] total loading time is %.2fs" % (time.time() - begin)) + + return a2c + except Exception as e: + logging.exception("[AI] error while starting AI (%s)", e) + logging.info("[AI] Deleting brain and restarting.") + os.system("rm /root/brain.nn && service pwnagotchi restart") + + logging.warning("[AI] AI not loaded!") + return False diff --git a/pwnagotchi/ai/epoch.py b/pwnagotchi/ai/epoch.py new file mode 100644 index 00000000..2ba47212 --- /dev/null +++ b/pwnagotchi/ai/epoch.py @@ -0,0 +1,249 @@ +import time +import threading +import logging + +import pwnagotchi +import pwnagotchi.utils as utils +import pwnagotchi.mesh.wifi as wifi + +from pwnagotchi.ai.reward import RewardFunction + + +class Epoch(object): + def __init__(self, config): + self.epoch = 0 + self.config = config + # how many consecutive epochs with no activity + self.inactive_for = 0 + # how many consecutive epochs with activity + self.active_for = 0 + # number of epochs with no visible access points + self.blind_for = 0 + # number of epochs in sad state + self.sad_for = 0 + # number of epochs in bored state + self.bored_for = 0 + # did deauth in this epoch in the current channel? + self.did_deauth = False + # number of deauths in this epoch + self.num_deauths = 0 + # did associate in this epoch in the current channel? + self.did_associate = False + # number of associations in this epoch + self.num_assocs = 0 + # number of assocs or deauths missed + self.num_missed = 0 + # did get any handshake in this epoch? + self.did_handshakes = False + # number of handshakes captured in this epoch + self.num_shakes = 0 + # number of channels hops + self.num_hops = 0 + # number of seconds sleeping + self.num_slept = 0 + # number of peers seen during this epoch + self.num_peers = 0 + # cumulative bond factor + self.tot_bond_factor = 0.0 # cum_bond_factor sounded worse ... + # average bond factor + self.avg_bond_factor = 0.0 + # any activity at all during this epoch? + self.any_activity = False + # when the current epoch started + self.epoch_started = time.time() + # last epoch duration + self.epoch_duration = 0 + # https://www.metageek.com/training/resources/why-channels-1-6-11.html + self.non_overlapping_channels = {1: 0, 6: 0, 11: 0} + # observation vectors + self._observation = { + 'aps_histogram': [0.0] * wifi.NumChannels, + 'sta_histogram': [0.0] * wifi.NumChannels, + 'peers_histogram': [0.0] * wifi.NumChannels + } + self._observation_ready = threading.Event() + self._epoch_data = {} + self._epoch_data_ready = threading.Event() + self._reward = RewardFunction() + + def wait_for_epoch_data(self, with_observation=True, timeout=None): + # if with_observation: + # self._observation_ready.wait(timeout) + # self._observation_ready.clear() + self._epoch_data_ready.wait(timeout) + self._epoch_data_ready.clear() + return self._epoch_data if with_observation is False else {**self._observation, **self._epoch_data} + + def data(self): + return self._epoch_data + + def observe(self, aps, peers): + num_aps = len(aps) + if num_aps == 0: + self.blind_for += 1 + else: + self.blind_for = 0 + + bond_unit_scale = self.config['personality']['bond_encounters_factor'] + + self.num_peers = len(peers) + num_peers = self.num_peers + 1e-10 # avoid division by 0 + + self.tot_bond_factor = sum((peer.encounters for peer in peers)) / bond_unit_scale + self.avg_bond_factor = self.tot_bond_factor / num_peers + + num_aps = len(aps) + 1e-10 + num_sta = sum(len(ap['clients']) for ap in aps) + 1e-10 + aps_per_chan = [0.0] * wifi.NumChannels + sta_per_chan = [0.0] * wifi.NumChannels + peers_per_chan = [0.0] * wifi.NumChannels + + for ap in aps: + ch_idx = ap['channel'] - 1 + try: + aps_per_chan[ch_idx] += 1.0 + sta_per_chan[ch_idx] += len(ap['clients']) + except IndexError: + logging.error("got data on channel %d, we can store %d channels" % (ap['channel'], wifi.NumChannels)) + + for peer in peers: + try: + peers_per_chan[peer.last_channel - 1] += 1.0 + except IndexError: + logging.error( + "got peer data on channel %d, we can store %d channels" % (peer.last_channel, wifi.NumChannels)) + + # normalize + aps_per_chan = [e / num_aps for e in aps_per_chan] + sta_per_chan = [e / num_sta for e in sta_per_chan] + peers_per_chan = [e / num_peers for e in peers_per_chan] + + self._observation = { + 'aps_histogram': aps_per_chan, + 'sta_histogram': sta_per_chan, + 'peers_histogram': peers_per_chan + } + self._observation_ready.set() + + def track(self, deauth=False, assoc=False, handshake=False, hop=False, sleep=False, miss=False, inc=1): + if deauth: + self.num_deauths += inc + self.did_deauth = True + self.any_activity = True + + if assoc: + self.num_assocs += inc + self.did_associate = True + self.any_activity = True + + if miss: + self.num_missed += inc + + if hop: + self.num_hops += inc + # these two are used in order to determine the sleep time in seconds + # before switching to a new channel ... if nothing happened so far + # during this epoch on the current channel, we will sleep less + self.did_deauth = False + self.did_associate = False + + if handshake: + self.num_shakes += inc + self.did_handshakes = True + + if sleep: + self.num_slept += inc + + def next(self): + if self.any_activity is False and self.did_handshakes is False: + self.inactive_for += 1 + self.active_for = 0 + else: + self.active_for += 1 + self.inactive_for = 0 + self.sad_for = 0 + self.bored_for = 0 + + if self.inactive_for >= self.config['personality']['sad_num_epochs']: + # sad > bored; cant be sad and bored + self.bored_for = 0 + self.sad_for += 1 + elif self.inactive_for >= self.config['personality']['bored_num_epochs']: + # sad_treshhold > inactive > bored_treshhold; cant be sad and bored + self.sad_for = 0 + self.bored_for += 1 + else: + self.sad_for = 0 + self.bored_for = 0 + + now = time.time() + cpu = pwnagotchi.cpu_load("epoch") + mem = pwnagotchi.mem_usage() + temp = pwnagotchi.temperature() + + self.epoch_duration = now - self.epoch_started + + # cache the state of this epoch for other threads to read + self._epoch_data = { + 'duration_secs': self.epoch_duration, + 'slept_for_secs': self.num_slept, + 'blind_for_epochs': self.blind_for, + 'inactive_for_epochs': self.inactive_for, + 'active_for_epochs': self.active_for, + 'sad_for_epochs': self.sad_for, + 'bored_for_epochs': self.bored_for, + 'missed_interactions': self.num_missed, + 'num_hops': self.num_hops, + 'num_peers': self.num_peers, + 'tot_bond': self.tot_bond_factor, + 'avg_bond': self.avg_bond_factor, + 'num_deauths': self.num_deauths, + 'num_associations': self.num_assocs, + 'num_handshakes': self.num_shakes, + 'cpu_load': cpu, + 'mem_usage': mem, + 'temperature': temp + } + + self._epoch_data['reward'] = self._reward(self.epoch + 1, self._epoch_data) + self._epoch_data_ready.set() + + logging.info("[epoch %d] duration=%s slept_for=%s blind=%d sad=%d bored=%d inactive=%d active=%d peers=%d tot_bond=%.2f " + "avg_bond=%.2f hops=%d missed=%d deauths=%d assocs=%d handshakes=%d cpu=%d%% mem=%d%% " + "temperature=%dC reward=%s" % ( + self.epoch, + utils.secs_to_hhmmss(self.epoch_duration), + utils.secs_to_hhmmss(self.num_slept), + self.blind_for, + self.sad_for, + self.bored_for, + self.inactive_for, + self.active_for, + self.num_peers, + self.tot_bond_factor, + self.avg_bond_factor, + self.num_hops, + self.num_missed, + self.num_deauths, + self.num_assocs, + self.num_shakes, + cpu * 100, + mem * 100, + temp, + self._epoch_data['reward'])) + + self.epoch += 1 + self.epoch_started = now + self.did_deauth = False + self.num_deauths = 0 + self.num_peers = 0 + self.tot_bond_factor = 0.0 + self.avg_bond_factor = 0.0 + self.did_associate = False + self.num_assocs = 0 + self.num_missed = 0 + self.did_handshakes = False + self.num_shakes = 0 + self.num_hops = 0 + self.num_slept = 0 + self.any_activity = False diff --git a/pwnagotchi/ai/featurizer.py b/pwnagotchi/ai/featurizer.py new file mode 100644 index 00000000..d8f27a74 --- /dev/null +++ b/pwnagotchi/ai/featurizer.py @@ -0,0 +1,66 @@ +import numpy as np + +import pwnagotchi.mesh.wifi as wifi + +MAX_EPOCH_DURATION = 1024 + + +def describe(extended=False): + if not extended: + histogram_size = wifi.NumChannels + else: + # see https://github.com/evilsocket/pwnagotchi/issues/583 + histogram_size = wifi.NumChannelsExt + + return histogram_size, (1, + # aps per channel + histogram_size + + # clients per channel + histogram_size + + # peers per channel + histogram_size + + # duration + 1 + + # inactive + 1 + + # active + 1 + + # missed + 1 + + # hops + 1 + + # deauths + 1 + + # assocs + 1 + + # handshakes + 1) + + +def featurize(state, step): + tot_epochs = step + 1e-10 + tot_interactions = (state['num_deauths'] + state['num_associations']) + 1e-10 + return np.concatenate(( + # aps per channel + state['aps_histogram'], + # clients per channel + state['sta_histogram'], + # peers per channel + state['peers_histogram'], + # duration + [np.clip(state['duration_secs'] / MAX_EPOCH_DURATION, 0.0, 1.0)], + # inactive + [state['inactive_for_epochs'] / tot_epochs], + # active + [state['active_for_epochs'] / tot_epochs], + # missed + [state['missed_interactions'] / tot_interactions], + # hops + [state['num_hops'] / wifi.NumChannels], + # deauths + [state['num_deauths'] / tot_interactions], + # assocs + [state['num_associations'] / tot_interactions], + # handshakes + [state['num_handshakes'] / tot_interactions], + )) diff --git a/pwnagotchi/ai/gym.py b/pwnagotchi/ai/gym.py new file mode 100644 index 00000000..c188d31d --- /dev/null +++ b/pwnagotchi/ai/gym.py @@ -0,0 +1,151 @@ +import logging +import gym +from gym import spaces +import numpy as np + +import pwnagotchi.ai.featurizer as featurizer +import pwnagotchi.ai.reward as reward +from pwnagotchi.ai.parameter import Parameter + + +class Environment(gym.Env): + render_mode = "human" + metadata = {'render_modes': ['human']} + params = [ + Parameter('min_rssi', min_value=-200, max_value=-50), + Parameter('ap_ttl', min_value=30, max_value=600), + Parameter('sta_ttl', min_value=60, max_value=300), + + Parameter('recon_time', min_value=5, max_value=60), + Parameter('max_inactive_scale', min_value=3, max_value=10), + Parameter('recon_inactive_multiplier', min_value=1, max_value=3), + Parameter('hop_recon_time', min_value=5, max_value=60), + Parameter('min_recon_time', min_value=1, max_value=30), + Parameter('max_interactions', min_value=1, max_value=25), + Parameter('max_misses_for_recon', min_value=3, max_value=10), + Parameter('excited_num_epochs', min_value=5, max_value=30), + Parameter('bored_num_epochs', min_value=5, max_value=30), + Parameter('sad_num_epochs', min_value=5, max_value=30), + ] + + def __init__(self, agent, epoch): + super(Environment, self).__init__() + self._agent = agent + self._epoch = epoch + self._epoch_num = 0 + self._last_render = None + + # see https://github.com/evilsocket/pwnagotchi/issues/583 + self._supported_channels = agent.supported_channels() + self._extended_spectrum = any(ch > 150 for ch in self._supported_channels) + self._histogram_size, self._observation_shape = featurizer.describe(self._extended_spectrum) + + Environment.params += [ + Parameter('_channel_%d' % ch, min_value=0, max_value=1, meta=ch + 1) for ch in + range(self._histogram_size) if ch + 1 in self._supported_channels + ] + + self.last = { + 'reward': 0.0, + 'observation': None, + 'policy': None, + 'params': {}, + 'state': None, + 'state_v': None + } + + self.action_space = spaces.MultiDiscrete([p.space_size() for p in Environment.params if p.trainable]) + self.observation_space = spaces.Box(low=0, high=1, shape=self._observation_shape, dtype=np.float32) + self.reward_range = reward.range + + @staticmethod + def policy_size(): + return len(list(p for p in Environment.params if p.trainable)) + + @staticmethod + def policy_to_params(policy): + num = len(policy) + params = {} + + assert len(Environment.params) == num + + channels = [] + + for i in range(num): + param = Environment.params[i] + + if '_channel' not in param.name: + params[param.name] = param.to_param_value(policy[i]) + else: + has_chan = param.to_param_value(policy[i]) + # print("%s policy:%s bool:%s" % (param.name, policy[i], has_chan)) + chan = param.meta + if has_chan: + channels.append(chan) + + params['channels'] = channels + + return params + + def _next_epoch(self): + logging.debug("[ai] waiting for epoch to finish ...") + return self._epoch.wait_for_epoch_data() + + def _apply_policy(self, policy): + new_params = Environment.policy_to_params(policy) + self.last['policy'] = policy + self.last['params'] = new_params + self._agent.on_ai_policy(new_params) + + def step(self, policy): + # create the parameters from the policy and update + # them in the algorithm + self._apply_policy(policy) + self._epoch_num += 1 + + # wait for the algorithm to run with the new parameters + state = self._next_epoch() + + self.last['reward'] = state['reward'] + self.last['state'] = state + self.last['state_v'] = featurizer.featurize(state, self._epoch_num) + + self._agent.on_ai_step() + + return self.last['state_v'], self.last['reward'], not self._agent.is_training(), {} + + def reset(self): + # logging.info("[ai] resetting environment ...") + self._epoch_num = 0 + state = self._next_epoch() + self.last['state'] = state + self.last['state_v'] = featurizer.featurize(state, 1) + return self.last['state_v'] + + def _render_histogram(self, hist): + for ch in range(self._histogram_size): + if hist[ch]: + logging.info(" CH %d: %s" % (ch + 1, hist[ch])) + + def render(self, mode='human', close=False, force=False): + # when using a vectorialized environment, render gets called twice + # avoid rendering the same data + if self._last_render == self._epoch_num: + return + + if not self._agent.is_training() and not force: + return + + self._last_render = self._epoch_num + + logging.info("[AI] --- training epoch %d/%d ---" % (self._epoch_num, self._agent.training_epochs())) + logging.info("[AI] REWARD: %f" % self.last['reward']) + + logging.debug( + "[AI] policy: %s" % ', '.join("%s:%s" % (name, value) for name, value in self.last['params'].items())) + + logging.info("[AI] observation:") + for name, value in self.last['state'].items(): + if 'histogram' in name: + logging.info(" %s" % name.replace('_histogram', '')) + self._render_histogram(value) diff --git a/pwnagotchi/ai/parameter.py b/pwnagotchi/ai/parameter.py new file mode 100644 index 00000000..414129b7 --- /dev/null +++ b/pwnagotchi/ai/parameter.py @@ -0,0 +1,30 @@ +from gym import spaces + + +class Parameter(object): + def __init__(self, name, value=0.0, min_value=0, max_value=2, meta=None, trainable=True): + self.name = name + self.trainable = trainable + self.meta = meta + self.value = value + self.min_value = min_value + self.max_value = max_value + 1 + + # gymnasium.space.Discrete is within [0, 1, 2, ..., n-1] + if self.min_value < 0: + self.scale_factor = abs(self.min_value) + elif self.min_value > 0: + self.scale_factor = -self.min_value + else: + self.scale_factor = 0 + + def space_size(self): + return self.max_value + self.scale_factor + + def space(self): + return spaces.Discrete(self.max_value + self.scale_factor) + + def to_param_value(self, policy_v): + self.value = policy_v - self.scale_factor + assert self.min_value <= self.value <= self.max_value + return int(self.value) diff --git a/pwnagotchi/ai/reward.py b/pwnagotchi/ai/reward.py new file mode 100644 index 00000000..daaf75f6 --- /dev/null +++ b/pwnagotchi/ai/reward.py @@ -0,0 +1,27 @@ +import pwnagotchi.mesh.wifi as wifi + +range = (-.7, 1.02) +fuck_zero = 1e-20 + + +class RewardFunction(object): + def __call__(self, epoch_n, state): + tot_epochs = epoch_n + fuck_zero + tot_interactions = max(state['num_deauths'] + state['num_associations'], state['num_handshakes']) + fuck_zero + tot_channels = wifi.NumChannels + + h = state['num_handshakes'] / tot_interactions + a = .2 * (state['active_for_epochs'] / tot_epochs) + c = .1 * (state['num_hops'] / tot_channels) + + b = -.3 * (state['blind_for_epochs'] / tot_epochs) + m = -.3 * (state['missed_interactions'] / tot_interactions) + i = -.2 * (state['inactive_for_epochs'] / tot_epochs) + + # include emotions if state >= 5 epochs + _sad = state['sad_for_epochs'] if state['sad_for_epochs'] >= 5 else 0 + _bored = state['bored_for_epochs'] if state['bored_for_epochs'] >= 5 else 0 + s = -.2 * (_sad / tot_epochs) + l = -.1 * (_bored / tot_epochs) + + return h + a + c + b + i + m + s + l diff --git a/pwnagotchi/ai/train.py b/pwnagotchi/ai/train.py new file mode 100644 index 00000000..58758d6f --- /dev/null +++ b/pwnagotchi/ai/train.py @@ -0,0 +1,197 @@ +import _thread +import threading +import time +import random +import os +import json +import logging + +import pwnagotchi.plugins as plugins +import pwnagotchi.ai as ai + + +class Stats(object): + def __init__(self, path, events_receiver): + self._lock = threading.Lock() + self._receiver = events_receiver + + self.path = path + self.born_at = time.time() + # total epochs lived (trained + just eval) + self.epochs_lived = 0 + # total training epochs + self.epochs_trained = 0 + + self.worst_reward = 0.0 + self.best_reward = 0.0 + + self.load() + + def on_epoch(self, data, training): + best_r = False + worst_r = False + with self._lock: + reward = data['reward'] + if reward < self.worst_reward: + self.worst_reward = reward + worst_r = True + + elif reward > self.best_reward: + best_r = True + self.best_reward = reward + + self.epochs_lived += 1 + if training: + self.epochs_trained += 1 + + self.save() + + if best_r: + self._receiver.on_ai_best_reward(reward) + elif worst_r: + self._receiver.on_ai_worst_reward(reward) + + def load(self): + with self._lock: + if os.path.exists(self.path) and os.path.getsize(self.path) > 0: + logging.info("[AI] loading %s" % self.path) + with open(self.path, 'rt') as fp: + obj = json.load(fp) + + self.born_at = obj['born_at'] + self.epochs_lived, self.epochs_trained = obj['epochs_lived'], obj['epochs_trained'] + self.best_reward, self.worst_reward = obj['rewards']['best'], obj['rewards']['worst'] + + def save(self): + with self._lock: + logging.info("[AI] saving %s" % self.path) + + data = json.dumps({ + 'born_at': self.born_at, + 'epochs_lived': self.epochs_lived, + 'epochs_trained': self.epochs_trained, + 'rewards': { + 'best': self.best_reward, + 'worst': self.worst_reward + } + }) + + temp = "%s.tmp" % self.path + back = "%s.bak" % self.path + with open(temp, 'wt') as fp: + fp.write(data) + + if os.path.isfile(self.path): + os.replace(self.path, back) + os.replace(temp, self.path) + + +class AsyncTrainer(object): + def __init__(self, config): + self._config = config + self._model = None + self._is_training = False + self._training_epochs = 0 + self._nn_path = self._config['ai']['path'] + self._stats = Stats("%s.json" % os.path.splitext(self._nn_path)[0], self) + + def set_training(self, training, for_epochs=0): + self._is_training = training + self._training_epochs = for_epochs + + if training: + plugins.on('ai_training_start', self, for_epochs) + else: + plugins.on('ai_training_end', self) + + def is_training(self): + return self._is_training + + def training_epochs(self): + return self._training_epochs + + def start_ai(self): + _thread.start_new_thread(self._ai_worker, ()) + + def _save_ai(self): + logging.info("[AI] saving model to %s ..." % self._nn_path) + temp = "%s.tmp" % self._nn_path + self._model.save(temp) + os.replace(temp, self._nn_path) + + def on_ai_step(self): + self._model.env.render() + + if self._is_training: + self._save_ai() + + self._stats.on_epoch(self._epoch.data(), self._is_training) + + def on_ai_training_step(self, _locals, _globals): + self._model.env.render() + plugins.on('ai_training_step', self, _locals, _globals) + + def on_ai_policy(self, new_params): + plugins.on('ai_policy', self, new_params) + logging.info("[AI] setting new policy:") + for name, value in new_params.items(): + if name in self._config['personality']: + curr_value = self._config['personality'][name] + if curr_value != value: + logging.info("[AI] ! %s: %s -> %s" % (name, curr_value, value)) + self._config['personality'][name] = value + else: + logging.error("[AI] param %s not in personality configuration!" % name) + + self.run('set wifi.ap.ttl %d' % self._config['personality']['ap_ttl']) + self.run('set wifi.sta.ttl %d' % self._config['personality']['sta_ttl']) + self.run('set wifi.rssi.min %d' % self._config['personality']['min_rssi']) + + def on_ai_ready(self): + self._view.on_ai_ready() + plugins.on('ai_ready', self) + + def on_ai_best_reward(self, r): + logging.info("[AI] best reward so far: %s" % r) + self._view.on_motivated(r) + plugins.on('ai_best_reward', self, r) + + def on_ai_worst_reward(self, r): + logging.info("[AI] worst reward so far: %s" % r) + self._view.on_demotivated(r) + plugins.on('ai_worst_reward', self, r) + + def _ai_worker(self): + self._model = ai.load(self._config, self, self._epoch) + + if self._model: + self.on_ai_ready() + + epochs_per_episode = self._config['ai']['epochs_per_episode'] + + obs = None + while True: + self._model.env.render() + # enter in training mode? + if random.random() > self._config['ai']['laziness']: + logging.info("[AI] learning for %d epochs ..." % epochs_per_episode) + try: + self.set_training(True, epochs_per_episode) + # back up brain file before starting new training set + if os.path.isfile(self._nn_path): + back = "%s.bak" % self._nn_path + os.replace(self._nn_path, back) + self._view.set("mode", " AI") + self._model.learn(total_timesteps=epochs_per_episode, callback=self.on_ai_training_step) + except Exception as e: + logging.exception("[AI] error while training (%s)", e) + finally: + self.set_training(False) + obs = self._model.env.reset() + # init the first time + elif obs is None: + obs = self._model.env.reset() + + # run the inference + action, _ = self._model.predict(obs) + obs, _, _, _ = self._model.env.step(action) diff --git a/pwnagotchi/ai/utils.py b/pwnagotchi/ai/utils.py new file mode 100644 index 00000000..e6284d05 --- /dev/null +++ b/pwnagotchi/ai/utils.py @@ -0,0 +1,16 @@ +import numpy as np + + +def normalize(v, min_v, max_v): + return (v - min_v) / (max_v - min_v) + + +def as_batches(x, y, batch_size, shuffle=True): + x_size = len(x) + assert x_size == len(y) + + indices = np.random.permutation(x_size) if shuffle else None + + for offset in range(0, x_size - batch_size + 1, batch_size): + excerpt = indices[offset:offset + batch_size] if shuffle else slice(offset, offset + batch_size) + yield x[excerpt], y[excerpt]