Make peername and sockname calls more robust across platforms

For whatever reason, MacOS returns 4 values from conn.get_extra_info('peername') and conn.get_extra_info('sockname'), but Linux systems only return 2.  On the Mac, it's only the first two that we need anyway. Now we retrieve them all, no matter how many there are, and just use the first two so it will work on both platforms.
This commit is contained in:
David J. Bianco
2025-01-28 10:39:12 -05:00
parent 788bd26845
commit 7be73a7dff

View File

@ -49,8 +49,18 @@ class MySSHServer(asyncssh.SSHServer):
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
# Get the source and destination IPs and ports # Get the source and destination IPs and ports
(src_ip, src_port, _, _) = conn.get_extra_info('peername') peername = conn.get_extra_info('peername')
(dst_ip, dst_port, _, _) = conn.get_extra_info('sockname') sockname = conn.get_extra_info('sockname')
if peername is not None:
src_ip, src_port = peername[:2]
else:
src_ip, src_port = '-', '-'
if sockname is not None:
dst_ip, dst_port = sockname[:2]
else:
dst_ip, dst_port = '-', '-'
# Store the connection details in thread-local storage # Store the connection details in thread-local storage
thread_local.src_ip = src_ip thread_local.src_ip = src_ip
@ -314,90 +324,96 @@ def get_prompts(prompt: Optional[str], prompt_file: Optional[str]) -> dict:
#### MAIN #### #### MAIN ####
# Parse command line arguments try:
parser = argparse.ArgumentParser(description='Start the SSH honeypot server.') # Parse command line arguments
parser.add_argument('-c', '--config', type=str, default='config.ini', help='Path to the configuration file') parser = argparse.ArgumentParser(description='Start the SSH honeypot server.')
parser.add_argument('-p', '--prompt', type=str, help='The entire text of the prompt') parser.add_argument('-c', '--config', type=str, default='config.ini', help='Path to the configuration file')
parser.add_argument('-f', '--prompt-file', type=str, default='prompt.txt', help='Path to the prompt file') parser.add_argument('-p', '--prompt', type=str, help='The entire text of the prompt')
args = parser.parse_args() parser.add_argument('-f', '--prompt-file', type=str, default='prompt.txt', help='Path to the prompt file')
args = parser.parse_args()
# Check if the config file exists # Check if the config file exists
if not os.path.exists(args.config): if not os.path.exists(args.config):
print(f"Error: The specified config file '{args.config}' does not exist.", file=sys.stderr) print(f"Error: The specified config file '{args.config}' does not exist.", file=sys.stderr)
sys.exit(1)
# Always use UTC for logging
logging.Formatter.formatTime = (lambda self, record, datefmt=None: datetime.datetime.fromtimestamp(record.created, datetime.timezone.utc).isoformat(sep="T",timespec="milliseconds"))
# Read our configuration file
config = ConfigParser()
config.read(args.config)
# Read the user accounts from the configuration file
accounts = get_user_accounts()
# Set up the honeypot logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log"))
logger.addHandler(log_file_handler)
log_file_handler.setFormatter(JSONFormatter())
f = ContextFilter()
logger.addFilter(f)
# Now get access to the LLM
prompts = get_prompts(args.prompt, args.prompt_file)
llm_system_prompt = prompts["system_prompt"]
llm_user_prompt = prompts["user_prompt"]
llm = choose_llm()
llm_sessions = dict()
llm_trimmer = trim_messages(
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
strategy="last",
token_counter=llm,
include_system=True,
allow_partial=False,
start_on="human",
)
llm_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
llm_system_prompt
),
(
"system",
llm_user_prompt
),
MessagesPlaceholder(variable_name="messages"),
]
)
llm_chain = (
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
| llm_prompt
| llm
)
with_message_history = RunnableWithMessageHistory(
llm_chain,
llm_get_session_history,
input_messages_key="messages"
)
# Thread-local storage for connection details
thread_local = threading.local()
# Kick off the server!
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_server())
loop.run_forever()
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
traceback.print_exc()
sys.exit(1) sys.exit(1)
# Always use UTC for logging
logging.Formatter.formatTime = (lambda self, record, datefmt=None: datetime.datetime.fromtimestamp(record.created, datetime.timezone.utc).isoformat(sep="T",timespec="milliseconds"))
# Read our configuration file
config = ConfigParser()
config.read(args.config)
# Read the user accounts from the configuration file
accounts = get_user_accounts()
# Set up the honeypot logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
log_file_handler = logging.FileHandler(config['honeypot'].get("log_file", "ssh_log.log"))
logger.addHandler(log_file_handler)
log_file_handler.setFormatter(JSONFormatter())
f = ContextFilter()
logger.addFilter(f)
# Now get access to the LLM
prompts = get_prompts(args.prompt, args.prompt_file)
llm_system_prompt = prompts["system_prompt"]
llm_user_prompt = prompts["user_prompt"]
llm = choose_llm()
llm_sessions = dict()
llm_trimmer = trim_messages(
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
strategy="last",
token_counter=llm,
include_system=True,
allow_partial=False,
start_on="human",
)
llm_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
llm_system_prompt
),
(
"system",
llm_user_prompt
),
MessagesPlaceholder(variable_name="messages"),
]
)
llm_chain = (
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
| llm_prompt
| llm
)
with_message_history = RunnableWithMessageHistory(
llm_chain,
llm_get_session_history,
input_messages_key="messages"
)
# Thread-local storage for connection details
thread_local = threading.local()
# Kick off the server!
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_server())
loop.run_forever()