Improved end-of-session handling

Rather than explicitly checking to see if the user
was typing a shell exit command, the LLM is now
instructed to provide a specific token starting
("XXX-END-OF-SESSION-XXX") to indicate that the
session should be closed. This allows the user to
exit the shell in any way they see fit, and the
LLM will still know when to end the session. It
also means that typing 'exit' or similar commands
to subshells or command interpreters (e.g. Python)
are less likely to cause the session to end.
This commit is contained in:
David J. Bianco
2024-08-23 15:28:42 -04:00
parent ed95eda824
commit 2461b42e40
2 changed files with 12 additions and 12 deletions

View File

@ -5,5 +5,9 @@ On the first call, be sure to include a realistic MOTD.
End all responses with a realistic shell prompt to display to the user, including a space at the end. End all responses with a realistic shell prompt to display to the user, including a space at the end.
Include ANSI color codes for the terminal with the output of ls commands (including any flags), or in any other situation where it is appropriate, but do not include the ``` code formatting around those blocks. Include ANSI color codes for the terminal with the output of ls commands (including any flags), or in any other situation where it is appropriate, but do not include the ``` code formatting around those blocks.
Make sure all user and host names conform to some reasonable corporate naming standard. Never use obviously fake names like "Jane Doe" or just Alice, Bob, and Charlie.
If at any time the user's input would cause the SSH session to close (e.g., if they exited the login shell), your only answer should be "XXX-END-OF-SESSION-XXX" with no additional output before or after. Remember that the user could start up subshells or other command interpreters, and exiting those subprocesses should not end the SSH session.
Assume the username is {username}. Assume the username is {username}.

View File

@ -48,11 +48,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
line = line.rstrip('\n') line = line.rstrip('\n')
logger.info(f"INPUT: {line}") logger.info(f"INPUT: {line}")
# If the user is trying to log out, don't send that to the
# LLM, just exit the session.
if line in ['exit', 'quit', 'logout']:
process.exit(0)
# Send the command to the LLM and give the response to the user # Send the command to the LLM and give the response to the user
llm_response = await with_message_history.ainvoke( llm_response = await with_message_history.ainvoke(
{ {
@ -61,8 +56,11 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
}, },
config=llm_config config=llm_config
) )
process.stdout.write(f"{llm_response.content}") if llm_response.content == "XXX-END-OF-SESSION-XXX":
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") process.exit(0)
else:
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass
@ -142,7 +140,7 @@ class ContextFilter(logging.Filter):
if task: if task:
task_name = task.get_name() task_name = task.get_name()
else: else:
task_name = "NONE" task_name = "-"
record.src_ip = thread_local.__dict__.get('src_ip', '-') record.src_ip = thread_local.__dict__.get('src_ip', '-')
record.src_port = thread_local.__dict__.get('src_port', '-') record.src_port = thread_local.__dict__.get('src_port', '-')
@ -170,8 +168,6 @@ def get_user_accounts() -> dict:
return accounts return accounts
def choose_llm(): def choose_llm():
# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
llm_provider_name = config['llm'].get("llm_provider", "openai") llm_provider_name = config['llm'].get("llm_provider", "openai")
llm_provider_name = llm_provider_name.lower() llm_provider_name = llm_provider_name.lower()
model_name = config['llm'].get("model_name", "gpt-3.5-turbo") model_name = config['llm'].get("model_name", "gpt-3.5-turbo")
@ -213,7 +209,7 @@ logger.setLevel(logging.INFO)
log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log")) log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log"))
logger.addHandler(log_file_handler) logger.addHandler(log_file_handler)
log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(task_name)s SSH [%(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z")) log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(task_name)s SSH %(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z"))
f = ContextFilter() f = ContextFilter()
logger.addFilter(f) logger.addFilter(f)