Added 'temperature' parameter to control randomness in LLM responses.

Lower temps mean less randomness in the  responses, which increase the chances of consistency between sessions. Not a guarantee, though.
This commit is contained in:
David J. Bianco
2025-04-25 09:12:40 -04:00
parent a3c14bbf15
commit 10e2f11599
2 changed files with 23 additions and 5 deletions

View File

@ -53,6 +53,12 @@ model_name = gpt-4o
#llm_provider = gemini
#model_name = gemini-1.5-pro
# Temperature controls randomness in LLM responses. Values usually range from 0.0 to 2.0.
# Lower values (e.g., 0.2) make responses more focused and deterministic.
# Higher values (e.g., 0.8) make responses more creative and variable.
# Default is 0.2.
temperature = 0.2
# The maximum number of tokens to send to the LLM backend in a single
# request. This includes the message history for the session, so should
# be fairly high. Not all models support large token counts, so be sure

View File

@ -310,29 +310,38 @@ def choose_llm(llm_provider: Optional[str] = None, model_name: Optional[str] = N
llm_provider_name = llm_provider_name.lower()
model_name = model_name or config['llm'].get("model_name", "gpt-3.5-turbo")
# Get temperature parameter from config, default to 0.7 if not specified
temperature = config['llm'].getfloat("temperature", 0.7)
if llm_provider_name == 'openai':
llm_model = ChatOpenAI(
model=model_name
model=model_name,
temperature=temperature
)
elif llm_provider_name == 'azure':
llm_model = AzureChatOpenAI(
azure_deployment=config['llm'].get("azure_deployment"),
azure_endpoint=config['llm'].get("azure_endpoint"),
api_version=config['llm'].get("azure_api_version"),
model=config['llm'].get("model_name") # Ensure model_name is passed here
model=config['llm'].get("model_name"), # Ensure model_name is passed here
temperature=temperature
)
elif llm_provider_name == 'ollama':
llm_model = ChatOllama(
model=model_name
model=model_name,
temperature=temperature
)
elif llm_provider_name == 'aws':
llm_model = ChatBedrockConverse(
model=model_name,
region_name=config['llm'].get("aws_region", "us-east-1"),
credentials_profile_name=config['llm'].get("aws_credentials_profile", "default") )
credentials_profile_name=config['llm'].get("aws_credentials_profile", "default"),
temperature=temperature
)
elif llm_provider_name == 'gemini':
llm_model = ChatGoogleGenerativeAI(
model=model_name,
temperature=temperature
)
else:
raise ValueError(f"Invalid LLM provider {llm_provider_name}.")
@ -374,6 +383,7 @@ try:
parser.add_argument('-m', '--model-name', type=str, help='The model name to use')
parser.add_argument('-t', '--trimmer-max-tokens', type=int, help='The maximum number of tokens to send to the LLM backend in a single request')
parser.add_argument('-s', '--system-prompt', type=str, help='System prompt for the LLM')
parser.add_argument('-r', '--temperature', type=float, help='Temperature parameter for controlling randomness in LLM responses (0.0-2.0)')
parser.add_argument('-P', '--port', type=int, help='The port the SSH honeypot will listen on')
parser.add_argument('-k', '--host-priv-key', type=str, help='The host key to use for the SSH server')
parser.add_argument('-v', '--server-version-string', type=str, help='The server version string to send to clients')
@ -398,7 +408,7 @@ try:
# Use defaults when no config file found.
config['honeypot'] = {'log_file': 'ssh_log.log', 'sensor_name': socket.gethostname()}
config['ssh'] = {'port': '8022', 'host_priv_key': 'ssh_host_key', 'server_version_string': 'SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.3'}
config['llm'] = {'llm_provider': 'openai', 'model_name': 'gpt-3.5-turbo', 'trimmer_max_tokens': '64000', 'system_prompt': ''}
config['llm'] = {'llm_provider': 'openai', 'model_name': 'gpt-3.5-turbo', 'trimmer_max_tokens': '64000', 'temperature': '0.7', 'system_prompt': ''}
config['user_accounts'] = {}
# Override config values with command line arguments if provided
@ -410,6 +420,8 @@ try:
config['llm']['trimmer_max_tokens'] = str(args.trimmer_max_tokens)
if args.system_prompt:
config['llm']['system_prompt'] = args.system_prompt
if args.temperature is not None:
config['llm']['temperature'] = str(args.temperature)
if args.port:
config['ssh']['port'] = str(args.port)
if args.host_priv_key: