Fix session summary and ascii errors

* Session summaries now occur both at normal session termination (e.g., the user gracefully logs out) or abnormal termination, such as if the client disconnects suddenly.
* Now encode the AI results as UTF-8 instead of ASCII, because it would ocassionally send back non-ASCII characters which caused the server to throw errors
This commit is contained in:
David J. Bianco
2025-01-10 12:33:36 -05:00
parent 3b546126b6
commit 7185c7f5c7

View File

@ -21,6 +21,57 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnablePassthrough
class MySSHServer(asyncssh.SSHServer):
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
# Get the source and destination IPs and ports
(src_ip, src_port, _, _) = conn.get_extra_info('peername')
(dst_ip, dst_port, _, _) = conn.get_extra_info('sockname')
# Store the connection details in thread-local storage
thread_local.src_ip = src_ip
thread_local.src_port = src_port
thread_local.dst_ip = dst_ip
thread_local.dst_port = dst_port
# Log the connection details
logger.info(f"SSH connection received from {src_ip}/{src_port} to {dst_ip}/{dst_port}.")
def connection_lost(self, exc: Optional[Exception]) -> None:
if exc:
logger.error('SSH connection error: ' + str(exc))
else:
logger.info("SSH connection closed.")
# Ensure session summary is called on connection loss if attributes are set
if hasattr(self, '_process') and hasattr(self, '_llm_config') and hasattr(self, '_session'):
asyncio.create_task(session_summary(self._process, self._llm_config, self._session))
def begin_auth(self, username: str) -> bool:
if accounts.get(username) != '':
logger.info(f"AUTH: User {username} attempting to authenticate.")
return True
else:
logger.info(f"AUTH: SUCCESS for user {username} with password ''.")
return False
def password_auth_supported(self) -> bool:
return True
def host_based_auth_supported(self) -> bool:
return False
def public_key_auth_supported(self) -> bool:
return False
def kbdinit_auth_supported(self) -> bool:
return False
def validate_password(self, username: str, password: str) -> bool:
pw = accounts.get(username, '*')
if ((pw != '*') and (password == pw)):
logger.info(f"AUTH: SUCCESS for user {username} with password '{password}'.")
return True
else:
logger.info(f"AUTH: FAILED for user {username} with password '{password}'.")
return False
async def session_summary(process: asyncssh.SSHServerProcess, llm_config: dict, session: RunnableWithMessageHistory): async def session_summary(process: asyncssh.SSHServerProcess, llm_config: dict, session: RunnableWithMessageHistory):
# When the SSH session ends, ask the LLM to give a nice # When the SSH session ends, ask the LLM to give a nice
# summary of the attacker's actions and probable intent, # summary of the attacker's actions and probable intent,
@ -65,7 +116,7 @@ representative examples.
logger.info(f"---SESSION SUMMARY---\n{llm_response.content}\n") logger.info(f"---SESSION SUMMARY---\n{llm_response.content}\n")
process.exit(0) process.exit(0)
async def handle_client(process: asyncssh.SSHServerProcess) -> None: async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer) -> None:
# This is the main loop for handling SSH client connections. # This is the main loop for handling SSH client connections.
# Any user interaction should be done here. # Any user interaction should be done here.
@ -85,7 +136,12 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
) )
process.stdout.write(f"{llm_response.content}") process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}")
# Store process, llm_config, and session in the MySSHServer instance
server._process = process
server._llm_config = llm_config
server._session = with_message_history
try: try:
async for line in process.stdin: async for line in process.stdin:
@ -102,73 +158,31 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
) )
if llm_response.content == "XXX-END-OF-SESSION-XXX": if llm_response.content == "XXX-END-OF-SESSION-XXX":
await session_summary(process, llm_config, with_message_history) await session_summary(process, llm_config, with_message_history)
return
else: else:
process.stdout.write(f"{llm_response.content}") process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}")
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass
finally:
await session_summary(process, llm_config, with_message_history)
# Just in case we ever get here, which we probably shouldn't # Just in case we ever get here, which we probably shouldn't
# process.exit(0) # process.exit(0)
class MySSHServer(asyncssh.SSHServer):
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
# Get the source and destination IPs and ports
(src_ip, src_port, _, _) = conn.get_extra_info('peername')
(dst_ip, dst_port, _, _) = conn.get_extra_info('sockname')
# Store the connection details in thread-local storage
thread_local.src_ip = src_ip
thread_local.src_port = src_port
thread_local.dst_ip = dst_ip
thread_local.dst_port = dst_port
# Log the connection details
logger.info(f"SSH connection received from {src_ip}/{src_port} to {dst_ip}/{dst_port}.")
def connection_lost(self, exc: Optional[Exception]) -> None:
if exc:
logger.error('SSH connection error: ' + str(exc))
else:
logger.info("SSH connection closed.")
def begin_auth(self, username: str) -> bool:
if accounts.get(username) != '':
logger.info(f"AUTH: User {username} attempting to authenticate.")
return True
else:
logger.info(f"AUTH: SUCCESS for user {username} with password ''.")
return False
def password_auth_supported(self) -> bool:
return True
def host_based_auth_supported(self) -> bool:
return False
def public_key_auth_supported(self) -> bool:
return False
def kbdinit_auth_supported(self) -> bool:
return False
def validate_password(self, username: str, password: str) -> bool:
pw = accounts.get(username, '*')
if ((pw != '*') and (password == pw)):
logger.info(f"AUTH: SUCCESS for user {username} with password '{password}'.")
return True
else:
logger.info(f"AUTH: FAILED for user {username} with password '{password}'.")
return False
async def start_server() -> None: async def start_server() -> None:
async def process_factory(process: asyncssh.SSHServerProcess) -> None:
server = process.get_server()
await handle_client(process, server)
await asyncssh.listen( await asyncssh.listen(
port=config['ssh'].getint("port", 8022), port=config['ssh'].getint("port", 8022),
reuse_address=True, reuse_address=True,
reuse_port=True, reuse_port=True,
server_factory=MySSHServer, server_factory=MySSHServer,
server_host_keys=config['ssh'].get("host_priv_key", "ssh_host_key"), server_host_keys=config['ssh'].get("host_priv_key", "ssh_host_key"),
process_factory=handle_client, process_factory=lambda process: handle_client(process, MySSHServer()),
server_version=config['ssh'].get("server_version_string", "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3") server_version=config['ssh'].get("server_version_string", "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3")
) )