Experimental support for changing LLM providers and models in the config file.

This commit is contained in:
David J. Bianco
2024-08-22 14:39:47 -04:00
parent df203a7a55
commit 7e38c43dee
2 changed files with 48 additions and 5 deletions

View File

@ -1,3 +1,13 @@
langchain
langchain_community
# For OpenAI models
langchain_openai
# For Google's Gemini models
langchain_google_genai
# For AWS
langchain_aws
transformers
torch
# For anthropic models (via AWS)
anthropic

View File

@ -13,6 +13,9 @@ import uuid
from base64 import b64encode
from langchain_openai import ChatOpenAI
from langchain_aws import ChatBedrock, ChatBedrockConverse
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.chat_history import (
BaseChatMessageHistory,
@ -40,7 +43,7 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
llm_response = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="")],
"messages": [HumanMessage(content="ignore this message")],
"username": process.get_extra_info('username')
},
config=llm_config
@ -67,7 +70,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
},
config=llm_config
)
process.stdout.write(f"{llm_response.content}")
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
@ -160,6 +162,37 @@ def get_user_accounts() -> dict:
return accounts
def choose_llm():
# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
llm_provider_name = config['llm'].get("llm_provider", "openai")
llm_provider_name = llm_provider_name.lower()
model_name = config['llm'].get("model_name", "gpt-3.5-turbo")
if llm_provider_name == 'openai':
print("***** Model: OpenAI")
print("***** Model Name: ", model_name)
llm_model = ChatOpenAI(
model=model_name
)
elif llm_provider_name == 'aws':
print("***** Model: AWS")
print("***** Model Name: ", model_name)
llm_model = ChatBedrockConverse(
model=model_name,
region_name=config['llm'].get("aws_region", "us-east-1"),
credentials_profile_name=config['llm'].get("aws_credentials_profile", "default") )
elif llm_provider_name == 'gemini':
print("***** Model: Gemini")
print("***** Model Name: ", model_name)
llm_model = ChatGoogleGenerativeAI(
model=model_name,
)
else:
raise ValueError(f"Invalid LLM provider {llm_provider_name}.")
return llm_model
#### MAIN ####
# Always use UTC for logging
@ -190,14 +223,14 @@ prompt_file = config['llm'].get("system_prompt_file", "prompt.txt")
with open(prompt_file, "r") as f:
llm_system_prompt = f.read()
llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
llm = choose_llm()
llm_sessions = dict()
llm_trimmer = trim_messages(
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
strategy="last",
token_counter=llm_model,
token_counter=llm,
include_system=True,
allow_partial=False,
start_on="human",
@ -216,7 +249,7 @@ llm_prompt = ChatPromptTemplate.from_messages(
llm_chain = (
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
| llm_prompt
| llm_model
| llm
)
with_message_history = RunnableWithMessageHistory(