diff --git a/requirements.txt b/requirements.txt index 73e41fc..afa45fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,13 @@ langchain langchain_community +# For OpenAI models langchain_openai +# For Google's Gemini models +langchain_google_genai + +# For AWS +langchain_aws +transformers +torch +# For anthropic models (via AWS) +anthropic \ No newline at end of file diff --git a/ssh_server.py b/ssh_server.py index 6ea480d..fffe88b 100644 --- a/ssh_server.py +++ b/ssh_server.py @@ -13,6 +13,9 @@ import uuid from base64 import b64encode from langchain_openai import ChatOpenAI +from langchain_aws import ChatBedrock, ChatBedrockConverse +from langchain_google_genai import ChatGoogleGenerativeAI + from langchain_core.messages import HumanMessage from langchain_core.chat_history import ( BaseChatMessageHistory, @@ -40,7 +43,7 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: llm_response = await with_message_history.ainvoke( { - "messages": [HumanMessage(content="")], + "messages": [HumanMessage(content="ignore this message")], "username": process.get_extra_info('username') }, config=llm_config @@ -67,7 +70,6 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: }, config=llm_config ) - process.stdout.write(f"{llm_response.content}") logger.info(f"OUTPUT: {b64encode(llm_response.content.encode('ascii')).decode('ascii')}") @@ -160,6 +162,37 @@ def get_user_accounts() -> dict: return accounts +def choose_llm(): +# llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE")) + + llm_provider_name = config['llm'].get("llm_provider", "openai") + llm_provider_name = llm_provider_name.lower() + model_name = config['llm'].get("model_name", "gpt-3.5-turbo") + + if llm_provider_name == 'openai': + print("***** Model: OpenAI") + print("***** Model Name: ", model_name) + llm_model = ChatOpenAI( + model=model_name + ) + elif llm_provider_name == 'aws': + print("***** Model: AWS") + print("***** Model Name: ", model_name) + llm_model = ChatBedrockConverse( + model=model_name, + region_name=config['llm'].get("aws_region", "us-east-1"), + credentials_profile_name=config['llm'].get("aws_credentials_profile", "default") ) + elif llm_provider_name == 'gemini': + print("***** Model: Gemini") + print("***** Model Name: ", model_name) + llm_model = ChatGoogleGenerativeAI( + model=model_name, + ) + else: + raise ValueError(f"Invalid LLM provider {llm_provider_name}.") + + return llm_model + #### MAIN #### # Always use UTC for logging @@ -190,14 +223,14 @@ prompt_file = config['llm'].get("system_prompt_file", "prompt.txt") with open(prompt_file, "r") as f: llm_system_prompt = f.read() -llm_model = ChatOpenAI(model=config['llm'].get("model", "NONE")) +llm = choose_llm() llm_sessions = dict() llm_trimmer = trim_messages( max_tokens=config['llm'].getint("trimmer_max_tokens", 64000), strategy="last", - token_counter=llm_model, + token_counter=llm, include_system=True, allow_partial=False, start_on="human", @@ -216,7 +249,7 @@ llm_prompt = ChatPromptTemplate.from_messages( llm_chain = ( RunnablePassthrough.assign(messages=itemgetter("messages") | llm_trimmer) | llm_prompt - | llm_model + | llm ) with_message_history = RunnableWithMessageHistory(