Feat: Refactoring plugin:LLM honeypot custom prompt (#154)

refactoring LLM honeypot custom prompt
This commit is contained in:
Mario Candela
2025-01-16 08:46:13 +01:00
committed by GitHub
parent c3d2ff885d
commit 99c7287c02
2 changed files with 53 additions and 19 deletions

View File

@ -96,14 +96,19 @@ func InitLLMHoneypot(config LLMHoneypot) *LLMHoneypot {
return &config return &config
} }
func buildPrompt(histories []Message, protocol tracer.Protocol, command string) ([]Message, error) { func (llmHoneypot *LLMHoneypot) buildPrompt(command string) ([]Message, error) {
var messages []Message var messages []Message
var prompt string
switch protocol { switch llmHoneypot.Protocol {
case tracer.SSH: case tracer.SSH:
prompt = systemPromptVirtualizeLinuxTerminal
if llmHoneypot.CustomPrompt != "" {
prompt = llmHoneypot.CustomPrompt
}
messages = append(messages, Message{ messages = append(messages, Message{
Role: SYSTEM.String(), Role: SYSTEM.String(),
Content: systemPromptVirtualizeLinuxTerminal, Content: prompt,
}) })
messages = append(messages, Message{ messages = append(messages, Message{
Role: USER.String(), Role: USER.String(),
@ -113,13 +118,17 @@ func buildPrompt(histories []Message, protocol tracer.Protocol, command string)
Role: ASSISTANT.String(), Role: ASSISTANT.String(),
Content: "/home/user", Content: "/home/user",
}) })
for _, history := range histories { for _, history := range llmHoneypot.Histories {
messages = append(messages, history) messages = append(messages, history)
} }
case tracer.HTTP: case tracer.HTTP:
prompt = systemPromptVirtualizeHTTPServer
if llmHoneypot.CustomPrompt != "" {
prompt = llmHoneypot.CustomPrompt
}
messages = append(messages, Message{ messages = append(messages, Message{
Role: SYSTEM.String(), Role: SYSTEM.String(),
Content: systemPromptVirtualizeHTTPServer, Content: prompt,
}) })
messages = append(messages, Message{ messages = append(messages, Message{
Role: USER.String(), Role: USER.String(),
@ -214,18 +223,7 @@ func (llmHoneypot *LLMHoneypot) ExecuteModel(command string) (string, error) {
var err error var err error
var prompt []Message var prompt []Message
if llmHoneypot.CustomPrompt != "" { prompt, err = llmHoneypot.buildPrompt(command)
prompt = append(prompt, Message{
Role: SYSTEM.String(),
Content: llmHoneypot.CustomPrompt,
})
prompt = append(prompt, Message{
Role: USER.String(),
Content: command,
})
} else {
prompt, err = buildPrompt(llmHoneypot.Histories, llmHoneypot.Protocol, command)
}
if err != nil { if err != nil {
return "", err return "", err

View File

@ -16,8 +16,13 @@ func TestBuildPromptEmptyHistory(t *testing.T) {
var histories []Message var histories []Message
command := "pwd" command := "pwd"
honeypot := LLMHoneypot{
Histories: histories,
Protocol: tracer.SSH,
}
//When //When
prompt, err := buildPrompt(histories, tracer.SSH, command) prompt, err := honeypot.buildPrompt(command)
//Then //Then
assert.Nil(t, err) assert.Nil(t, err)
@ -35,14 +40,45 @@ func TestBuildPromptWithHistory(t *testing.T) {
command := "pwd" command := "pwd"
honeypot := LLMHoneypot{
Histories: histories,
Protocol: tracer.SSH,
}
//When //When
prompt, err := buildPrompt(histories, tracer.SSH, command) prompt, err := honeypot.buildPrompt(command)
//Then //Then
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, SystemPromptLen+1, len(prompt)) assert.Equal(t, SystemPromptLen+1, len(prompt))
} }
func TestBuildPromptWithCustomPrompt(t *testing.T) {
//Given
var histories = []Message{
{
Role: "cat hello.txt",
Content: "world",
},
}
command := "pwd"
honeypot := LLMHoneypot{
Histories: histories,
Protocol: tracer.SSH,
CustomPrompt: "act as calculator",
}
//When
prompt, err := honeypot.buildPrompt(command)
//Then
assert.Nil(t, err)
assert.Equal(t, prompt[0].Content, "act as calculator")
assert.Equal(t, prompt[0].Role, SYSTEM.String())
}
func TestBuildExecuteModelFailValidation(t *testing.T) { func TestBuildExecuteModelFailValidation(t *testing.T) {
llmHoneypot := LLMHoneypot{ llmHoneypot := LLMHoneypot{