Correctly handle both interactive and non-interactive SSH sessions

SSH servers can take user commands from an interactive session as normal, but users can also include commands on the ssh client command line which are executed on the server (e.g., "ssh <hostname> 'uname -a'"). We now execute these non-interactive commands properly as well.

Also added a new "interactive" flag to all user commands (true/false) to show which type of command execution this was.
This commit is contained in:
David J. Bianco
2025-02-04 12:29:12 -05:00
parent 585ee66009
commit 5f27aeeabb

View File

@ -37,6 +37,8 @@ class JSONFormatter(logging.Formatter):
"dst_port": record.dst_port, "dst_port": record.dst_port,
"message": record.getMessage() "message": record.getMessage()
} }
if hasattr(record, 'interactive'):
log_record["interactive"] = record.interactive
# Include any additional fields from the extra dictionary # Include any additional fields from the extra dictionary
for key, value in record.__dict__.items(): for key, value in record.__dict__.items():
if key not in log_record and key != 'args' and key != 'msg': if key not in log_record and key != 'args' and key != 'msg':
@ -150,7 +152,8 @@ representative examples.
llm_response = await session.ainvoke( llm_response = await session.ainvoke(
{ {
"messages": [HumanMessage(content=prompt)], "messages": [HumanMessage(content=prompt)],
"username": process.get_extra_info('username') "username": process.get_extra_info('username'),
"interactive": True # Ensure interactive flag is passed
}, },
config=llm_config config=llm_config
) )
@ -178,42 +181,57 @@ async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer)
llm_config = {"configurable": {"session_id": task_uuid}} llm_config = {"configurable": {"session_id": task_uuid}}
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="ignore this message")],
"username": process.get_extra_info('username')
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')})
# Store process, llm_config, and session in the MySSHServer instance
server._process = process
server._llm_config = llm_config
server._session = with_message_history
try: try:
async for line in process.stdin: if process.command:
line = line.rstrip('\n') # Handle non-interactive command execution
logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8')}) command = process.command
logger.info("User input", extra={"details": b64encode(command.encode('utf-8')).decode('utf-8'), "interactive": False})
# Send the command to the LLM and give the response to the user
llm_response = await with_message_history.ainvoke( llm_response = await with_message_history.ainvoke(
{ {
"messages": [HumanMessage(content=line)], "messages": [HumanMessage(content=command)],
"username": process.get_extra_info('username') "username": process.get_extra_info('username'),
"interactive": False
}, },
config=llm_config config=llm_config
) )
if llm_response.content == "XXX-END-OF-SESSION-XXX": process.stdout.write(f"{llm_response.content}")
await session_summary(process, llm_config, with_message_history, server) logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": False})
process.exit(0) await session_summary(process, llm_config, with_message_history, server)
return process.exit(0)
else: else:
process.stdout.write(f"{llm_response.content}") # Handle interactive session
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}) llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="ignore this message")],
"username": process.get_extra_info('username'),
"interactive": True
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True})
async for line in process.stdin:
line = line.rstrip('\n')
logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8'), "interactive": True})
# Send the command to the LLM and give the response to the user
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content=line)],
"username": process.get_extra_info('username'),
"interactive": True
},
config=llm_config
)
if llm_response.content == "XXX-END-OF-SESSION-XXX":
await session_summary(process, llm_config, with_message_history, server)
process.exit(0)
return
else:
process.stdout.write(f"{llm_response.content}")
logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True})
except asyncssh.BreakReceived: except asyncssh.BreakReceived:
pass pass