Now a function prototype with an LLM backend.

* Added langchain support (OpenAI's gpt-4o model)
* Created a system prompt that gives functional results
* Initial integration of logging for LLM responses (needs improvement)
This commit is contained in:
David J. Bianco
2024-08-15 15:44:54 -04:00
parent 759814f8c9
commit 092ac94b05
2 changed files with 82 additions and 7 deletions

View File

@ -1,3 +1,3 @@
langchain
langchain_community
langchain_openai

View File

@ -11,22 +11,52 @@ import logging
import datetime
import uuid
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
async def handle_client(process: asyncssh.SSHServerProcess) -> None:
# This is the main loop for handling SSH client connections.
# Any user interaction should be done here.
# Give each session a unique name
task_uuid = f"session-{uuid.uuid4()}"
current_task = asyncio.current_task()
current_task.set_name(f"session-{uuid.uuid4()}")
current_task.set_name(task_uuid)
llm_config = {"configurable": {"session_id": task_uuid}}
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="")],
"username": process.get_extra_info('username')
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {llm_response.content}")
process.stdout.write('Welcome to my SSH server, %s!\n' %
process.get_extra_info('username'))
try:
async for line in process.stdin:
line = line.rstrip('\n')
logger.info(f"INPUT: {line}")
process.stdout.write('You entered: %s\n' % line)
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content=line)],
"username": process.get_extra_info('username')
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {llm_response.content}")
except asyncssh.BreakReceived:
pass
@ -42,7 +72,6 @@ class MySSHServer(asyncssh.SSHServer):
logger.error('SSH connection error: ' + str(exc), file=sys.stderr)
else:
print('SSH connection closed.')
logger.info("SSH connection closed.")
def begin_auth(self, username: str) -> bool:
@ -87,6 +116,11 @@ def read_accounts() -> dict:
return accounts
def llm_get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in llm_sessions:
llm_sessions[session_id] = InMemoryChatMessageHistory()
return llm_sessions[session_id]
#### MAIN ####
# Always use UTC for logging
@ -104,6 +138,47 @@ log_file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(tas
f = ContextFilter()
logger.addFilter(f)
# Now get access to the LLM
llm_model = ChatOpenAI(model="gpt-4o")
llm_sessions = dict()
llm_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
'''
You are a linux operating system accepting commands from a user via SSH.
Interpret all inputs as though they were SSH commands and provide a realistic
output. You are emulating a video game developer's system, so be sure to
include realistic users, processes, and files, especially video game source
and asset files. Do not include extraneous formatting in your responses.
On the first call, be sure to include a realistic MOTD.
End all responses with a realistic shell prompt to display to the user,
including a space at the end.
Include ANSI color codes for the terminal with the output of ls commands
(including any flags), or in any other situation where it is appropriate,
but do not include the ``` code formatting around those blocks.
Assume the username is {username}.
'''
),
MessagesPlaceholder(variable_name="messages"),
]
)
llm_chain = llm_prompt | llm_model
with_message_history = RunnableWithMessageHistory(
llm_chain,
llm_get_session_history,
input_messages_key="messages"
)
# Read the valid accounts
accounts = read_accounts()