diff --git a/SSH/ssh_server.py b/SSH/ssh_server.py index c452559..e2550f6 100755 --- a/SSH/ssh_server.py +++ b/SSH/ssh_server.py @@ -37,6 +37,8 @@ class JSONFormatter(logging.Formatter): "dst_port": record.dst_port, "message": record.getMessage() } + if hasattr(record, 'interactive'): + log_record["interactive"] = record.interactive # Include any additional fields from the extra dictionary for key, value in record.__dict__.items(): if key not in log_record and key != 'args' and key != 'msg': @@ -150,7 +152,8 @@ representative examples. llm_response = await session.ainvoke( { "messages": [HumanMessage(content=prompt)], - "username": process.get_extra_info('username') + "username": process.get_extra_info('username'), + "interactive": True # Ensure interactive flag is passed }, config=llm_config ) @@ -178,42 +181,57 @@ async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer) llm_config = {"configurable": {"session_id": task_uuid}} - llm_response = await with_message_history.ainvoke( - { - "messages": [HumanMessage(content="ignore this message")], - "username": process.get_extra_info('username') - }, - config=llm_config - ) - - process.stdout.write(f"{llm_response.content}") - logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}) - - # Store process, llm_config, and session in the MySSHServer instance - server._process = process - server._llm_config = llm_config - server._session = with_message_history - try: - async for line in process.stdin: - line = line.rstrip('\n') - logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8')}) - - # Send the command to the LLM and give the response to the user + if process.command: + # Handle non-interactive command execution + command = process.command + logger.info("User input", extra={"details": b64encode(command.encode('utf-8')).decode('utf-8'), "interactive": False}) llm_response = await with_message_history.ainvoke( { - "messages": [HumanMessage(content=line)], - "username": process.get_extra_info('username') + "messages": [HumanMessage(content=command)], + "username": process.get_extra_info('username'), + "interactive": False }, config=llm_config ) - if llm_response.content == "XXX-END-OF-SESSION-XXX": - await session_summary(process, llm_config, with_message_history, server) - process.exit(0) - return - else: - process.stdout.write(f"{llm_response.content}") - logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8')}) + process.stdout.write(f"{llm_response.content}") + logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": False}) + await session_summary(process, llm_config, with_message_history, server) + process.exit(0) + else: + # Handle interactive session + llm_response = await with_message_history.ainvoke( + { + "messages": [HumanMessage(content="ignore this message")], + "username": process.get_extra_info('username'), + "interactive": True + }, + config=llm_config + ) + + process.stdout.write(f"{llm_response.content}") + logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True}) + + async for line in process.stdin: + line = line.rstrip('\n') + logger.info("User input", extra={"details": b64encode(line.encode('utf-8')).decode('utf-8'), "interactive": True}) + + # Send the command to the LLM and give the response to the user + llm_response = await with_message_history.ainvoke( + { + "messages": [HumanMessage(content=line)], + "username": process.get_extra_info('username'), + "interactive": True + }, + config=llm_config + ) + if llm_response.content == "XXX-END-OF-SESSION-XXX": + await session_summary(process, llm_config, with_message_history, server) + process.exit(0) + return + else: + process.stdout.write(f"{llm_response.content}") + logger.info("LLM response", extra={"details": b64encode(llm_response.content.encode('utf-8')).decode('utf-8'), "interactive": True}) except asyncssh.BreakReceived: pass