Add intercept response filter

This commit is contained in:
David Stotijn
2022-03-21 10:33:11 +01:00
parent cf55456c42
commit 89141afd3b
15 changed files with 556 additions and 113 deletions

View File

@ -7,6 +7,7 @@ import (
"io"
"io/ioutil"
"net/http"
"strconv"
"strings"
"github.com/dstotijn/hetty/pkg/scope"
@ -38,6 +39,34 @@ var reqFilterKeyFns = map[string]func(req *http.Request) (string, error){
},
}
//nolint:unparam
var resFilterKeyFns = map[string]func(res *http.Response) (string, error){
"proto": func(res *http.Response) (string, error) { return res.Proto, nil },
"statusCode": func(res *http.Response) (string, error) { return strconv.Itoa(res.StatusCode), nil },
"statusReason": func(res *http.Response) (string, error) {
statusReasonSubs := strings.SplitN(res.Status, " ", 2)
if len(statusReasonSubs) != 2 {
return "", fmt.Errorf("invalid response status %q", res.Status)
}
return statusReasonSubs[1], nil
},
"body": func(res *http.Response) (string, error) {
if res.Body == nil {
return "", nil
}
body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return string(body), nil
},
}
// MatchRequestFilter returns true if an HTTP request matches the request filter expression.
func MatchRequestFilter(req *http.Request, expr search.Expression) (bool, error) {
switch e := expr.(type) {
@ -228,3 +257,139 @@ func MatchRequestScope(req *http.Request, s *scope.Scope) (bool, error) {
return false, nil
}
// MatchResponseFilter returns true if an HTTP response matches the response filter expression.
func MatchResponseFilter(res *http.Response, expr search.Expression) (bool, error) {
switch e := expr.(type) {
case search.PrefixExpression:
return matchResPrefixExpr(res, e)
case search.InfixExpression:
return matchResInfixExpr(res, e)
case search.StringLiteral:
return matchResStringLiteral(res, e)
default:
return false, fmt.Errorf("expression type (%T) not supported", expr)
}
}
func matchResPrefixExpr(res *http.Response, expr search.PrefixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpNot:
match, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return !match, nil
default:
return false, errors.New("operator is not supported")
}
}
func matchResInfixExpr(res *http.Response, expr search.InfixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := MatchResponseFilter(res, expr.Left)
if err != nil {
return false, err
}
right, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return left && right, nil
case search.TokOpOr:
left, err := MatchResponseFilter(res, expr.Left)
if err != nil {
return false, err
}
right, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return left || right, nil
}
left, ok := expr.Left.(search.StringLiteral)
if !ok {
return false, errors.New("left operand must be a string literal")
}
leftVal, err := getMappedStringLiteralFromRes(res, left.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from response for left operand: %w", err)
}
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(search.RegexpLiteral)
if !ok {
return false, errors.New("right operand must be a regular expression")
}
switch expr.Operator {
case search.TokOpRe:
return right.MatchString(leftVal), nil
case search.TokOpNotRe:
return !right.MatchString(leftVal), nil
}
}
right, ok := expr.Right.(search.StringLiteral)
if !ok {
return false, errors.New("right operand must be a string literal")
}
rightVal, err := getMappedStringLiteralFromRes(res, right.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from response for right operand: %w", err)
}
switch expr.Operator {
case search.TokOpEq:
return leftVal == rightVal, nil
case search.TokOpNotEq:
return leftVal != rightVal, nil
case search.TokOpGt:
// TODO(?) attempt to parse as int.
return leftVal > rightVal, nil
case search.TokOpLt:
// TODO(?) attempt to parse as int.
return leftVal < rightVal, nil
case search.TokOpGtEq:
// TODO(?) attempt to parse as int.
return leftVal >= rightVal, nil
case search.TokOpLtEq:
// TODO(?) attempt to parse as int.
return leftVal <= rightVal, nil
default:
return false, errors.New("unsupported operator")
}
}
func getMappedStringLiteralFromRes(res *http.Response, s string) (string, error) {
fn, ok := resFilterKeyFns[s]
if ok {
return fn(res)
}
return s, nil
}
func matchResStringLiteral(res *http.Response, strLiteral search.StringLiteral) (bool, error) {
for _, fn := range resFilterKeyFns {
value, err := fn(res)
if err != nil {
return false, err
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
return true, nil
}
}
return false, nil
}

View File

@ -53,14 +53,19 @@ type Service struct {
requests map[ulid.ULID]Request
responses map[ulid.ULID]Response
logger log.Logger
enabled bool
reqFilter search.Expression
requestsEnabled bool
responsesEnabled bool
reqFilter search.Expression
resFilter search.Expression
}
type Config struct {
Logger log.Logger
Enabled bool
RequestFilter search.Expression
Logger log.Logger
RequestsEnabled bool
ResponsesEnabled bool
RequestFilter search.Expression
ResponseFilter search.Expression
}
// RequestIDs implements sort.Interface.
@ -68,13 +73,15 @@ type RequestIDs []ulid.ULID
func NewService(cfg Config) *Service {
s := &Service{
reqMu: &sync.RWMutex{},
resMu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
responses: make(map[ulid.ULID]Response),
logger: cfg.Logger,
enabled: cfg.Enabled,
reqFilter: cfg.RequestFilter,
reqMu: &sync.RWMutex{},
resMu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
responses: make(map[ulid.ULID]Response),
logger: cfg.Logger,
requestsEnabled: cfg.RequestsEnabled,
responsesEnabled: cfg.ResponsesEnabled,
reqFilter: cfg.RequestFilter,
resFilter: cfg.ResponseFilter,
}
if s.logger == nil {
@ -122,7 +129,7 @@ func (svc *Service) InterceptRequest(ctx context.Context, req *http.Request) (*h
return req, nil
}
if !svc.enabled {
if !svc.requestsEnabled {
// If request intercept is disabled, return the incoming request as-is.
svc.logger.Debugw("Bypassed request interception: feature disabled.")
return req, nil
@ -267,14 +274,20 @@ func (svc *Service) Items() []Item {
}
func (svc *Service) UpdateSettings(settings Settings) {
// When updating from `enabled` -> `disabled`, clear any pending reqs.
if svc.enabled && !settings.Enabled {
// When updating from requests `enabled` -> `disabled`, clear any pending reqs.
if svc.requestsEnabled && !settings.RequestsEnabled {
svc.ClearRequests()
}
// When updating from responses `enabled` -> `disabled`, clear any pending responses.
if svc.responsesEnabled && !settings.ResponsesEnabled {
svc.ClearResponses()
}
svc.enabled = settings.Enabled
svc.requestsEnabled = settings.RequestsEnabled
svc.responsesEnabled = settings.ResponsesEnabled
svc.reqFilter = settings.RequestFilter
svc.resFilter = settings.ResponseFilter
}
// ItemByID returns an intercepted item (request and possible response) by ID. It's safe for concurrent use.
@ -358,25 +371,25 @@ func (svc *Service) InterceptResponse(ctx context.Context, res *http.Response) (
return res, nil
}
if !svc.enabled {
// If the feature is disabled, return the response as-is.
// If global response intercept is disabled and interception is *not* explicitly enabled for this response: bypass.
if !svc.responsesEnabled && !(ok && shouldIntercept) {
svc.logger.Debugw("Bypassed response interception: feature disabled.")
return res, nil
}
// if svc.reqFilter != nil {
// match, err := MatchRequestFilter(req, svc.reqFilter)
// if err != nil {
// return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
// reqID.String(), err,
// )
// }
if svc.resFilter != nil {
match, err := MatchResponseFilter(res, svc.resFilter)
if err != nil {
return nil, fmt.Errorf("intercept: failed to match response rules for response (id: %v): %w",
reqID.String(), err,
)
}
// if !match {
// svc.logger.Debugw("Bypassed interception: request rules don't match.")
// return req, nil
// }
// }
if !match {
svc.logger.Debugw("Bypassed response interception: response rules don't match.")
return res, nil
}
}
ch := make(chan *http.Response)
done := make(chan struct{})

View File

@ -28,8 +28,9 @@ func TestRequestModifier(t *testing.T) {
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
Enabled: true,
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
@ -45,8 +46,9 @@ func TestRequestModifier(t *testing.T) {
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
Enabled: true,
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
ctx, cancel := context.WithCancel(context.Background())
@ -88,8 +90,9 @@ func TestRequestModifier(t *testing.T) {
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
Enabled: true,
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
var got *http.Request

View File

@ -3,6 +3,8 @@ package intercept
import "github.com/dstotijn/hetty/pkg/search"
type Settings struct {
Enabled bool
RequestFilter search.Expression
RequestsEnabled bool
ResponsesEnabled bool
RequestFilter search.Expression
ResponseFilter search.Expression
}