Implement chat message history trimming to avoid overflowing the LLM context window.

This commit is contained in:
David J. Bianco
2024-08-16 11:34:29 -04:00
parent c40444a6cc
commit 0f5c4d1f69

View File

@ -20,7 +20,10 @@ from langchain_core.chat_history import (
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage, trim_messages
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
async def handle_client(process: asyncssh.SSHServerProcess) -> None:
# This is the main loop for handling SSH client connections.
@ -157,6 +160,15 @@ llm_model = ChatOpenAI(model="gpt-4o")
llm_sessions = dict()
llm_trimmer = trim_messages(
max_tokens=64000,
strategy="last",
token_counter=llm_model,
include_system=True,
allow_partial=False,
start_on="human",
)
llm_prompt = ChatPromptTemplate.from_messages(
[
(
@ -167,7 +179,11 @@ llm_prompt = ChatPromptTemplate.from_messages(
]
)
llm_chain = llm_prompt | llm_model
llm_chain = (
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
| llm_prompt
| llm_model
)
with_message_history = RunnableWithMessageHistory(
llm_chain,
@ -175,7 +191,6 @@ with_message_history = RunnableWithMessageHistory(
input_messages_key="messages"
)
# Read the valid accounts
accounts = read_accounts()