From 7185c7f5c7b394fa1b9d72affb6f97e0346f9310 Mon Sep 17 00:00:00 2001 From: "David J. Bianco" Date: Fri, 10 Jan 2025 12:33:36 -0500 Subject: [PATCH] Fix session summary and ascii errors * Session summaries now occur both at normal session termination (e.g., the user gracefully logs out) or abnormal termination, such as if the client disconnects suddenly. * Now encode the AI results as UTF-8 instead of ASCII, because it would ocassionally send back non-ASCII characters which caused the server to throw errors --- SSH/ssh_server.py | 120 ++++++++++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 53 deletions(-) diff --git a/SSH/ssh_server.py b/SSH/ssh_server.py index 71ceb31..3885d22 100755 --- a/SSH/ssh_server.py +++ b/SSH/ssh_server.py @@ -21,6 +21,57 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough +class MySSHServer(asyncssh.SSHServer): + def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: + # Get the source and destination IPs and ports + (src_ip, src_port, _, _) = conn.get_extra_info('peername') + (dst_ip, dst_port, _, _) = conn.get_extra_info('sockname') + + # Store the connection details in thread-local storage + thread_local.src_ip = src_ip + thread_local.src_port = src_port + thread_local.dst_ip = dst_ip + thread_local.dst_port = dst_port + + # Log the connection details + logger.info(f"SSH connection received from {src_ip}/{src_port} to {dst_ip}/{dst_port}.") + + def connection_lost(self, exc: Optional[Exception]) -> None: + if exc: + logger.error('SSH connection error: ' + str(exc)) + else: + logger.info("SSH connection closed.") + # Ensure session summary is called on connection loss if attributes are set + if hasattr(self, '_process') and hasattr(self, '_llm_config') and hasattr(self, '_session'): + asyncio.create_task(session_summary(self._process, self._llm_config, self._session)) + + def begin_auth(self, username: str) -> bool: + if accounts.get(username) != '': + logger.info(f"AUTH: User {username} attempting to authenticate.") + return True + else: + logger.info(f"AUTH: SUCCESS for user {username} with password ''.") + return False + + def password_auth_supported(self) -> bool: + return True + def host_based_auth_supported(self) -> bool: + return False + def public_key_auth_supported(self) -> bool: + return False + def kbdinit_auth_supported(self) -> bool: + return False + + def validate_password(self, username: str, password: str) -> bool: + pw = accounts.get(username, '*') + + if ((pw != '*') and (password == pw)): + logger.info(f"AUTH: SUCCESS for user {username} with password '{password}'.") + return True + else: + logger.info(f"AUTH: FAILED for user {username} with password '{password}'.") + return False + async def session_summary(process: asyncssh.SSHServerProcess, llm_config: dict, session: RunnableWithMessageHistory): # When the SSH session ends, ask the LLM to give a nice # summary of the attacker's actions and probable intent, @@ -65,7 +116,7 @@ representative examples. logger.info(f"---SESSION SUMMARY---\n{llm_response.content}\n") process.exit(0) -async def handle_client(process: asyncssh.SSHServerProcess) -> None: +async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer) -> None: # This is the main loop for handling SSH client connections. # Any user interaction should be done here. @@ -85,7 +136,12 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: ) process.stdout.write(f"{llm_response.content}") - logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") + logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}") + + # Store process, llm_config, and session in the MySSHServer instance + server._process = process + server._llm_config = llm_config + server._session = with_message_history try: async for line in process.stdin: @@ -102,73 +158,31 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: ) if llm_response.content == "XXX-END-OF-SESSION-XXX": await session_summary(process, llm_config, with_message_history) + return else: process.stdout.write(f"{llm_response.content}") - logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") + logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}") except asyncssh.BreakReceived: pass + finally: + await session_summary(process, llm_config, with_message_history) # Just in case we ever get here, which we probably shouldn't # process.exit(0) -class MySSHServer(asyncssh.SSHServer): - def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: - # Get the source and destination IPs and ports - (src_ip, src_port, _, _) = conn.get_extra_info('peername') - (dst_ip, dst_port, _, _) = conn.get_extra_info('sockname') - - # Store the connection details in thread-local storage - thread_local.src_ip = src_ip - thread_local.src_port = src_port - thread_local.dst_ip = dst_ip - thread_local.dst_port = dst_port - - # Log the connection details - logger.info(f"SSH connection received from {src_ip}/{src_port} to {dst_ip}/{dst_port}.") - - def connection_lost(self, exc: Optional[Exception]) -> None: - if exc: - logger.error('SSH connection error: ' + str(exc)) - - else: - logger.info("SSH connection closed.") - - def begin_auth(self, username: str) -> bool: - if accounts.get(username) != '': - logger.info(f"AUTH: User {username} attempting to authenticate.") - return True - else: - logger.info(f"AUTH: SUCCESS for user {username} with password ''.") - return False - - def password_auth_supported(self) -> bool: - return True - def host_based_auth_supported(self) -> bool: - return False - def public_key_auth_supported(self) -> bool: - return False - def kbdinit_auth_supported(self) -> bool: - return False - - def validate_password(self, username: str, password: str) -> bool: - pw = accounts.get(username, '*') - - if ((pw != '*') and (password == pw)): - logger.info(f"AUTH: SUCCESS for user {username} with password '{password}'.") - return True - else: - logger.info(f"AUTH: FAILED for user {username} with password '{password}'.") - return False - async def start_server() -> None: + async def process_factory(process: asyncssh.SSHServerProcess) -> None: + server = process.get_server() + await handle_client(process, server) + await asyncssh.listen( port=config['ssh'].getint("port", 8022), reuse_address=True, reuse_port=True, server_factory=MySSHServer, server_host_keys=config['ssh'].get("host_priv_key", "ssh_host_key"), - process_factory=handle_client, + process_factory=lambda process: handle_client(process, MySSHServer()), server_version=config['ssh'].get("server_version_string", "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3") )