diff --git a/prompt.txt b/prompt.txt index d38352b..6e5f4ea 100644 --- a/prompt.txt +++ b/prompt.txt @@ -5,5 +5,9 @@ On the first call, be sure to include a realistic MOTD. End all responses with a realistic shell prompt to display to the user, including a space at the end. Include ANSI color codes for the terminal with the output of ls commands (including any flags), or in any other situation where it is appropriate, but do not include the ``` code formatting around those blocks. - + +Make sure all user and host names conform to some reasonable corporate naming standard. Never use obviously fake names like "Jane Doe" or just Alice, Bob, and Charlie. + +If at any time the user's input would cause the SSH session to close (e.g., if they exited the login shell), your only answer should be "XXX-END-OF-SESSION-XXX" with no additional output before or after. Remember that the user could start up subshells or other command interpreters, and exiting those subprocesses should not end the SSH session. + Assume the username is {username}. \ No newline at end of file diff --git a/ssh_server.py b/ssh_server.py index 2e32172..42f5490 100755 --- a/ssh_server.py +++ b/ssh_server.py @@ -48,11 +48,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: line = line.rstrip('\n') logger.info(f"INPUT: {line}") - # If the user is trying to log out, don't send that to the - # LLM, just exit the session. - if line in ['exit', 'quit', 'logout']: - process.exit(0) - # Send the command to the LLM and give the response to the user llm_response = await with_message_history.ainvoke( { @@ -61,8 +56,11 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: }, config=llm_config ) - process.stdout.write(f"{llm_response.content}") - logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") + if llm_response.content == "XXX-END-OF-SESSION-XXX": + process.exit(0) + else: + process.stdout.write(f"{llm_response.content}") + logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") except asyncssh.BreakReceived: pass @@ -142,7 +140,7 @@ class ContextFilter(logging.Filter): if task: task_name = task.get_name() else: - task_name = "NONE" + task_name = "-" record.src_ip = thread_local.__dict__.get('src_ip', '-') record.src_port = thread_local.__dict__.get('src_port', '-') @@ -170,8 +168,6 @@ def get_user_accounts() -> dict: return accounts def choose_llm(): -# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE")) - llm_provider_name = config['llm'].get("llm_provider", "openai") llm_provider_name = llm_provider_name.lower() model_name = config['llm'].get("model_name", "gpt-3.5-turbo") @@ -213,7 +209,7 @@ logger.setLevel(logging.INFO) log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log")) logger.addHandler(log_file_handler) -log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(task_name)s SSH [%(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z")) +log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(task_name)s SSH %(src_ip)s:%(src_port)s -> %(dst_ip)s:%(dst_port)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S. %Z")) f = ContextFilter() logger.addFilter(f)