Preliminary support for Azure OpenAI models, plus "porn fix"

This version adds support for Azure OpenAI models. I'm not entirely happy with how each LLM provider has it's own set of params, and am investigating how to make these seem a little more unified, so this support may change in the future.

Also, Azure's content filter flags the "XXX-END-OF-SESSION-XXX" token as "sexual content", so I changed it to use "YYY" instead. I feel so protected!
This commit is contained in:
David J. Bianco
2025-03-20 15:21:07 -04:00
parent e2e47c4e6c
commit a3c14bbf15
2 changed files with 17 additions and 3 deletions

View File

@ -30,6 +30,13 @@ server_version_string = OpenSSH_8.2p1 Ubuntu-4ubuntu0.3
llm_provider = openai
model_name = gpt-4o
##### Azure OpenAI
#llm_provider = azure
#azure_deployment = gpt-4o
#azure_api_version = 2025-01-01-preview
#azure_endpoint = <your endpoint url>
#model_name = gpt-4o
##### ollama llama3
#llm_provider = ollama
#model_name = llama3.3

View File

@ -15,7 +15,7 @@ import datetime
import uuid
from base64 import b64encode
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_aws import ChatBedrock, ChatBedrockConverse
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
@ -234,7 +234,7 @@ async def handle_client(process: asyncssh.SSHServerProcess, server: MySSHServer)
},
config=llm_config
)
if llm_response.content == "XXX-END-OF-SESSION-XXX":
if llm_response.content == "YYY-END-OF-SESSION-YYY":
await session_summary(process, llm_config, with_message_history, server)
process.exit(0)
return
@ -314,8 +314,15 @@ def choose_llm(llm_provider: Optional[str] = None, model_name: Optional[str] = N
llm_model = ChatOpenAI(
model=model_name
)
elif llm_provider_name == 'azure':
llm_model = AzureChatOpenAI(
azure_deployment=config['llm'].get("azure_deployment"),
azure_endpoint=config['llm'].get("azure_endpoint"),
api_version=config['llm'].get("azure_api_version"),
model=config['llm'].get("model_name") # Ensure model_name is passed here
)
elif llm_provider_name == 'ollama':
llm_model = ChatOllama(
llm_model = ChatOllama(
model=model_name
)
elif llm_provider_name == 'aws':