diff --git a/.gitignore b/.gitignore index 4c607cb..871ceec 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,6 @@ cython_debug/ ssh_host_key *.key *.pub + +# config files +*.ini diff --git a/accounts.json b/accounts.json deleted file mode 100644 index f60de16..0000000 --- a/accounts.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "guest":"", - "user1":"secretpw" -} \ No newline at end of file diff --git a/ssh_server.py b/ssh_server.py index 637e041..9611cf4 100644 --- a/ssh_server.py +++ b/ssh_server.py @@ -25,6 +25,8 @@ from langchain_core.runnables import RunnablePassthrough from operator import itemgetter +from configparser import ConfigParser + async def handle_client(process: asyncssh.SSHServerProcess) -> None: # This is the main loop for handling SSH client connections. # Any user interaction should be done here. @@ -81,8 +83,7 @@ class MySSHServer(asyncssh.SSHServer): def connection_lost(self, exc: Optional[Exception]) -> None: if exc: - print('SSH connection error: ' + str(exc), file=sys.stderr) - logger.error('SSH connection error: ' + str(exc), file=sys.stderr) + logger.error('SSH connection error: ' + str(exc)) else: logger.info("SSH connection closed.") @@ -100,13 +101,13 @@ class MySSHServer(asyncssh.SSHServer): async def start_server() -> None: await asyncssh.listen( - port=8022, + port=config['ssh'].getint("port", 8022), reuse_address=True, reuse_port=True, server_factory=MySSHServer, - server_host_keys=['ssh_host_key'], + server_host_keys=config['ssh'].get("host_priv_key", "ssh_host_key"), process_factory=handle_client, - server_version="SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3" + server_version=config['ssh'].get("server_version_string", "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3") ) class ContextFilter(logging.Filter): @@ -127,29 +128,39 @@ class ContextFilter(logging.Filter): return True -def read_accounts() -> dict: - accounts = dict() - - with open('accounts.json', 'r') as f: - accounts = json.loads(f.read()) - - return accounts - def llm_get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in llm_sessions: llm_sessions[session_id] = InMemoryChatMessageHistory() return llm_sessions[session_id] +def get_user_accounts() -> dict: + if (not 'user_accounts' in config) or (len(config.items('user_accounts')) == 0): + raise ValueError("No user accounts found in configuration file.") + + accounts = dict() + + for k, v in config.items('user_accounts'): + accounts[k] = v + + return accounts + #### MAIN #### # Always use UTC for logging logging.Formatter.formatTime = (lambda self, record, datefmt=None: datetime.datetime.fromtimestamp(record.created, datetime.timezone.utc).astimezone().isoformat(sep="T",timespec="milliseconds")) +# Read our configuration file +config = ConfigParser() +config.read("config.ini") + +# Read the user accounts from the configuration file +accounts = get_user_accounts() + # Set up the honeypot logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -log_file_handler = logging.FileHandler("ssh_log.log") +log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log")) logger.addHandler(log_file_handler) log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(task_name)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z")) @@ -159,15 +170,16 @@ logger.addFilter(f) # Now get access to the LLM -with open("prompt.txt", "r") as f: +prompt_file = config['llm'].get("system_prompt_file", "prompt.txt") +with open(prompt_file, "r") as f: llm_system_prompt = f.read() -llm_model = ChatOpenAI(model="gpt-4o") +llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE")) llm_sessions = dict() llm_trimmer = trim_messages( - max_tokens=64000, + max_tokens=config['llm'].getint("trimmer_max_tokens", 64000), strategy="last", token_counter=llm_model, include_system=True, @@ -197,9 +209,6 @@ with_message_history = RunnableWithMessageHistory( input_messages_key="messages" ) -# Read the valid accounts -accounts = read_accounts() - # Kick off the server! loop = asyncio.new_event_loop() asyncio.set_event_loop(loop)