From 1ee940c7985ea83e52fa506c0cec01db721c7a61 Mon Sep 17 00:00:00 2001 From: XxKingsxX-Pinu <58925163+rai68@users.noreply.github.com> Date: Tue, 9 Jul 2024 07:35:53 +1000 Subject: [PATCH] adds daemonise and plugin as threads --- pwnagotchi/agent.py | 4 +- pwnagotchi/ai/train.py | 2 +- pwnagotchi/fs/__init__.py | 2 +- pwnagotchi/mesh/utils.py | 2 +- pwnagotchi/plugins/__init__.py | 180 +++++++++++++++++++++++++-------- pwnagotchi/ui/web/server.py | 2 +- 6 files changed, 144 insertions(+), 48 deletions(-) diff --git a/pwnagotchi/agent.py b/pwnagotchi/agent.py index 07f3665e..29bc5db8 100644 --- a/pwnagotchi/agent.py +++ b/pwnagotchi/agent.py @@ -306,7 +306,7 @@ class Agent(Client, Automata, AsyncAdvertiser, AsyncTrainer): def start_session_fetcher(self): #_thread.start_new_thread(self._fetch_stats, ()) - threading.Thread(target=self._fetch_stats, args=(), name="Session Fetcher").start() + threading.Thread(target=self._fetch_stats, args=(), name="Session Fetcher", daemon=True).start() def _fetch_stats(self): while True: @@ -390,7 +390,7 @@ class Agent(Client, Automata, AsyncAdvertiser, AsyncTrainer): def start_event_polling(self): # start a thread and pass in the mainloop #_thread.start_new_thread(self._event_poller, (asyncio.get_event_loop(),)) - threading.Thread(target=self._event_poller, args=(asyncio.get_event_loop(),), name="Event Polling") + threading.Thread(target=self._event_poller, args=(asyncio.get_event_loop(),), name="Event Polling", daemon=True) def is_module_running(self, module): s = self.session() diff --git a/pwnagotchi/ai/train.py b/pwnagotchi/ai/train.py index 035ccfd0..a0d7056b 100644 --- a/pwnagotchi/ai/train.py +++ b/pwnagotchi/ai/train.py @@ -112,7 +112,7 @@ class AsyncTrainer(object): def start_ai(self): #_thread.start_new_thread(self._ai_worker, ()) - threading.Thread(target=self._ai_worker, args=(), name="AI Worker").start() + threading.Thread(target=self._ai_worker, args=(), name="AI Worker" daemon=True).start() def _save_ai(self): logging.info("[AI] saving model to %s ..." % self._nn_path) diff --git a/pwnagotchi/fs/__init__.py b/pwnagotchi/fs/__init__.py index f0df1f6a..5205ff92 100644 --- a/pwnagotchi/fs/__init__.py +++ b/pwnagotchi/fs/__init__.py @@ -86,7 +86,7 @@ def setup_mounts(config): if interval: logging.debug("[FS] Starting thread to sync %s (interval: %d)", options['mount'], interval) - threading.Thread(target=m.daemonize, args=(interval,),name="File Sys").start() + threading.Thread(target=m.daemonize, args=(interval,),name="File Sys", daemon=True).start() #_thread.start_new_thread(m.daemonize, (interval,)) else: logging.debug("[FS] Not syncing %s, because interval is 0", diff --git a/pwnagotchi/mesh/utils.py b/pwnagotchi/mesh/utils.py index 27d7dee4..7f90f2c4 100644 --- a/pwnagotchi/mesh/utils.py +++ b/pwnagotchi/mesh/utils.py @@ -43,7 +43,7 @@ class AsyncAdvertiser(object): def start_advertising(self): if self._config['personality']['advertise']: #_thread.start_new_thread(self._adv_poller, ()) - threading.Thread(target=self._adv_poller,args=(), name="Grid").start() + threading.Thread(target=self._adv_poller,args=(), name="Grid", daemon=True).start() grid.set_advertisement_data(self._advertisement) grid.advertise(True) diff --git a/pwnagotchi/plugins/__init__.py b/pwnagotchi/plugins/__init__.py index d0212db7..a4e3039e 100644 --- a/pwnagotchi/plugins/__init__.py +++ b/pwnagotchi/plugins/__init__.py @@ -1,19 +1,107 @@ +import os +import queue +import glob import _thread import threading -import glob -import importlib -import importlib.util +import importlib, importlib.util import logging -import os -import threading -import pwnagotchi.grid +import time +import prctl + + +#Idea and base code from NurseJackass default_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "default") loaded = {} database = {} locks = {} +exitFlag = 0 +plugin_event_queues = {} +plugin_thread_workers = {} +class PluginHandler(): + def __init__(self, plugin_name): + try: + self._worker_thread = threading.Thread(target=self.doWork, daemon=True, name = "%s.sleeping" % plugin_name) + self._loop_thread = None + self.plugin_name = plugin_name + self.work_queue = queue.Queue() + self.queue_lock = threading.Lock() + self.load_handler = None + self.keep_going = True + logging.debug("Starting worker for %s" % plugin_name) + self._worker_thread.start() + except Exception as e: + logging.exception(e) + + def __del__(self): + self.keep_going = False + self._worker_thread.join() + if self.load_handler: + self.load_handler.join() + + def AddWork(self, event_name, *args, **kwargs): + if event_name == "loaded" or event_name == "loop": + # spawn separate thread, because many plugins use on_load as a "main" loop + # this way on_load can continue if it needs, while other events get processed + # for future use, use `on_loop` + try: + self._loop_thread = threading.Thread(target=self.doLoop, args = (self, event_name, *args), daemon=True, name = "%s.loop" % (self.plugin_name)).start() + except Exception as e: + logging.exception(e) + else: + self.work_queue.put([event_name, args, kwargs]) + + def run(self): + logging.debug("Worker thread starting for %s"%(self.plugin_name)) + self._worker_thread.start() + logging.info("Worker thread exited for %s"%(self.plugin_name)) + + def process_event(self, event_name, *args, **kwargs): + cb_name = 'on_%s' % event_name + callback = getattr(loaded[self.plugin_name], cb_name, None) + if callback: + callback(*args, **kwargs) + + def doWork(self): + global exitFlag + plugin_name = self.plugin_name + work_queue = self.work_queue + while not exitFlag and self.keep_going: + try: + data = work_queue.get(timeout=2) + (event_name, args, kwargs) = data + prctl.set_name("pwnagotchi.%s.%s" % (self.plugin_name, event_name )) + self._worker_thread.name = "%s.%s" % (self.plugin_name, event_name) + logging.debug("") + self.process_event(event_name, *args, **kwargs) + except queue.Empty as e: + self._worker_thread.name = "%s.sleeping" + prctl.set_name("pwnagotchi.%s.sleeping" % (self.plugin_name)) + pass + except Exception as e: + logging.exception(repr(e)) + + def doLoop(self, loopCB, event_name, *args, **kwargs): + global exitFlag + plugin_name = self.plugin_name + prctl.set_name("pwnagotchi.%s" % self.plugin_name) + + while not exitFlag and self.keep_going: + try: + self.process_event(event_name, *args, **kwargs) + self.keep_going = False + except Exception as e: + #error in plugin loop kill plugin + self.keep_going = False + logging.exception(repr(e)) + + def killLoop(self): + self._loop_thread.stop() + + + class Plugin: @classmethod def __init_subclass__(cls, **kwargs): @@ -45,16 +133,19 @@ def toggle_plugin(name, enable=True): global loaded, database if pwnagotchi.config: - if name not in pwnagotchi.config['main']['plugins']: + if not name in pwnagotchi.config['main']['plugins']: pwnagotchi.config['main']['plugins'][name] = dict() pwnagotchi.config['main']['plugins'][name]['enabled'] = enable - save_config(pwnagotchi.config, '/etc/pwnagotchi/config.toml') if not enable and name in loaded: if getattr(loaded[name], 'on_unload', None): loaded[name].on_unload(view.ROOT) del loaded[name] - + if name in plugin_event_queues: + plugin_event_queues[name].keep_going = False + del plugin_event_queues[name] + if pwnagotchi.config: + save_config(pwnagotchi.config, '/etc/pwnagotchi/config.toml') return True if enable and name in database and name not in loaded: @@ -62,47 +153,44 @@ def toggle_plugin(name, enable=True): if name in loaded and pwnagotchi.config and name in pwnagotchi.config['main']['plugins']: loaded[name].options = pwnagotchi.config['main']['plugins'][name] one(name, 'loaded') + time.sleep(3) if pwnagotchi.config: one(name, 'config_changed', pwnagotchi.config) one(name, 'ui_setup', view.ROOT) one(name, 'ready', view.ROOT._agent) + if pwnagotchi.config: + save_config(pwnagotchi.config, '/etc/pwnagotchi/config.toml') return True return False def on(event_name, *args, **kwargs): + global loaded, plugin_event_queues + cb_name = 'on_%s' % event_name for plugin_name in loaded.keys(): - one(plugin_name, event_name, *args, **kwargs) + plugin = loaded[plugin_name] + callback = getattr(plugin, cb_name, None) + if callback is None or not callable(callback): + continue -def locked_cb(lock_name, cb, *args, **kwargs): - global locks - - if lock_name not in locks: - locks[lock_name] = threading.Lock() - - with locks[lock_name]: - cb(*args, *kwargs) + if plugin_name not in plugin_event_queues: + plugin_event_queues[plugin_name] = PluginHandler(plugin_name) + plugin_event_queues[plugin_name].AddWork(event_name, *args, **kwargs) def one(plugin_name, event_name, *args, **kwargs): - global loaded - + global loaded, plugin_event_queues if plugin_name in loaded: plugin = loaded[plugin_name] cb_name = 'on_%s' % event_name callback = getattr(plugin, cb_name, None) if callback is not None and callable(callback): - try: - lock_name = "%s::%s" % (plugin_name, cb_name) - loggingFormat = "%s.%s" % (plugin_name, cb_name) - locked_cb_args = (lock_name, callback, *args, *kwargs) - #_thread.start_new_thread(locked_cb, locked_cb_args) - threading.Thread(target=locked_cb, args=locked_cb_args, name=loggingFormat).start() - except Exception as e: - logging.error("error while running %s.%s : %s" % (plugin_name, cb_name, e)) - logging.error(e, exc_info=True) + if plugin_name not in plugin_event_queues: + plugin_event_queues[plugin_name] = PluginHandler(plugin_name) + + plugin_event_queues[plugin_name].AddWork(event_name, *args, **kwargs) def load_from_file(filename): @@ -111,6 +199,8 @@ def load_from_file(filename): spec = importlib.util.spec_from_file_location(plugin_name, filename) instance = importlib.util.module_from_spec(spec) spec.loader.exec_module(instance) + if plugin_name not in plugin_event_queues: + plugin_event_queues[plugin_name] = PluginHandler(plugin_name) return plugin_name, instance @@ -131,20 +221,26 @@ def load_from_path(path, enabled=()): def load(config): - enabled = [name for name, options in config['main']['plugins'].items() if - 'enabled' in options and options['enabled']] + try: + enabled = [name for name, options in config['main']['plugins'].items() if + 'enabled' in options and options['enabled']] - # load default plugins - load_from_path(default_path, enabled=enabled) + # load default plugins + load_from_path(default_path, enabled=enabled) - # load custom ones - custom_path = config['main']['custom_plugins'] if 'custom_plugins' in config['main'] else None - if custom_path is not None: - load_from_path(custom_path, enabled=enabled) + # load custom ones + custom_path = config['main']['custom_plugins'] if 'custom_plugins' in config['main'] else None + if custom_path is not None: + load_from_path(custom_path, enabled=enabled) - # propagate options - for name, plugin in loaded.items(): - plugin.options = config['main']['plugins'][name] + # propagate options + for name, plugin in loaded.items(): + if name in config['main']['plugins']: + plugin.options = config['main']['plugins'][name] + else: + plugin.options = {} - on('loaded') - on('config_changed', config) + on('loaded') + on('config_changed', config) + except Exception as e: + logging.exception(repr(e)) \ No newline at end of file diff --git a/pwnagotchi/ui/web/server.py b/pwnagotchi/ui/web/server.py index 8f4c9a2e..32aeef8b 100644 --- a/pwnagotchi/ui/web/server.py +++ b/pwnagotchi/ui/web/server.py @@ -28,7 +28,7 @@ class Server: if self._enabled: #_thread.start_new_thread(self._http_serve, ()) logging.info("Starting WebServer thread") - self._thread = threading.Thread(target=self._http_serve, name="WebServer").start() + self._thread = threading.Thread(target=self._http_serve, name="WebServer", daemon = True).start() def _http_serve(self): if self._address is not None: