From 8703d1afdaac70086aee9ff9579cc9bfd2b22401 Mon Sep 17 00:00:00 2001 From: Mario Candela Date: Sun, 16 Feb 2025 16:27:10 +0100 Subject: [PATCH] Fix: llm plugin OpenAI generates random plaintext (hallucinations) (#163) * Add tests to adopt TDD. * Fix bug, LLM hallucinations --- plugins/llm-integration.go | 13 +++-- plugins/llm-integration_test.go | 90 +++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/plugins/llm-integration.go b/plugins/llm-integration.go index 99ee29b..712825c 100644 --- a/plugins/llm-integration.go +++ b/plugins/llm-integration.go @@ -6,12 +6,12 @@ import ( "fmt" "github.com/go-resty/resty/v2" "github.com/mariocandela/beelzebub/v3/tracer" - log "github.com/sirupsen/logrus" + "regexp" ) const ( - systemPromptVirtualizeLinuxTerminal = "You will act as an Ubuntu Linux terminal. The user will type commands, and you are to reply with what the terminal should show. Your responses must be contained within a single code block. Do not provide explanations or type commands unless explicitly instructed by the user. Your entire response/output is going to consist of a simple text with \n for new line, and you will NOT wrap it within string md markers" + systemPromptVirtualizeLinuxTerminal = "You will act as an Ubuntu Linux terminal. The user will type commands, and you are to reply with what the terminal should show. Your responses must be contained within a single code block. Do not provide note. Do not provide explanations or type commands unless explicitly instructed by the user. Your entire response/output is going to consist of a simple text with \n for new line, and you will NOT wrap it within string md markers" systemPromptVirtualizeHTTPServer = "You will act as an unsecure HTTP Server with multiple vulnerability like aws and git credentials stored into root http directory. The user will send HTTP requests, and you are to reply with what the server should show. Do not provide explanations or type commands unless explicitly instructed by the user." LLMPluginName = "LLMHoneypot" openAIGPTEndpoint = "https://api.openai.com/v1/chat/completions" @@ -185,7 +185,7 @@ func (llmHoneypot *LLMHoneypot) openAICaller(messages []Message) (string, error) return "", errors.New("no choices") } - return response.Result().(*Response).Choices[0].Message.Content, nil + return removeQuotes(response.Result().(*Response).Choices[0].Message.Content), nil } func (llmHoneypot *LLMHoneypot) ollamaCaller(messages []Message) (string, error) { @@ -216,7 +216,7 @@ func (llmHoneypot *LLMHoneypot) ollamaCaller(messages []Message) (string, error) } log.Debug(response) - return response.Result().(*Response).Message.Content, nil + return removeQuotes(response.Result().(*Response).Message.Content), nil } func (llmHoneypot *LLMHoneypot) ExecuteModel(command string) (string, error) { @@ -238,3 +238,8 @@ func (llmHoneypot *LLMHoneypot) ExecuteModel(command string) (string, error) { return "", errors.New("no model selected") } } + +func removeQuotes(content string) string { + regex := regexp.MustCompile("(```( *)?([a-z]*)?(\\n)?)") + return regex.ReplaceAllString(content, "") +} diff --git a/plugins/llm-integration_test.go b/plugins/llm-integration_test.go index 15e2516..332ee2f 100644 --- a/plugins/llm-integration_test.go +++ b/plugins/llm-integration_test.go @@ -379,3 +379,93 @@ func TestFromString(t *testing.T) { model, err = FromStringToLLMModel("beelzebub-model") assert.Errorf(t, err, "model beelzebub-model not found") } + +func TestBuildExecuteModelSSHWithoutPlaintextSection(t *testing.T) { + client := resty.New() + httpmock.ActivateNonDefault(client.GetClient()) + defer httpmock.DeactivateAndReset() + + // Given + httpmock.RegisterResponder("POST", ollamaEndpoint, + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, &Response{ + Message: Message{ + Role: SYSTEM.String(), + Content: "```plaintext\n```\n", + }, + }) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + llmHoneypot := LLMHoneypot{ + Histories: make([]Message, 0), + Protocol: tracer.SSH, + Model: LLAMA3, + } + + openAIGPTVirtualTerminal := InitLLMHoneypot(llmHoneypot) + openAIGPTVirtualTerminal.client = client + + //When + str, err := openAIGPTVirtualTerminal.ExecuteModel("ls") + + //Then + assert.Nil(t, err) + assert.Equal(t, "", str) +} + +func TestBuildExecuteModelSSHWithoutQuotesSection(t *testing.T) { + client := resty.New() + httpmock.ActivateNonDefault(client.GetClient()) + defer httpmock.DeactivateAndReset() + + // Given + httpmock.RegisterResponder("POST", ollamaEndpoint, + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, &Response{ + Message: Message{ + Role: SYSTEM.String(), + Content: "```\n```\n", + }, + }) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + llmHoneypot := LLMHoneypot{ + Histories: make([]Message, 0), + Protocol: tracer.SSH, + Model: LLAMA3, + } + + openAIGPTVirtualTerminal := InitLLMHoneypot(llmHoneypot) + openAIGPTVirtualTerminal.client = client + + //When + str, err := openAIGPTVirtualTerminal.ExecuteModel("ls") + + //Then + assert.Nil(t, err) + assert.Equal(t, "", str) +} + +func TestRemoveQuotes(t *testing.T) { + plaintext := "```plaintext\n```" + bash := "```bash\n```" + onlyQuotes := "```\n```" + complexText := "```plaintext\ntop - 10:30:48 up 1 day, 4:30, 2 users, load average: 0.15, 0.10, 0.08\nTasks: 198 total, 1 running, 197 sleeping, 0 stopped, 0 zombie\n```" + complexText2 := "```\ntop - 15:06:59 up 10 days, 3:17, 1 user, load average: 0.10, 0.09, 0.08\nTasks: 285 total\n```" + + assert.Equal(t, "", removeQuotes(plaintext)) + assert.Equal(t, "", removeQuotes(bash)) + assert.Equal(t, "", removeQuotes(onlyQuotes)) + assert.Equal(t, "top - 10:30:48 up 1 day, 4:30, 2 users, load average: 0.15, 0.10, 0.08\nTasks: 198 total, 1 running, 197 sleeping, 0 stopped, 0 zombie\n", removeQuotes(complexText)) + assert.Equal(t, "top - 15:06:59 up 10 days, 3:17, 1 user, load average: 0.10, 0.09, 0.08\nTasks: 285 total\n", removeQuotes(complexText2)) +}