mirror of
https://github.com/splunk/DECEIVE.git
synced 2025-07-01 16:47:28 -04:00
Experimental support for changing LLM providers and models in the config file.
This commit is contained in:
@ -1,3 +1,13 @@
|
||||
langchain
|
||||
langchain_community
|
||||
# For OpenAI models
|
||||
langchain_openai
|
||||
# For Google's Gemini models
|
||||
langchain_google_genai
|
||||
|
||||
# For AWS
|
||||
langchain_aws
|
||||
transformers
|
||||
torch
|
||||
# For anthropic models (via AWS)
|
||||
anthropic
|
@ -13,6 +13,9 @@ import uuid
|
||||
from base64 import b64encode
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
@ -40,7 +43,7 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
|
||||
|
||||
llm_response = await with_message_history.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="")],
|
||||
"messages": [HumanMessage(content="ignore this message")],
|
||||
"username": process.get_extra_info('username')
|
||||
},
|
||||
config=llm_config
|
||||
@ -67,7 +70,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None:
|
||||
},
|
||||
config=llm_config
|
||||
)
|
||||
|
||||
process.stdout.write(f"{llm_response.content}")
|
||||
logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}")
|
||||
|
||||
@ -160,6 +162,37 @@ def get_user_accounts() -> dict:
|
||||
|
||||
return accounts
|
||||
|
||||
def choose_llm():
|
||||
# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
|
||||
|
||||
llm_provider_name = config['llm'].get("llm_provider", "openai")
|
||||
llm_provider_name = llm_provider_name.lower()
|
||||
model_name = config['llm'].get("model_name", "gpt-3.5-turbo")
|
||||
|
||||
if llm_provider_name == 'openai':
|
||||
print("***** Model: OpenAI")
|
||||
print("***** Model Name: ", model_name)
|
||||
llm_model = ChatOpenAI(
|
||||
model=model_name
|
||||
)
|
||||
elif llm_provider_name == 'aws':
|
||||
print("***** Model: AWS")
|
||||
print("***** Model Name: ", model_name)
|
||||
llm_model = ChatBedrockConverse(
|
||||
model=model_name,
|
||||
region_name=config['llm'].get("aws_region", "us-east-1"),
|
||||
credentials_profile_name=config['llm'].get("aws_credentials_profile", "default") )
|
||||
elif llm_provider_name == 'gemini':
|
||||
print("***** Model: Gemini")
|
||||
print("***** Model Name: ", model_name)
|
||||
llm_model = ChatGoogleGenerativeAI(
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM provider {llm_provider_name}.")
|
||||
|
||||
return llm_model
|
||||
|
||||
#### MAIN ####
|
||||
|
||||
# Always use UTC for logging
|
||||
@ -190,14 +223,14 @@ prompt_file = config['llm'].get("system_prompt_file", "prompt.txt")
|
||||
with open(prompt_file, "r") as f:
|
||||
llm_system_prompt = f.read()
|
||||
|
||||
llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE"))
|
||||
llm = choose_llm()
|
||||
|
||||
llm_sessions = dict()
|
||||
|
||||
llm_trimmer = trim_messages(
|
||||
max_tokens=config['llm'].getint("trimmer_max_tokens", 64000),
|
||||
strategy="last",
|
||||
token_counter=llm_model,
|
||||
token_counter=llm,
|
||||
include_system=True,
|
||||
allow_partial=False,
|
||||
start_on="human",
|
||||
@ -216,7 +249,7 @@ llm_prompt = ChatPromptTemplate.from_messages(
|
||||
llm_chain = (
|
||||
RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer)
|
||||
| llm_prompt
|
||||
| llm_model
|
||||
| llm
|
||||
)
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
|
Reference in New Issue
Block a user