Improved end-of-session handling

Rather than explicitly checking to see if the user
was typing a shell exit command, the LLM is now
instructed to provide a specific token starting
("XXX-END-OF-SESSION-XXX") to indicate that the
session should be closed. This allows the user to
exit the shell in any way they see fit, and the
LLM will still know when to end the session. It
also means that typing 'exit' or similar commands
to subshells or command interpreters (e.g. Python)
are less likely to cause the session to end.
This commit is contained in:
David J. Bianco
2024-08-23 15:28:42 -04:00
parent ed95eda824
commit 2461b42e40
2 changed files with 12 additions and 12 deletions

View File

@ -48,11 +48,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
line = line.rstrip('\n')
logger.info(f"INPUT: {line}")
# If the user is trying to log out, don't send that to the
# LLM, just exit the session.
if line in ['exit', 'quit', 'logout']:
process.exit(0)
# Send the command to the LLM and give the response to the user
llm_response = await with_message_history.ainvoke(
{
@ -61,8 +56,11 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
if llm_response.content == "XXX-END-OF-SESSION-XXX":
process.exit(0)
else:
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
except asyncssh.BreakReceived:
pass
@ -142,7 +140,7 @@ class ContextFilter(logging.Filter):
if task:
task_name = task.get_name()
else:
task_name = "NONE"
task_name = "-"
record.src_ip = thread_local.__dict__.get('src_ip', '-')
record.src_port = thread_local.__dict__.get('src_port', '-')
@ -170,8 +168,6 @@ def get_user_accounts() -> dict:
return accounts
def choose_llm():
# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
llm_provider_name = config['llm'].get("llm_provider", "openai")
llm_provider_name = llm_provider_name.lower()
model_name = config['llm'].get("model_name", "gpt-3.5-turbo")
@ -213,7 +209,7 @@ logger.setLevel(logging.INFO)
log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log"))
logger.addHandler(log_file_handler)
log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(task_name)s SSH [%(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z"))
log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(task_name)s SSH %(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z"))
f = ContextFilter()
logger.addFilter(f)