From 0f5c4d1f698a4b80d5e5816313f063a470e41d4d Mon Sep 17 00:00:00 2001 From: "David J. Bianco" Date: Fri, 16 Aug 2024 11:34:29 -0400 Subject: [PATCH] Implement chat message history trimming to avoid overflowing the LLM context window. --- ssh_server.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/ssh_server.py b/ssh_server.py index ea3f0c8..aebafbc 100644 --- a/ssh_server.py +++ b/ssh_server.py @@ -20,7 +20,10 @@ from langchain_core.chat_history import ( ) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.messages import SystemMessage, trim_messages +from langchain_core.runnables import RunnablePassthrough +from operator import itemgetter async def handle_client(process: asyncssh.SSHServerProcess) -> None: # This is the main loop for handling SSH client connections. @@ -157,6 +160,15 @@ llm_model = ChatOpenAI(model="gpt-4o") llm_sessions = dict() +llm_trimmer = trim_messages( + max_tokens=64000, + strategy="last", + token_counter=llm_model, + include_system=True, + allow_partial=False, + start_on="human", +) + llm_prompt = ChatPromptTemplate.from_messages( [ ( @@ -167,7 +179,11 @@ llm_prompt = ChatPromptTemplate.from_messages( ] ) -llm_chain = llm_prompt | llm_model +llm_chain = ( + RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer) + | llm_prompt + | llm_model +) with_message_history = RunnableWithMessageHistory( llm_chain, @@ -175,7 +191,6 @@ with_message_history = RunnableWithMessageHistory( input_messages_key="messages" ) - # Read the valid accounts accounts = read_accounts()