diff --git a/pwnagotchi/ai/reward.py b/pwnagotchi/ai/reward.py index daaf75f6..0767a46b 100644 --- a/pwnagotchi/ai/reward.py +++ b/pwnagotchi/ai/reward.py @@ -1,27 +1,28 @@ import pwnagotchi.mesh.wifi as wifi -range = (-.7, 1.02) -fuck_zero = 1e-20 +range: tuple[float, float] = (-.7, 1.02) +fuck_zero: float = 1e-20 class RewardFunction(object): - def __call__(self, epoch_n, state): - tot_epochs = epoch_n + fuck_zero - tot_interactions = max(state['num_deauths'] + state['num_associations'], state['num_handshakes']) + fuck_zero - tot_channels = wifi.NumChannels + def __call__(self, epoch_n: float, state: dict[str, float]) -> float: - h = state['num_handshakes'] / tot_interactions - a = .2 * (state['active_for_epochs'] / tot_epochs) - c = .1 * (state['num_hops'] / tot_channels) + tot_epochs: float = epoch_n + fuck_zero + tot_interactions: float = max(state['num_deauths'] + state['num_associations'], state['num_handshakes']) + fuck_zero + tot_channels: int = wifi.NumChannels - b = -.3 * (state['blind_for_epochs'] / tot_epochs) - m = -.3 * (state['missed_interactions'] / tot_interactions) - i = -.2 * (state['inactive_for_epochs'] / tot_epochs) + h: float = state['num_handshakes'] / tot_interactions + a: float = .2 * (state['active_for_epochs'] / tot_epochs) + c: float = .1 * (state['num_hops'] / tot_channels) + + b: float = -.3 * (state['blind_for_epochs'] / tot_epochs) + m: float = -.3 * (state['missed_interactions'] / tot_interactions) + i: float = -.2 * (state['inactive_for_epochs'] / tot_epochs) # include emotions if state >= 5 epochs - _sad = state['sad_for_epochs'] if state['sad_for_epochs'] >= 5 else 0 - _bored = state['bored_for_epochs'] if state['bored_for_epochs'] >= 5 else 0 - s = -.2 * (_sad / tot_epochs) - l = -.1 * (_bored / tot_epochs) + _sad: float = state['sad_for_epochs'] if state['sad_for_epochs'] >= 5 else 0 + _bored: float = state['bored_for_epochs'] if state['bored_for_epochs'] >= 5 else 0 + s: float = -.2 * (_sad / tot_epochs) + l: float = -.1 * (_bored / tot_epochs) return h + a + c + b + i + m + s + l diff --git a/pwnagotchi/mesh/wifi.py b/pwnagotchi/mesh/wifi.py index be5d329c..f5b23008 100644 --- a/pwnagotchi/mesh/wifi.py +++ b/pwnagotchi/mesh/wifi.py @@ -1,6 +1,7 @@ -NumChannels = 233 -def freq_to_channel(freq): +NumChannels: int = 233 + +def freq_to_channel(freq: int) -> int: if 2412 <= freq <= 2472: # 2.4ghz wifi return int(((freq - 2412) / 5) + 1) elif freq == 2484: # channel 14 special