mirror of
https://github.com/dstotijn/hetty.git
synced 2025-07-01 18:47:29 -04:00
Add request filter for intercept
This commit is contained in:
229
pkg/proxy/intercept/filter.go
Normal file
229
pkg/proxy/intercept/filter.go
Normal file
@ -0,0 +1,229 @@
|
||||
package intercept
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
)
|
||||
|
||||
var reqFilterKeyFns = map[string]func(req *http.Request) (string, error){
|
||||
"proto": func(req *http.Request) (string, error) { return req.Proto, nil },
|
||||
"url": func(req *http.Request) (string, error) {
|
||||
if req.URL == nil {
|
||||
return "", nil
|
||||
}
|
||||
return req.URL.String(), nil
|
||||
},
|
||||
"method": func(req *http.Request) (string, error) { return req.Method, nil },
|
||||
"body": func(req *http.Request) (string, error) {
|
||||
if req.Body == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
return string(body), nil
|
||||
},
|
||||
}
|
||||
|
||||
// MatchRequestFilter returns true if an HTTP request matches the request filter expression.
|
||||
func MatchRequestFilter(req *http.Request, expr search.Expression) (bool, error) {
|
||||
switch e := expr.(type) {
|
||||
case search.PrefixExpression:
|
||||
return matchReqPrefixExpr(req, e)
|
||||
case search.InfixExpression:
|
||||
return matchReqInfixExpr(req, e)
|
||||
case search.StringLiteral:
|
||||
return matchReqStringLiteral(req, e)
|
||||
default:
|
||||
return false, fmt.Errorf("expression type (%T) not supported", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func matchReqPrefixExpr(req *http.Request, expr search.PrefixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpNot:
|
||||
match, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return !match, nil
|
||||
default:
|
||||
return false, errors.New("operator is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func matchReqInfixExpr(req *http.Request, expr search.InfixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpAnd:
|
||||
left, err := MatchRequestFilter(req, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left && right, nil
|
||||
case search.TokOpOr:
|
||||
left, err := MatchRequestFilter(req, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left || right, nil
|
||||
}
|
||||
|
||||
left, ok := expr.Left.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("left operand must be a string literal")
|
||||
}
|
||||
|
||||
leftVal, err := getMappedStringLiteralFromReq(req, left.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from request for left operand: %w", err)
|
||||
}
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpRe:
|
||||
return right.MatchString(leftVal), nil
|
||||
case search.TokOpNotRe:
|
||||
return !right.MatchString(leftVal), nil
|
||||
}
|
||||
}
|
||||
|
||||
right, ok := expr.Right.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a string literal")
|
||||
}
|
||||
|
||||
rightVal, err := getMappedStringLiteralFromReq(req, right.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from request for right operand: %w", err)
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpEq:
|
||||
return leftVal == rightVal, nil
|
||||
case search.TokOpNotEq:
|
||||
return leftVal != rightVal, nil
|
||||
case search.TokOpGt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal > rightVal, nil
|
||||
case search.TokOpLt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal < rightVal, nil
|
||||
case search.TokOpGtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal >= rightVal, nil
|
||||
case search.TokOpLtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal <= rightVal, nil
|
||||
default:
|
||||
return false, errors.New("unsupported operator")
|
||||
}
|
||||
}
|
||||
|
||||
func getMappedStringLiteralFromReq(req *http.Request, s string) (string, error) {
|
||||
fn, ok := reqFilterKeyFns[s]
|
||||
if ok {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func matchReqStringLiteral(req *http.Request, strLiteral search.StringLiteral) (bool, error) {
|
||||
for _, fn := range reqFilterKeyFns {
|
||||
value, err := fn(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func MatchRequestScope(req *http.Request, s *scope.Scope) (bool, error) {
|
||||
for _, rule := range s.Rules() {
|
||||
if rule.URL != nil && req.URL != nil {
|
||||
if matches := rule.URL.MatchString(req.URL.String()); matches {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range req.Header {
|
||||
var keyMatches, valueMatches bool
|
||||
|
||||
if rule.Header.Key != nil {
|
||||
if matches := rule.Header.Key.MatchString(key); matches {
|
||||
keyMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
if rule.Header.Value != nil {
|
||||
for _, value := range values {
|
||||
if matches := rule.Header.Value.MatchString(value); matches {
|
||||
valueMatches = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When only key or value is set, match on whatever is set.
|
||||
// When both are set, both must match.
|
||||
switch {
|
||||
case rule.Header.Key != nil && rule.Header.Value == nil && keyMatches:
|
||||
return true, nil
|
||||
case rule.Header.Key == nil && rule.Header.Value != nil && valueMatches:
|
||||
return true, nil
|
||||
case rule.Header.Key != nil && rule.Header.Value != nil && keyMatches && valueMatches:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if rule.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
if matches := rule.Body.Match(body); matches {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
@ -3,6 +3,7 @@ package intercept
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/log"
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -28,15 +30,17 @@ type Request struct {
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
mu *sync.RWMutex
|
||||
requests map[ulid.ULID]Request
|
||||
logger log.Logger
|
||||
enabled bool
|
||||
mu *sync.RWMutex
|
||||
requests map[ulid.ULID]Request
|
||||
logger log.Logger
|
||||
enabled bool
|
||||
reqFilter search.Expression
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger log.Logger
|
||||
Enabled bool
|
||||
Logger log.Logger
|
||||
Enabled bool
|
||||
RequestFilter search.Expression
|
||||
}
|
||||
|
||||
// RequestIDs implements sort.Interface.
|
||||
@ -44,10 +48,11 @@ type RequestIDs []ulid.ULID
|
||||
|
||||
func NewService(cfg Config) *Service {
|
||||
s := &Service{
|
||||
mu: &sync.RWMutex{},
|
||||
requests: make(map[ulid.ULID]Request),
|
||||
logger: cfg.Logger,
|
||||
enabled: cfg.Enabled,
|
||||
mu: &sync.RWMutex{},
|
||||
requests: make(map[ulid.ULID]Request),
|
||||
logger: cfg.Logger,
|
||||
enabled: cfg.Enabled,
|
||||
reqFilter: cfg.RequestFilter,
|
||||
}
|
||||
|
||||
if s.logger == nil {
|
||||
@ -102,6 +107,20 @@ func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Req
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if svc.reqFilter != nil {
|
||||
match, err := MatchRequestFilter(req, svc.reqFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
|
||||
reqID.String(), err,
|
||||
)
|
||||
}
|
||||
|
||||
if !match {
|
||||
svc.logger.Debugw("Bypassed interception: request rules don't match.")
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan *http.Request)
|
||||
done := make(chan struct{})
|
||||
|
||||
@ -197,6 +216,7 @@ func (svc *Service) UpdateSettings(settings Settings) {
|
||||
}
|
||||
|
||||
svc.enabled = settings.Enabled
|
||||
svc.reqFilter = settings.RequestFilter
|
||||
}
|
||||
|
||||
// Request returns an intercepted request by ID. It's safe for concurrent use.
|
||||
|
@ -1,5 +1,8 @@
|
||||
package intercept
|
||||
|
||||
import "github.com/dstotijn/hetty/pkg/search"
|
||||
|
||||
type Settings struct {
|
||||
Enabled bool
|
||||
Enabled bool
|
||||
RequestFilter search.Expression
|
||||
}
|
||||
|
Reference in New Issue
Block a user