mirror of
https://github.com/mariocandela/beelzebub.git
synced 2025-07-01 18:47:26 -04:00
Feat: Add FallbackCommand for HTTP Strategy, refactor packages strategies (#175)
Add FallbackCommand for HTTP Strategy, refactor packages strategies, improve histories implementations.
This commit is contained in:
242
protocols/strategies/SSH/ssh.go
Normal file
242
protocols/strategies/SSH/ssh.go
Normal file
@ -0,0 +1,242 @@
|
||||
package SSH
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mariocandela/beelzebub/v3/historystore"
|
||||
"github.com/mariocandela/beelzebub/v3/parser"
|
||||
"github.com/mariocandela/beelzebub/v3/plugins"
|
||||
"github.com/mariocandela/beelzebub/v3/tracer"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type SSHStrategy struct {
|
||||
Sessions *historystore.HistoryStore
|
||||
}
|
||||
|
||||
func (sshStrategy *SSHStrategy) Init(servConf parser.BeelzebubServiceConfiguration, tr tracer.Tracer) error {
|
||||
if sshStrategy.Sessions == nil {
|
||||
sshStrategy.Sessions = historystore.NewHistoryStore()
|
||||
}
|
||||
go func() {
|
||||
server := &ssh.Server{
|
||||
Addr: servConf.Address,
|
||||
MaxTimeout: time.Duration(servConf.DeadlineTimeoutSeconds) * time.Second,
|
||||
IdleTimeout: time.Duration(servConf.DeadlineTimeoutSeconds) * time.Second,
|
||||
Version: servConf.ServerVersion,
|
||||
Handler: func(sess ssh.Session) {
|
||||
uuidSession := uuid.New()
|
||||
|
||||
host, port, _ := net.SplitHostPort(sess.RemoteAddr().String())
|
||||
sessionKey := "SSH" + host + sess.User()
|
||||
|
||||
// Inline SSH command
|
||||
if sess.RawCommand() != "" {
|
||||
for _, command := range servConf.Commands {
|
||||
matched, err := regexp.MatchString(command.Regex, sess.RawCommand())
|
||||
if err != nil {
|
||||
log.Errorf("error regex: %s, %s", command.Regex, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if matched {
|
||||
commandOutput := command.Handler
|
||||
if command.Plugin == plugins.LLMPluginName {
|
||||
llmProvider, err := plugins.FromStringToLLMProvider(servConf.Plugin.LLMProvider)
|
||||
if err != nil {
|
||||
log.Errorf("error: %s", err.Error())
|
||||
commandOutput = "command not found"
|
||||
llmProvider = plugins.OpenAI
|
||||
}
|
||||
|
||||
var histories []plugins.Message
|
||||
if sshStrategy.Sessions.HasKey(sessionKey) {
|
||||
histories = sshStrategy.Sessions.Query(sessionKey)
|
||||
}
|
||||
llmHoneypot := plugins.LLMHoneypot{
|
||||
Histories: histories,
|
||||
OpenAIKey: servConf.Plugin.OpenAISecretKey,
|
||||
Protocol: tracer.SSH,
|
||||
Host: servConf.Plugin.Host,
|
||||
Model: servConf.Plugin.LLMModel,
|
||||
Provider: llmProvider,
|
||||
CustomPrompt: servConf.Plugin.Prompt,
|
||||
}
|
||||
llmHoneypotInstance := plugins.InitLLMHoneypot(llmHoneypot)
|
||||
if commandOutput, err = llmHoneypotInstance.ExecuteModel(sess.RawCommand()); err != nil {
|
||||
log.Errorf("error ExecuteModel: %s, %s", sess.RawCommand(), err.Error())
|
||||
commandOutput = "command not found"
|
||||
}
|
||||
}
|
||||
|
||||
sess.Write(append([]byte(commandOutput), '\n'))
|
||||
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "New SSH Raw Command Session",
|
||||
Protocol: tracer.SSH.String(),
|
||||
RemoteAddr: sess.RemoteAddr().String(),
|
||||
SourceIp: host,
|
||||
SourcePort: port,
|
||||
Status: tracer.Start.String(),
|
||||
ID: uuidSession.String(),
|
||||
Environ: strings.Join(sess.Environ(), ","),
|
||||
User: sess.User(),
|
||||
Description: servConf.Description,
|
||||
Command: sess.RawCommand(),
|
||||
CommandOutput: commandOutput,
|
||||
})
|
||||
|
||||
var histories []plugins.Message
|
||||
if sshStrategy.Sessions.HasKey(sessionKey) {
|
||||
histories = sshStrategy.Sessions.Query(sessionKey)
|
||||
}
|
||||
histories = append(histories, plugins.Message{Role: plugins.USER.String(), Content: sess.RawCommand()})
|
||||
histories = append(histories, plugins.Message{Role: plugins.ASSISTANT.String(), Content: commandOutput})
|
||||
sshStrategy.Sessions.Append(sessionKey, histories...)
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "End SSH Raw Command Session",
|
||||
Status: tracer.End.String(),
|
||||
ID: uuidSession.String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "New SSH Terminal Session",
|
||||
Protocol: tracer.SSH.String(),
|
||||
RemoteAddr: sess.RemoteAddr().String(),
|
||||
SourceIp: host,
|
||||
SourcePort: port,
|
||||
Status: tracer.Start.String(),
|
||||
ID: uuidSession.String(),
|
||||
Environ: strings.Join(sess.Environ(), ","),
|
||||
User: sess.User(),
|
||||
Description: servConf.Description,
|
||||
})
|
||||
|
||||
terminal := term.NewTerminal(sess, buildPrompt(sess.User(), servConf.ServerName))
|
||||
var histories []plugins.Message
|
||||
if sshStrategy.Sessions.HasKey(sessionKey) {
|
||||
histories = sshStrategy.Sessions.Query(sessionKey)
|
||||
}
|
||||
|
||||
for {
|
||||
commandInput, err := terminal.ReadLine()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if commandInput == "exit" {
|
||||
break
|
||||
}
|
||||
for _, command := range servConf.Commands {
|
||||
matched, err := regexp.MatchString(command.Regex, commandInput)
|
||||
if err != nil {
|
||||
log.Errorf("error regex: %s, %s", command.Regex, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if matched {
|
||||
commandOutput := command.Handler
|
||||
if command.Plugin == plugins.LLMPluginName {
|
||||
llmProvider, err := plugins.FromStringToLLMProvider(servConf.Plugin.LLMProvider)
|
||||
if err != nil {
|
||||
log.Errorf("error: %s, fallback OpenAI", err.Error())
|
||||
llmProvider = plugins.OpenAI
|
||||
}
|
||||
llmHoneypot := plugins.LLMHoneypot{
|
||||
Histories: histories,
|
||||
OpenAIKey: servConf.Plugin.OpenAISecretKey,
|
||||
Protocol: tracer.SSH,
|
||||
Host: servConf.Plugin.Host,
|
||||
Model: servConf.Plugin.LLMModel,
|
||||
Provider: llmProvider,
|
||||
CustomPrompt: servConf.Plugin.Prompt,
|
||||
}
|
||||
llmHoneypotInstance := plugins.InitLLMHoneypot(llmHoneypot)
|
||||
if commandOutput, err = llmHoneypotInstance.ExecuteModel(commandInput); err != nil {
|
||||
log.Errorf("error ExecuteModel: %s, %s", commandInput, err.Error())
|
||||
commandOutput = "command not found"
|
||||
}
|
||||
}
|
||||
|
||||
histories = append(histories, plugins.Message{Role: plugins.USER.String(), Content: commandInput})
|
||||
histories = append(histories, plugins.Message{Role: plugins.ASSISTANT.String(), Content: commandOutput})
|
||||
|
||||
terminal.Write(append([]byte(commandOutput), '\n'))
|
||||
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "New SSH Terminal Session",
|
||||
RemoteAddr: sess.RemoteAddr().String(),
|
||||
SourceIp: host,
|
||||
SourcePort: port,
|
||||
Status: tracer.Interaction.String(),
|
||||
Command: commandInput,
|
||||
CommandOutput: commandOutput,
|
||||
ID: uuidSession.String(),
|
||||
Protocol: tracer.SSH.String(),
|
||||
Description: servConf.Description,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add all history events for the terminal session to the store.
|
||||
// This is done at the end of the session to avoid excess lock operations.
|
||||
sshStrategy.Sessions.Append(sessionKey, histories...)
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "End SSH Session",
|
||||
Status: tracer.End.String(),
|
||||
ID: uuidSession.String(),
|
||||
})
|
||||
},
|
||||
PasswordHandler: func(ctx ssh.Context, password string) bool {
|
||||
host, port, _ := net.SplitHostPort(ctx.RemoteAddr().String())
|
||||
|
||||
tr.TraceEvent(tracer.Event{
|
||||
Msg: "New SSH Login Attempt",
|
||||
Protocol: tracer.SSH.String(),
|
||||
Status: tracer.Stateless.String(),
|
||||
User: ctx.User(),
|
||||
Password: password,
|
||||
Client: ctx.ClientVersion(),
|
||||
RemoteAddr: ctx.RemoteAddr().String(),
|
||||
SourceIp: host,
|
||||
SourcePort: port,
|
||||
ID: uuid.New().String(),
|
||||
Description: servConf.Description,
|
||||
})
|
||||
matched, err := regexp.MatchString(servConf.PasswordRegex, password)
|
||||
if err != nil {
|
||||
log.Errorf("error regex: %s, %s", servConf.PasswordRegex, err.Error())
|
||||
return false
|
||||
}
|
||||
return matched
|
||||
},
|
||||
}
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("error during init SSH Protocol: %s", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"port": servConf.Address,
|
||||
"commands": len(servConf.Commands),
|
||||
}).Infof("GetInstance service %s", servConf.Protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildPrompt(user string, serverName string) string {
|
||||
return fmt.Sprintf("%s@%s:~$ ", user, serverName)
|
||||
}
|
Reference in New Issue
Block a user