3 Commits

Author SHA1 Message Date
0197b8b1df Update config.ini.TEMPLATE to support non-interactive commands
Oops, I forgot to commit this before.
2025-02-04 14:55:02 -05:00
5f27aeeabb Correctly handle both interactive and non-interactive SSH sessions
SSH servers can take user commands from an interactive session as normal, but users can also include commands on the ssh client command line which are executed on the server (e.g., "ssh <hostname> 'uname -a'"). We now execute these non-interactive commands properly as well.

Also added a new "interactive" flag to all user commands (true/false) to show which type of command execution this was.
2025-02-04 12:29:12 -05:00
585ee66009 Don't print ConnectionLost exceptions to the console
These are far too frequent. We still log them, though, we just don't print them.
2025-01-28 10:48:29 -05:00
2 changed files with 56 additions and 35 deletions

View File

@ -51,10 +51,11 @@ system_prompt = Interpret all inputs as though they were SSH commands and provid
responses do not indicate that the user is on a virtual machine, container,
or emulation at all.
On the first call, be sure to include a realistic MOTD.
If {interactive} is True, make sure to include a realistic MOTD on the
first call, and to end all responses with a realistic shell prompt to display to
the user, including a space at the end.
End all responses with a realistic shell prompt to display to the user,
including a space at the end.
If {interactive} is False, do not include a shell prompt or MOTD.
Include ANSI color codes for the terminal with the output of ls commands
(including any flags), or in any other situation where it is appropriate, but

View File

@ -23,6 +23,7 @@ from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMess
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from asyncssh.misc import ConnectionLost
class JSONFormatter(logging.Formatter):
def format(self, record):
@ -36,6 +37,8 @@ class JSONFormatter(logging.Formatter):
"dst_port": record.dst_port,
"message": record.getMessage()
}
if hasattr(record, 'interactive'):
log_record["interactive"] = record.interactive
# Include any additional fields from the extra dictionary
for key, value in record.__dict__.items():
if key not in log_record and key != 'args' and key != 'msg':
@ -74,7 +77,8 @@ class MySSHServer(asyncssh.SSHServer):
def connection_lost(self, exc: Optional[Exception]) -> None:
if exc:
logger.error('SSH connection error', extra={"error": str(exc)})
traceback.print_exception(exc)
if not isinstance(exc, ConnectionLost):
traceback.print_exception(exc)
else:
logger.info("SSH connection closed")
# Ensure session summary is called on connection loss if attributes are set
@ -148,7 +152,8 @@ representative examples.
llm_response = await session.ainvoke(
{
"messages": [HumanMessage(content=prompt)],
"username": process.get_extra_info('username')
"username": process.get_extra_info('username'),
"interactive": True # Ensure interactive flag is passed
},
config=llm_config
)
@ -176,42 +181,57 @@ async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer)
llm_config = {"configurable": {"session_id": task_uuid}}
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="ignore this message")],
"username": process.get_extra_info('username')
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')})
# Store process, llm_config, and session in the MySSHServer instance
server._process = process
server._llm_config = llm_config
server._session = with_message_history
try:
async for line in process.stdin:
line = line.rstrip('\n')
logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8')})
# Send the command to the LLM and give the response to the user
if process.command:
# Handle non-interactive command execution
command = process.command
logger.info("User input", extra={"details": b64encode(command.encode('utf-8')).decode('utf-8'), "interactive": False})
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content=line)],
"username": process.get_extra_info('username')
"messages": [HumanMessage(content=command)],
"username": process.get_extra_info('username'),
"interactive": False
},
config=llm_config
)
if llm_response.content == "XXX-END-OF-SESSION-XXX":
await session_summary(process, llm_config, with_message_history, server)
process.exit(0)
return
else:
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')})
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": False})
await session_summary(process, llm_config, with_message_history, server)
process.exit(0)
else:
# Handle interactive session
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="ignore this message")],
"username": process.get_extra_info('username'),
"interactive": True
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True})
async for line in process.stdin:
line = line.rstrip('\n')
logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8'), "interactive": True})
# Send the command to the LLM and give the response to the user
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content=line)],
"username": process.get_extra_info('username'),
"interactive": True
},
config=llm_config
)
if llm_response.content == "XXX-END-OF-SESSION-XXX":
await session_summary(process, llm_config, with_message_history, server)
process.exit(0)
return
else:
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True})
except asyncssh.BreakReceived:
pass