diff --git a/README.md b/README.md index 69d5e6c..43e63e5 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![CI](https://github.com/mariocandela/beelzebub/actions/workflows/ci.yml/badge.svg)](https://github.com/mariocandela/beelzebub/actions/workflows/ci.yml) [![Docker](https://github.com/mariocandela/beelzebub/actions/workflows/docker-image.yml/badge.svg)](https://github.com/mariocandela/beelzebub/actions/workflows/docker-image.yml) [![codeql](https://github.com/mariocandela/beelzebub/actions/workflows/codeql.yml/badge.svg)](https://github.com/mariocandela/beelzebub/actions/workflows/codeql.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/mariocandela/beelzebub)](https://goreportcard.com/report/github.com/mariocandela/beelzebub) +[![codecov](https://codecov.io/gh/mariocandela/beelzebub/graph/badge.svg?token=8XTK7D4WHE)](https://codecov.io/gh/mariocandela/beelzebub) ## Overview @@ -91,7 +92,9 @@ $ make test.unit To run integration tests: ```bash +$ make test.dependencies.start $ make test.integration +$ make test.dependencies.down ``` ## Key Features diff --git a/parser/configurations_parser.go b/parser/configurations_parser.go index d38c45b..07c7efa 100644 --- a/parser/configurations_parser.go +++ b/parser/configurations_parser.go @@ -2,7 +2,6 @@ package parser import ( "fmt" - "io/ioutil" "os" "path/filepath" "strings" @@ -127,7 +126,7 @@ func (bp configurationsParser) ReadConfigurationsServices() ([]BeelzebubServiceC } func gelAllFilesNameByDirName(dirName string) ([]string, error) { - files, err := ioutil.ReadDir(dirName) + files, err := os.ReadDir(dirName) if err != nil { return nil, err } diff --git a/parser/configurations_parser_test.go b/parser/configurations_parser_test.go index 368ca1b..3e150f4 100644 --- a/parser/configurations_parser_test.go +++ b/parser/configurations_parser_test.go @@ -2,6 +2,7 @@ package parser import ( "errors" + "os" "testing" "github.com/stretchr/testify/assert" @@ -118,3 +119,54 @@ func TestReadConfigurationsServicesValid(t *testing.T) { assert.Equal(t, len(firstBeelzebubServiceConfiguration.Commands[0].Headers), 1) assert.Equal(t, firstBeelzebubServiceConfiguration.Commands[0].Headers[0], "Content-Type: text/html") } + +func TestGelAllFilesNameByDirName(t *testing.T) { + + var dir = t.TempDir() + + files, err := gelAllFilesNameByDirName(dir) + + assert.Nil(t, err) + assert.Equal(t, 0, len(files)) +} + +func TestGelAllFilesNameByDirNameFiles(t *testing.T) { + + var dir = t.TempDir() + + testFiles := []string{"file1.yaml", "file2.yaml", "file3.txt", "subdir", "file4.yaml"} + for _, filename := range testFiles { + filePath := dir + "/" + filename + file, err := os.Create(filePath) + assert.NoError(t, err) + file.Close() + } + + files, err := gelAllFilesNameByDirName(dir) + + assert.Nil(t, err) + assert.Equal(t, 3, len(files)) +} + +func TestGelAllFilesNameByDirNameError(t *testing.T) { + + files, err := gelAllFilesNameByDirName("nosuchfile") + + assert.Nil(t, files) + assert.Equal(t, "open nosuchfile: no such file or directory", err.Error()) +} + +func TestReadFileBytesByFilePath(t *testing.T) { + + var dir = t.TempDir() + filePath := dir + "/test.yaml" + + f, err := os.Create(filePath) + assert.NoError(t, err) + f.Close() + + bytes, err := readFileBytesByFilePath(filePath) + assert.NoError(t, err) + + assert.Equal(t, "", string(bytes)) +} diff --git a/plugins/openai-gpt.go b/plugins/openai-gpt.go index b0a0983..e6f0564 100644 --- a/plugins/openai-gpt.go +++ b/plugins/openai-gpt.go @@ -5,33 +5,27 @@ import ( "errors" "fmt" "strings" - + log "github.com/sirupsen/logrus" - + "github.com/go-resty/resty/v2" ) const ( // Reference: https://www.engraved.blog/building-a-virtual-machine-inside/ promptVirtualizeLinuxTerminal = "I want you to act as a Linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. Do no write explanations. Do not type commands unless I instruct you to do so.\n\nA:pwd\n\nQ:/home/user\n\n" - ChatGPTPluginName = "OpenAIGPTLinuxTerminal" - openAIGPTEndpoint = "https://api.openai.com/v1/completions" -) + ChatGPTPluginName = "OpenAIGPTLinuxTerminal" + openAIGPTEndpoint = "https://api.openai.com/v1/completions" +) type History struct { Input, Output string } - -type OpenAIGPTVirtualTerminal struct { - Histories []History - OpenAPIChatGPTSecretKey string - client *resty.Client -} -func (openAIGPTVirtualTerminal *OpenAIGPTVirtualTerminal) InjectDependency() { - if openAIGPTVirtualTerminal.client == nil { - openAIGPTVirtualTerminal.client = resty.New() - } +type openAIGPTVirtualTerminal struct { + Histories []History + openAIKey string + client *resty.Client } type Choice struct { @@ -65,6 +59,14 @@ type gptRequest struct { Stop []string `json:"stop"` } +func Init(history []History, openAIKey string) *openAIGPTVirtualTerminal { + return &openAIGPTVirtualTerminal{ + Histories: history, + openAIKey: openAIKey, + client: resty.New(), + } +} + func buildPrompt(histories []History, command string) string { var sb strings.Builder @@ -79,7 +81,7 @@ func buildPrompt(histories []History, command string) string { return sb.String() } -func (openAIGPTVirtualTerminal *OpenAIGPTVirtualTerminal) GetCompletions(command string) (string, error) { +func (openAIGPTVirtualTerminal *openAIGPTVirtualTerminal) GetCompletions(command string) (string, error) { requestJson, err := json.Marshal(gptRequest{ Model: "text-davinci-003", Prompt: buildPrompt(openAIGPTVirtualTerminal.Histories, command), @@ -94,14 +96,14 @@ func (openAIGPTVirtualTerminal *OpenAIGPTVirtualTerminal) GetCompletions(command return "", err } - if openAIGPTVirtualTerminal.OpenAPIChatGPTSecretKey == "" { - return "", errors.New("OpenAPIChatGPTSecretKey is empty") + if openAIGPTVirtualTerminal.openAIKey == "" { + return "", errors.New("openAIKey is empty") } response, err := openAIGPTVirtualTerminal.client.R(). SetHeader("Content-Type", "application/json"). SetBody(requestJson). - SetAuthToken(openAIGPTVirtualTerminal.OpenAPIChatGPTSecretKey). + SetAuthToken(openAIGPTVirtualTerminal.openAIKey). SetResult(&gptResponse{}). Post(openAIGPTEndpoint) diff --git a/plugins/openai-gpt_test.go b/plugins/openai-gpt_test.go index b41b864..d987edf 100644 --- a/plugins/openai-gpt_test.go +++ b/plugins/openai-gpt_test.go @@ -46,7 +46,15 @@ func TestBuildPromptWithHistory(t *testing.T) { prompt) } -func TestBuildGetCompletions(t *testing.T) { +func TestBuildGetCompletionsFailValidation(t *testing.T) { + openAIGPTVirtualTerminal := Init(make([]History, 0), "") + + _, err := openAIGPTVirtualTerminal.GetCompletions("test") + + assert.Equal(t, "openAIKey is empty", err.Error()) +} + +func TestBuildGetCompletionsWithResults(t *testing.T) { client := resty.New() httpmock.ActivateNonDefault(client.GetClient()) defer httpmock.DeactivateAndReset() @@ -68,10 +76,8 @@ func TestBuildGetCompletions(t *testing.T) { }, ) - openAIGPTVirtualTerminal := OpenAIGPTVirtualTerminal{ - OpenAPIChatGPTSecretKey: "sdjdnklfjndslkjanfk", - client: client, - } + openAIGPTVirtualTerminal := Init(make([]History, 0), "sdjdnklfjndslkjanfk") + openAIGPTVirtualTerminal.client = client //When str, err := openAIGPTVirtualTerminal.GetCompletions("ls") @@ -80,3 +86,31 @@ func TestBuildGetCompletions(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "prova.txt", str) } + +func TestBuildGetCompletionsWithoutResults(t *testing.T) { + client := resty.New() + httpmock.ActivateNonDefault(client.GetClient()) + defer httpmock.DeactivateAndReset() + + // Given + httpmock.RegisterResponder("POST", openAIGPTEndpoint, + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, &gptResponse{ + Choices: []Choice{}, + }) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + openAIGPTVirtualTerminal := Init(make([]History, 0), "sdjdnklfjndslkjanfk") + openAIGPTVirtualTerminal.client = client + + //When + _, err := openAIGPTVirtualTerminal.GetCompletions("ls") + + //Then + assert.Equal(t, "no choices", err.Error()) +} diff --git a/protocols/protocol_manager.go b/protocols/protocol_manager.go index 00044fe..f6c6525 100644 --- a/protocols/protocol_manager.go +++ b/protocols/protocol_manager.go @@ -16,7 +16,7 @@ type ProtocolManager struct { func InitProtocolManager(tracerStrategy tracer.Strategy, strategy ServiceStrategy) *ProtocolManager { return &ProtocolManager{ - tracer: tracer.Init(tracerStrategy), + tracer: tracer.GetInstance(tracerStrategy), strategy: strategy, } } diff --git a/protocols/strategies/ssh.go b/protocols/strategies/ssh.go index 590c8b7..779d43d 100644 --- a/protocols/strategies/ssh.go +++ b/protocols/strategies/ssh.go @@ -63,8 +63,7 @@ func (sshStrategy *SSHStrategy) Init(beelzebubServiceConfiguration parser.Beelze commandOutput := command.Handler if command.Plugin == plugins.ChatGPTPluginName { - openAIGPTVirtualTerminal := plugins.OpenAIGPTVirtualTerminal{Histories: histories, OpenAPIChatGPTSecretKey: beelzebubServiceConfiguration.Plugin.OpenAPIChatGPTSecretKey} - openAIGPTVirtualTerminal.InjectDependency() + openAIGPTVirtualTerminal := plugins.Init(histories, beelzebubServiceConfiguration.Plugin.OpenAPIChatGPTSecretKey) if commandOutput, err = openAIGPTVirtualTerminal.GetCompletions(commandInput); err != nil { log.Errorf("Error GetCompletions: %s, %s", commandInput, err.Error()) @@ -125,7 +124,7 @@ func (sshStrategy *SSHStrategy) Init(beelzebubServiceConfiguration parser.Beelze log.WithFields(log.Fields{ "port": beelzebubServiceConfiguration.Address, "commands": len(beelzebubServiceConfiguration.Commands), - }).Infof("Init service %s", beelzebubServiceConfiguration.Protocol) + }).Infof("GetInstance service %s", beelzebubServiceConfiguration.Protocol) return nil } diff --git a/tracer/tracer.go b/tracer/tracer.go index 8dda6fc..e511bd7 100644 --- a/tracer/tracer.go +++ b/tracer/tracer.go @@ -1,11 +1,12 @@ package tracer import ( + log "github.com/sirupsen/logrus" + "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - log "github.com/sirupsen/logrus" ) const Workers = 5 @@ -44,8 +45,8 @@ const ( TCP ) -func (status Protocol) String() string { - return [...]string{"HTTP", "SSH", "TCP"}[status] +func (protocol Protocol) String() string { + return [...]string{"HTTP", "SSH", "TCP"}[protocol] } const ( @@ -66,49 +67,60 @@ type Tracer interface { } type tracer struct { - strategy Strategy - eventsChan chan Event + strategy Strategy + eventsChan chan Event + eventsTotal prometheus.Counter + eventsSSHTotal prometheus.Counter + eventsTCPTotal prometheus.Counter + eventsHTTPTotal prometheus.Counter } -var ( - eventsTotal = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: "beelzebub", - Name: "events_total", - Help: "The total number of events", - }) - eventsSSHTotal = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: "beelzebub", - Name: "ssh_events_total", - Help: "The total number of SSH events", - }) - eventsTCPTotal = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: "beelzebub", - Name: "tcp_events_total", - Help: "The total number of TCP events", - }) - eventsHTTPTotal = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: "beelzebub", - Name: "http_events_total", - Help: "The total number of HTTP events", - }) -) +var lock = &sync.Mutex{} +var singleton *tracer -func Init(strategy Strategy) *tracer { - tracer := &tracer{ - strategy: strategy, - eventsChan: make(chan Event, Workers), - } - - for i := 0; i < Workers; i++ { - go func(i int) { - log.Debug("Init trace worker: ", i) - for event := range tracer.eventsChan { - tracer.strategy(event) +func GetInstance(strategy Strategy) *tracer { + if singleton == nil { + lock.Lock() + defer lock.Unlock() + // This is to prevent expensive lock operations every time the GetInstance method is called + if singleton == nil { + singleton = &tracer{ + strategy: strategy, + eventsChan: make(chan Event, Workers), + eventsTotal: promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "beelzebub", + Name: "events_total", + Help: "The total number of events", + }), + eventsSSHTotal: promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "beelzebub", + Name: "ssh_events_total", + Help: "The total number of SSH events", + }), + eventsTCPTotal: promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "beelzebub", + Name: "tcp_events_total", + Help: "The total number of TCP events", + }), + eventsHTTPTotal: promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "beelzebub", + Name: "http_events_total", + Help: "The total number of HTTP events", + }), } - }(i) + + for i := 0; i < Workers; i++ { + go func(i int) { + log.Debug("GetInstance trace worker: ", i) + for event := range singleton.eventsChan { + singleton.strategy(event) + } + }(i) + } + } } - return tracer + return singleton } func (tracer *tracer) setStrategy(strategy Strategy) { @@ -120,14 +132,17 @@ func (tracer *tracer) TraceEvent(event Event) { tracer.eventsChan <- event - eventsTotal.Inc() - - switch event.Protocol { - case HTTP.String(): - eventsHTTPTotal.Inc() - case SSH.String(): - eventsSSHTotal.Inc() - case TCP.String(): - eventsTCPTotal.Inc() - } + tracer.updatePrometheusCounters(event.Protocol) +} + +func (tracer *tracer) updatePrometheusCounters(protocol string) { + switch protocol { + case HTTP.String(): + tracer.eventsHTTPTotal.Inc() + case SSH.String(): + tracer.eventsSSHTotal.Inc() + case TCP.String(): + tracer.eventsTCPTotal.Inc() + } + tracer.eventsTotal.Inc() } diff --git a/tracer/tracer_test.go b/tracer/tracer_test.go index dd03339..2abf99a 100644 --- a/tracer/tracer_test.go +++ b/tracer/tracer_test.go @@ -1,6 +1,7 @@ package tracer import ( + "github.com/prometheus/client_golang/prometheus" "sync" "testing" @@ -10,7 +11,7 @@ import ( func TestInit(t *testing.T) { mockStrategy := func(event Event) {} - tracer := Init(mockStrategy) + tracer := GetInstance(mockStrategy) assert.NotNil(t, tracer.strategy) } @@ -25,7 +26,9 @@ func TestTraceEvent(t *testing.T) { eventCalled = event } - tracer := Init(mockStrategy) + tracer := GetInstance(mockStrategy) + + tracer.strategy = mockStrategy wg.Add(1) tracer.TraceEvent(Event{ @@ -51,7 +54,7 @@ func TestSetStrategy(t *testing.T) { eventCalled = event } - tracer := Init(mockStrategy) + tracer := GetInstance(mockStrategy) tracer.setStrategy(mockStrategy) @@ -75,3 +78,42 @@ func TestStringStatus(t *testing.T) { assert.Equal(t, Stateless.String(), "Stateless") assert.Equal(t, Interaction.String(), "Interaction") } + +type mockCounter struct { + prometheus.Metric + prometheus.Collector + inc func() + add func(float64) +} + +var counter = 0 + +func (m mockCounter) Inc() { + counter += 1 +} + +func (m mockCounter) Add(f float64) { + counter = int(f) +} + +func TestUpdatePrometheusCounters(t *testing.T) { + mockStrategy := func(event Event) {} + + tracer := &tracer{ + strategy: mockStrategy, + eventsChan: make(chan Event, Workers), + eventsTotal: mockCounter{}, + eventsSSHTotal: mockCounter{}, + eventsTCPTotal: mockCounter{}, + eventsHTTPTotal: mockCounter{}, + } + + tracer.updatePrometheusCounters(SSH.String()) + assert.Equal(t, 2, counter) + + tracer.updatePrometheusCounters(HTTP.String()) + assert.Equal(t, 4, counter) + + tracer.updatePrometheusCounters(TCP.String()) + assert.Equal(t, 6, counter) +}