mirror of
https://github.com/ChrisSewell/DECEIVE.git
synced 2025-07-01 18:47:28 -04:00
Experimental support for changing LLM providers and models in the config file.
This commit is contained in:
@ -1,3 +1,13 @@
|
|||||||
langchain
|
langchain
|
||||||
langchain_community
|
langchain_community
|
||||||
|
# For OpenAI models
|
||||||
langchain_openai
|
langchain_openai
|
||||||
|
# For Google's Gemini models
|
||||||
|
langchain_google_genai
|
||||||
|
|
||||||
|
# For AWS
|
||||||
|
langchain_aws
|
||||||
|
transformers
|
||||||
|
torch
|
||||||
|
# For anthropic models (via AWS)
|
||||||
|
anthropic
|
@ -13,6 +13,9 @@ import uuid
|
|||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.chat_history import (
|
from langchain_core.chat_history import (
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
@ -40,7 +43,7 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
|
|||||||
|
|
||||||
llm_response = await with_message_history.ainvoke(
|
llm_response = await with_message_history.ainvoke(
|
||||||
{
|
{
|
||||||
"messages": [HumanMessage(content="")],
|
"messages": [HumanMessage(content="ignore this message")],
|
||||||
"username": process.get_extra_info('username')
|
"username": process.get_extra_info('username')
|
||||||
},
|
},
|
||||||
config=llm_config
|
config=llm_config
|
||||||
@ -67,7 +70,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
|
|||||||
},
|
},
|
||||||
config=llm_config
|
config=llm_config
|
||||||
)
|
)
|
||||||
|
|
||||||
process.stdout.write(f"{llm_response.content}")
|
process.stdout.write(f"{llm_response.content}")
|
||||||
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
|
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
|
||||||
|
|
||||||
@ -160,6 +162,37 @@ def get_user_accounts() -> dict:
|
|||||||
|
|
||||||
return accounts
|
return accounts
|
||||||
|
|
||||||
|
def choose_llm():
|
||||||
|
# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
|
||||||
|
|
||||||
|
llm_provider_name = config['llm'].get("llm_provider", "openai")
|
||||||
|
llm_provider_name = llm_provider_name.lower()
|
||||||
|
model_name = config['llm'].get("model_name", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
if llm_provider_name == 'openai':
|
||||||
|
print("***** Model: OpenAI")
|
||||||
|
print("***** Model Name: ", model_name)
|
||||||
|
llm_model = ChatOpenAI(
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
elif llm_provider_name == 'aws':
|
||||||
|
print("***** Model: AWS")
|
||||||
|
print("***** Model Name: ", model_name)
|
||||||
|
llm_model = ChatBedrockConverse(
|
||||||
|
model=model_name,
|
||||||
|
region_name=config['llm'].get("aws_region", "us-east-1"),
|
||||||
|
credentials_profile_name=config['llm'].get("aws_credentials_profile", "default") )
|
||||||
|
elif llm_provider_name == 'gemini':
|
||||||
|
print("***** Model: Gemini")
|
||||||
|
print("***** Model Name: ", model_name)
|
||||||
|
llm_model = ChatGoogleGenerativeAI(
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid LLM provider {llm_provider_name}.")
|
||||||
|
|
||||||
|
return llm_model
|
||||||
|
|
||||||
#### MAIN ####
|
#### MAIN ####
|
||||||
|
|
||||||
# Always use UTC for logging
|
# Always use UTC for logging
|
||||||
@ -190,14 +223,14 @@ prompt_file = config['llm'].get("system_prompt_file", "prompt.txt")
|
|||||||
with open(prompt_file, "r") as f:
|
with open(prompt_file, "r") as f:
|
||||||
llm_system_prompt = f.read()
|
llm_system_prompt = f.read()
|
||||||
|
|
||||||
llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
|
llm = choose_llm()
|
||||||
|
|
||||||
llm_sessions = dict()
|
llm_sessions = dict()
|
||||||
|
|
||||||
llm_trimmer = trim_messages(
|
llm_trimmer = trim_messages(
|
||||||
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
|
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
|
||||||
strategy="last",
|
strategy="last",
|
||||||
token_counter=llm_model,
|
token_counter=llm,
|
||||||
include_system=True,
|
include_system=True,
|
||||||
allow_partial=False,
|
allow_partial=False,
|
||||||
start_on="human",
|
start_on="human",
|
||||||
@ -216,7 +249,7 @@ llm_prompt = ChatPromptTemplate.from_messages(
|
|||||||
llm_chain = (
|
llm_chain = (
|
||||||
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
|
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
|
||||||
| llm_prompt
|
| llm_prompt
|
||||||
| llm_model
|
| llm
|
||||||
)
|
)
|
||||||
|
|
||||||
with_message_history = RunnableWithMessageHistory(
|
with_message_history = RunnableWithMessageHistory(
|
||||||
|
Reference in New Issue
Block a user