Add intercept module

This commit is contained in:
David Stotijn
2022-03-23 14:31:27 +01:00
parent 6ffc55cde3
commit 02408b5196
51 changed files with 5779 additions and 304 deletions

35
pkg/proxy/gzip.go Normal file
View File

@ -0,0 +1,35 @@
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
)
func gunzipResponseBody(res *http.Response) error {
if res.Header.Get("Content-Encoding") != "gzip" {
return nil
}
gzipReader, err := gzip.NewReader(res.Body)
if err != nil {
return fmt.Errorf("proxy: could not create gzip reader: %w", err)
}
defer gzipReader.Close()
buf := &bytes.Buffer{}
//nolint:gosec
if _, err := io.Copy(buf, gzipReader); err != nil {
return fmt.Errorf("proxy: could not read gzipped response body: %w", err)
}
res.Body = io.NopCloser(buf)
res.Header.Del("Content-Encoding")
res.Header.Set("Content-Length", fmt.Sprint(buf.Len()))
res.ContentLength = int64(buf.Len())
return nil
}

View File

@ -0,0 +1,395 @@
package intercept
import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strconv"
"strings"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
//nolint:unparam
var reqFilterKeyFns = map[string]func(req *http.Request) (string, error){
"proto": func(req *http.Request) (string, error) { return req.Proto, nil },
"url": func(req *http.Request) (string, error) {
if req.URL == nil {
return "", nil
}
return req.URL.String(), nil
},
"method": func(req *http.Request) (string, error) { return req.Method, nil },
"body": func(req *http.Request) (string, error) {
if req.Body == nil {
return "", nil
}
body, err := io.ReadAll(req.Body)
if err != nil {
return "", err
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return string(body), nil
},
}
//nolint:unparam
var resFilterKeyFns = map[string]func(res *http.Response) (string, error){
"proto": func(res *http.Response) (string, error) { return res.Proto, nil },
"statusCode": func(res *http.Response) (string, error) { return strconv.Itoa(res.StatusCode), nil },
"statusReason": func(res *http.Response) (string, error) {
statusReasonSubs := strings.SplitN(res.Status, " ", 2)
if len(statusReasonSubs) != 2 {
return "", fmt.Errorf("invalid response status %q", res.Status)
}
return statusReasonSubs[1], nil
},
"body": func(res *http.Response) (string, error) {
if res.Body == nil {
return "", nil
}
body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return string(body), nil
},
}
// MatchRequestFilter returns true if an HTTP request matches the request filter expression.
func MatchRequestFilter(req *http.Request, expr search.Expression) (bool, error) {
switch e := expr.(type) {
case search.PrefixExpression:
return matchReqPrefixExpr(req, e)
case search.InfixExpression:
return matchReqInfixExpr(req, e)
case search.StringLiteral:
return matchReqStringLiteral(req, e)
default:
return false, fmt.Errorf("expression type (%T) not supported", expr)
}
}
func matchReqPrefixExpr(req *http.Request, expr search.PrefixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpNot:
match, err := MatchRequestFilter(req, expr.Right)
if err != nil {
return false, err
}
return !match, nil
default:
return false, errors.New("operator is not supported")
}
}
func matchReqInfixExpr(req *http.Request, expr search.InfixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := MatchRequestFilter(req, expr.Left)
if err != nil {
return false, err
}
right, err := MatchRequestFilter(req, expr.Right)
if err != nil {
return false, err
}
return left && right, nil
case search.TokOpOr:
left, err := MatchRequestFilter(req, expr.Left)
if err != nil {
return false, err
}
right, err := MatchRequestFilter(req, expr.Right)
if err != nil {
return false, err
}
return left || right, nil
}
left, ok := expr.Left.(search.StringLiteral)
if !ok {
return false, errors.New("left operand must be a string literal")
}
leftVal, err := getMappedStringLiteralFromReq(req, left.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from request for left operand: %w", err)
}
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(search.RegexpLiteral)
if !ok {
return false, errors.New("right operand must be a regular expression")
}
switch expr.Operator {
case search.TokOpRe:
return right.MatchString(leftVal), nil
case search.TokOpNotRe:
return !right.MatchString(leftVal), nil
}
}
right, ok := expr.Right.(search.StringLiteral)
if !ok {
return false, errors.New("right operand must be a string literal")
}
rightVal, err := getMappedStringLiteralFromReq(req, right.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from request for right operand: %w", err)
}
switch expr.Operator {
case search.TokOpEq:
return leftVal == rightVal, nil
case search.TokOpNotEq:
return leftVal != rightVal, nil
case search.TokOpGt:
// TODO(?) attempt to parse as int.
return leftVal > rightVal, nil
case search.TokOpLt:
// TODO(?) attempt to parse as int.
return leftVal < rightVal, nil
case search.TokOpGtEq:
// TODO(?) attempt to parse as int.
return leftVal >= rightVal, nil
case search.TokOpLtEq:
// TODO(?) attempt to parse as int.
return leftVal <= rightVal, nil
default:
return false, errors.New("unsupported operator")
}
}
func getMappedStringLiteralFromReq(req *http.Request, s string) (string, error) {
fn, ok := reqFilterKeyFns[s]
if ok {
return fn(req)
}
return s, nil
}
func matchReqStringLiteral(req *http.Request, strLiteral search.StringLiteral) (bool, error) {
for _, fn := range reqFilterKeyFns {
value, err := fn(req)
if err != nil {
return false, err
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
return true, nil
}
}
return false, nil
}
func MatchRequestScope(req *http.Request, s *scope.Scope) (bool, error) {
for _, rule := range s.Rules() {
if rule.URL != nil && req.URL != nil {
if matches := rule.URL.MatchString(req.URL.String()); matches {
return true, nil
}
}
for key, values := range req.Header {
var keyMatches, valueMatches bool
if rule.Header.Key != nil {
if matches := rule.Header.Key.MatchString(key); matches {
keyMatches = true
}
}
if rule.Header.Value != nil {
for _, value := range values {
if matches := rule.Header.Value.MatchString(value); matches {
valueMatches = true
break
}
}
}
// When only key or value is set, match on whatever is set.
// When both are set, both must match.
switch {
case rule.Header.Key != nil && rule.Header.Value == nil && keyMatches:
return true, nil
case rule.Header.Key == nil && rule.Header.Value != nil && valueMatches:
return true, nil
case rule.Header.Key != nil && rule.Header.Value != nil && keyMatches && valueMatches:
return true, nil
}
}
if rule.Body != nil {
body, err := io.ReadAll(req.Body)
if err != nil {
return false, fmt.Errorf("failed to read request body: %w", err)
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
if matches := rule.Body.Match(body); matches {
return true, nil
}
}
}
return false, nil
}
// MatchResponseFilter returns true if an HTTP response matches the response filter expression.
func MatchResponseFilter(res *http.Response, expr search.Expression) (bool, error) {
switch e := expr.(type) {
case search.PrefixExpression:
return matchResPrefixExpr(res, e)
case search.InfixExpression:
return matchResInfixExpr(res, e)
case search.StringLiteral:
return matchResStringLiteral(res, e)
default:
return false, fmt.Errorf("expression type (%T) not supported", expr)
}
}
func matchResPrefixExpr(res *http.Response, expr search.PrefixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpNot:
match, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return !match, nil
default:
return false, errors.New("operator is not supported")
}
}
func matchResInfixExpr(res *http.Response, expr search.InfixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := MatchResponseFilter(res, expr.Left)
if err != nil {
return false, err
}
right, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return left && right, nil
case search.TokOpOr:
left, err := MatchResponseFilter(res, expr.Left)
if err != nil {
return false, err
}
right, err := MatchResponseFilter(res, expr.Right)
if err != nil {
return false, err
}
return left || right, nil
}
left, ok := expr.Left.(search.StringLiteral)
if !ok {
return false, errors.New("left operand must be a string literal")
}
leftVal, err := getMappedStringLiteralFromRes(res, left.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from response for left operand: %w", err)
}
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(search.RegexpLiteral)
if !ok {
return false, errors.New("right operand must be a regular expression")
}
switch expr.Operator {
case search.TokOpRe:
return right.MatchString(leftVal), nil
case search.TokOpNotRe:
return !right.MatchString(leftVal), nil
}
}
right, ok := expr.Right.(search.StringLiteral)
if !ok {
return false, errors.New("right operand must be a string literal")
}
rightVal, err := getMappedStringLiteralFromRes(res, right.Value)
if err != nil {
return false, fmt.Errorf("failed to get string literal from response for right operand: %w", err)
}
switch expr.Operator {
case search.TokOpEq:
return leftVal == rightVal, nil
case search.TokOpNotEq:
return leftVal != rightVal, nil
case search.TokOpGt:
// TODO(?) attempt to parse as int.
return leftVal > rightVal, nil
case search.TokOpLt:
// TODO(?) attempt to parse as int.
return leftVal < rightVal, nil
case search.TokOpGtEq:
// TODO(?) attempt to parse as int.
return leftVal >= rightVal, nil
case search.TokOpLtEq:
// TODO(?) attempt to parse as int.
return leftVal <= rightVal, nil
default:
return false, errors.New("unsupported operator")
}
}
func getMappedStringLiteralFromRes(res *http.Response, s string) (string, error) {
fn, ok := resFilterKeyFns[s]
if ok {
return fn(res)
}
return s, nil
}
func matchResStringLiteral(res *http.Response, strLiteral search.StringLiteral) (bool, error) {
for _, fn := range resFilterKeyFns {
value, err := fn(res)
if err != nil {
return false, err
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
return true, nil
}
}
return false, nil
}

View File

@ -0,0 +1,452 @@
package intercept
import (
"context"
"errors"
"fmt"
"net/http"
"sort"
"sync"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/log"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/search"
)
var (
ErrRequestAborted = errors.New("intercept: request was aborted")
ErrRequestNotFound = errors.New("intercept: request not found")
ErrRequestDone = errors.New("intercept: request is done")
ErrResponseNotFound = errors.New("intercept: response not found")
)
type contextKey int
const interceptResponseKey contextKey = 0
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Request struct {
req *http.Request
ch chan<- *http.Request
done <-chan struct{}
}
// Response represents an HTTP response from a proxied request, alongside a channel for sending a modified version of it
// to the routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Response struct {
res *http.Response
ch chan<- *http.Response
done <-chan struct{}
}
type Item struct {
Request *http.Request
Response *http.Response
}
type Service struct {
reqMu *sync.RWMutex
resMu *sync.RWMutex
requests map[ulid.ULID]Request
responses map[ulid.ULID]Response
logger log.Logger
requestsEnabled bool
responsesEnabled bool
reqFilter search.Expression
resFilter search.Expression
}
type Config struct {
Logger log.Logger
RequestsEnabled bool
ResponsesEnabled bool
RequestFilter search.Expression
ResponseFilter search.Expression
}
// RequestIDs implements sort.Interface.
type RequestIDs []ulid.ULID
func NewService(cfg Config) *Service {
s := &Service{
reqMu: &sync.RWMutex{},
resMu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
responses: make(map[ulid.ULID]Response),
logger: cfg.Logger,
requestsEnabled: cfg.RequestsEnabled,
responsesEnabled: cfg.ResponsesEnabled,
reqFilter: cfg.RequestFilter,
resFilter: cfg.ResponseFilter,
}
if s.logger == nil {
s.logger = log.NewNopLogger()
}
return s
}
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP requests.
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
// (typically `context.Canceled`).
modifiedReq, err := svc.InterceptRequest(req.Context(), req)
switch {
case errors.Is(err, ErrRequestAborted):
svc.logger.Debugw("Stopping intercept, request was aborted.")
// Prevent further processing by replacing req.Context with a cancelled context value.
// This will cause the http.Roundtripper in the `proxy` package to
// handle this request as an error.
ctx, cancel := context.WithCancel(context.Background())
cancel()
*req = *req.WithContext(ctx)
case errors.Is(err, context.Canceled):
svc.logger.Debugw("Stopping intercept, context was cancelled.")
case err != nil:
svc.logger.Errorw("Failed to intercept request.",
"error", err)
default:
*req = *modifiedReq
next(req)
}
}
}
// InterceptRequest adds an HTTP request to an array of pending intercepted requests, alongside channels used for
// sending a cancellation signal and receiving a modified request. It's safe for concurrent use.
func (svc *Service) InterceptRequest(ctx context.Context, req *http.Request) (*http.Request, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
return req, nil
}
if !svc.requestsEnabled {
// If request intercept is disabled, return the incoming request as-is.
svc.logger.Debugw("Bypassed request interception: feature disabled.")
return req, nil
}
if svc.reqFilter != nil {
match, err := MatchRequestFilter(req, svc.reqFilter)
if err != nil {
return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
reqID.String(), err,
)
}
if !match {
svc.logger.Debugw("Bypassed request interception: request rules don't match.")
return req, nil
}
}
ch := make(chan *http.Request)
done := make(chan struct{})
svc.reqMu.Lock()
svc.requests[reqID] = Request{
req: req,
ch: ch,
done: done,
}
svc.reqMu.Unlock()
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the request should be removed from the requests queue.
defer func() {
close(done)
svc.reqMu.Lock()
defer svc.reqMu.Unlock()
delete(svc.requests, reqID)
}()
select {
case modReq := <-ch:
if modReq == nil {
return nil, ErrRequestAborted
}
return modReq, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
// cancelled. It's safe for concurrent use.
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request, modifyResponse *bool) error {
svc.reqMu.RLock()
req, ok := svc.requests[reqID]
svc.reqMu.RUnlock()
if !ok {
return ErrRequestNotFound
}
*modReq = *modReq.WithContext(req.req.Context())
if modifyResponse != nil {
*modReq = *modReq.WithContext(WithInterceptResponse(modReq.Context(), *modifyResponse))
}
select {
case <-req.done:
return ErrRequestDone
case req.ch <- modReq:
return nil
}
}
// CancelRequest ensures an intercepted request is dropped.
func (svc *Service) CancelRequest(reqID ulid.ULID) error {
return svc.ModifyRequest(reqID, nil, nil)
}
func (svc *Service) ClearRequests() {
svc.reqMu.Lock()
defer svc.reqMu.Unlock()
for _, req := range svc.requests {
select {
case <-req.done:
case req.ch <- nil:
}
}
}
func (svc *Service) ClearResponses() {
svc.resMu.Lock()
defer svc.resMu.Unlock()
for _, res := range svc.responses {
select {
case <-res.done:
case res.ch <- nil:
}
}
}
// Items returns a list of pending items (requests and responses). It's safe for concurrent use.
func (svc *Service) Items() []Item {
svc.reqMu.RLock()
defer svc.reqMu.RUnlock()
svc.resMu.RLock()
defer svc.resMu.RUnlock()
reqIDs := make([]ulid.ULID, 0, len(svc.requests)+len(svc.responses))
for id := range svc.requests {
reqIDs = append(reqIDs, id)
}
for id := range svc.responses {
reqIDs = append(reqIDs, id)
}
sort.Sort(RequestIDs(reqIDs))
items := make([]Item, len(reqIDs))
for i, id := range reqIDs {
item := Item{}
if req, ok := svc.requests[id]; ok {
item.Request = req.req
}
if res, ok := svc.responses[id]; ok {
item.Response = res.res
}
items[i] = item
}
return items
}
func (svc *Service) UpdateSettings(settings Settings) {
// When updating from requests `enabled` -> `disabled`, clear any pending reqs.
if svc.requestsEnabled && !settings.RequestsEnabled {
svc.ClearRequests()
}
// When updating from responses `enabled` -> `disabled`, clear any pending responses.
if svc.responsesEnabled && !settings.ResponsesEnabled {
svc.ClearResponses()
}
svc.requestsEnabled = settings.RequestsEnabled
svc.responsesEnabled = settings.ResponsesEnabled
svc.reqFilter = settings.RequestFilter
svc.resFilter = settings.ResponseFilter
}
// ItemByID returns an intercepted item (request and possible response) by ID. It's safe for concurrent use.
func (svc *Service) ItemByID(id ulid.ULID) (Item, error) {
svc.reqMu.RLock()
defer svc.reqMu.RUnlock()
svc.resMu.RLock()
defer svc.resMu.RUnlock()
item := Item{}
found := false
if req, ok := svc.requests[id]; ok {
item.Request = req.req
found = true
}
if res, ok := svc.responses[id]; ok {
item.Response = res.res
found = true
}
if !found {
return Item{}, ErrRequestNotFound
}
return item, nil
}
func (ids RequestIDs) Len() int {
return len(ids)
}
func (ids RequestIDs) Less(i, j int) bool {
return ids[i].Compare(ids[j]) == -1
}
func (ids RequestIDs) Swap(i, j int) {
ids[i], ids[j] = ids[j], ids[i]
}
func WithInterceptResponse(ctx context.Context, value bool) context.Context {
return context.WithValue(ctx, interceptResponseKey, value)
}
func ShouldInterceptResponseFromContext(ctx context.Context) (bool, bool) {
shouldIntercept, ok := ctx.Value(interceptResponseKey).(bool)
return shouldIntercept, ok
}
// ResponseModifier is a proxy.ResponseModifyMiddleware for intercepting HTTP responses.
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
// This is a blocking operation, that gets unblocked when either a modified response is returned or an error.
//nolint:bodyclose
modifiedRes, err := svc.InterceptResponse(res.Request.Context(), res)
if err != nil {
return fmt.Errorf("failed to intercept response: %w", err)
}
*res = *modifiedRes
return next(res)
}
}
// InterceptResponse adds an HTTP response to an array of pending intercepted responses, alongside channels used for
// sending a cancellation signal and receiving a modified response. It's safe for concurrent use.
func (svc *Service) InterceptResponse(ctx context.Context, res *http.Response) (*http.Response, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
return res, nil
}
shouldIntercept, ok := ShouldInterceptResponseFromContext(ctx)
if ok && !shouldIntercept {
// If the related request explicitly disabled response intercept, return the response as-is.
svc.logger.Debugw("Bypassed response interception: related request explicitly disabled response intercept.")
return res, nil
}
// If global response intercept is disabled and interception is *not* explicitly enabled for this response: bypass.
if !svc.responsesEnabled && !(ok && shouldIntercept) {
svc.logger.Debugw("Bypassed response interception: feature disabled.")
return res, nil
}
if svc.resFilter != nil {
match, err := MatchResponseFilter(res, svc.resFilter)
if err != nil {
return nil, fmt.Errorf("intercept: failed to match response rules for response (id: %v): %w",
reqID.String(), err,
)
}
if !match {
svc.logger.Debugw("Bypassed response interception: response rules don't match.")
return res, nil
}
}
ch := make(chan *http.Response)
done := make(chan struct{})
svc.resMu.Lock()
svc.responses[reqID] = Response{
res: res,
ch: ch,
done: done,
}
svc.resMu.Unlock()
// Whatever happens next (modified response returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the response should be removed from the responses queue.
defer func() {
close(done)
svc.resMu.Lock()
defer svc.resMu.Unlock()
delete(svc.responses, reqID)
}()
select {
case modRes := <-ch:
if modRes == nil {
return nil, ErrRequestAborted
}
return modRes, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// ModifyResponse sends a modified HTTP response to the related channel, or returns ErrRequestDone when the related
// request was cancelled. It's safe for concurrent use.
func (svc *Service) ModifyResponse(reqID ulid.ULID, modRes *http.Response) error {
svc.resMu.RLock()
res, ok := svc.responses[reqID]
svc.resMu.RUnlock()
if !ok {
return ErrRequestNotFound
}
if modRes != nil {
modRes.Request = res.res.Request
}
select {
case <-res.done:
return ErrRequestDone
case res.ch <- modRes:
return nil
}
}
// CancelResponse ensures an intercepted response is dropped.
func (svc *Service) CancelResponse(reqID ulid.ULID) error {
return svc.ModifyResponse(reqID, nil)
}

View File

@ -0,0 +1,270 @@
package intercept_test
import (
"context"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/oklog/ulid"
"go.uber.org/zap"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestRequestModifier(t *testing.T) {
t.Parallel()
t.Run("modify request that's not found", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err := svc.ModifyRequest(reqID, nil, nil)
if !errors.Is(err, intercept.ErrRequestNotFound) {
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("modify request that's done", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(ctx)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
next := func(req *http.Request) {}
go svc.RequestModifier(next)(req)
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
cancel()
modReq := req.Clone(req.Context())
modReq.Header.Set("X-Foo", "bar")
err := svc.ModifyRequest(reqID, modReq, nil)
if !errors.Is(err, intercept.ErrRequestDone) {
t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err)
}
})
t.Run("modify intercepted request", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
req.Header.Set("X-Foo", "foo")
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
modReq := req.Clone(context.Background())
modReq.Header.Set("X-Foo", "bar")
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: true,
ResponsesEnabled: false,
})
var got *http.Request
next := func(req *http.Request) {
got = req.Clone(context.Background())
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
svc.RequestModifier(next)(req)
wg.Done()
}()
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
err := svc.ModifyRequest(reqID, modReq, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
wg.Wait()
if got == nil {
t.Fatal("expected `got` not to be nil")
}
if exp := "bar"; exp != got.Header.Get("X-Foo") {
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
}
})
}
func TestResponseModifier(t *testing.T) {
t.Parallel()
t.Run("modify response that's not found", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: false,
ResponsesEnabled: true,
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err := svc.ModifyResponse(reqID, nil)
if !errors.Is(err, intercept.ErrRequestNotFound) {
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("modify response of request that's done", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: false,
ResponsesEnabled: true,
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(ctx)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
res := &http.Response{
Request: req,
Header: make(http.Header),
}
res.Header.Add("X-Foo", "foo")
var modErr error
var wg sync.WaitGroup
wg.Add(1)
next := func(res *http.Response) error { return nil }
go func() {
defer wg.Done()
modErr = svc.ResponseModifier(next)(res)
}()
// Wait shortly, to allow the res modifier goroutine to add `res` to the
// array of intercepted responses.
time.Sleep(10 * time.Millisecond)
cancel()
modRes := *res
modRes.Header = make(http.Header)
modRes.Header.Set("X-Foo", "bar")
err := svc.ModifyResponse(reqID, &modRes)
if !errors.Is(err, intercept.ErrRequestDone) {
t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err)
}
wg.Wait()
if !errors.Is(modErr, context.Canceled) {
t.Fatalf("expected `context.Canceled`, got: %v", modErr)
}
})
t.Run("modify intercepted response", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
req.Header.Set("X-Foo", "foo")
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
res := &http.Response{
Request: req,
Header: make(http.Header),
}
res.Header.Add("X-Foo", "foo")
modRes := *res
modRes.Header = make(http.Header)
modRes.Header.Set("X-Foo", "bar")
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
RequestsEnabled: false,
ResponsesEnabled: true,
})
var gotHeader string
var next proxy.ResponseModifyFunc = func(res *http.Response) error {
gotHeader = res.Header.Get("X-Foo")
return nil
}
var modErr error
var wg sync.WaitGroup
wg.Add(1)
go func() {
modErr = svc.ResponseModifier(next)(res)
wg.Done()
}()
// Wait shortly, to allow the res modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
err := svc.ModifyResponse(reqID, &modRes)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
wg.Wait()
if modErr != nil {
t.Fatalf("unexpected error: %v", modErr)
}
if exp := "bar"; exp != gotHeader {
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, gotHeader)
}
})
}

View File

@ -0,0 +1,10 @@
package intercept
import "github.com/dstotijn/hetty/pkg/search"
type Settings struct {
RequestsEnabled bool
ResponsesEnabled bool
RequestFilter search.Expression
ResponseFilter search.Expression
}

View File

@ -7,16 +7,24 @@ import (
"crypto/x509"
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/log"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type contextKey int
const ReqLogIDKey contextKey = 0
const reqIDKey contextKey = 0
// Proxy implements http.Handler and offers MITM behaviour for modifying
// HTTP requests and responses.
@ -54,7 +62,25 @@ func NewProxy(cfg Config) (*Proxy, error) {
p.logger = log.NewNopLogger()
}
transport := &http.Transport{
// Values taken from `http.DefaultTransport`.
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Non-default transport values.
DisableCompression: true,
}
p.handler = &httputil.ReverseProxy{
Transport: transport,
Director: p.modifyRequest,
ModifyResponse: p.modifyResponse,
ErrorHandler: p.errorHandler,
@ -69,6 +95,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
ctx := context.WithValue(r.Context(), reqIDKey, reqID)
*r = *r.WithContext(ctx)
p.handler.ServeHTTP(w, r)
}
@ -91,6 +121,25 @@ func (p *Proxy) modifyRequest(r *http.Request) {
// set this header.
r.Header["X-Forwarded-For"] = nil
// Strip unsupported encodings.
if acceptEncs := r.Header.Get("Accept-Encoding"); acceptEncs != "" {
directives := strings.Split(acceptEncs, ",")
updated := make([]string, 0, len(directives))
for _, directive := range directives {
stripped := strings.TrimSpace(directive)
if strings.HasPrefix(stripped, "*") || strings.HasPrefix(stripped, "gzip") {
updated = append(updated, stripped)
}
}
if len(updated) == 0 {
r.Header.Del("Accept-Encoding")
} else {
r.Header.Set("Accept-Encoding", strings.Join(updated, ", "))
}
}
fn := nopReqModifier
for i := len(p.reqModifiers) - 1; i >= 0; i-- {
@ -103,6 +152,11 @@ func (p *Proxy) modifyRequest(r *http.Request) {
func (p *Proxy) modifyResponse(res *http.Response) error {
fn := nopResModifier
// TODO: Make decompressing gzip formatted response bodies a configurable project setting.
if err := gunzipResponseBody(res); err != nil {
return fmt.Errorf("proxy: failed to gunzip response body: %w", err)
}
for i := len(p.resModifiers) - 1; i >= 0; i-- {
fn = p.resModifiers[i](fn)
}
@ -110,6 +164,15 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
return fn(res)
}
func WithRequestID(ctx context.Context, id ulid.ULID) context.Context {
return context.WithValue(ctx, reqIDKey, id)
}
func RequestIDFromContext(ctx context.Context) (ulid.ULID, bool) {
id, ok := ctx.Value(reqIDKey).(ulid.ULID)
return id, ok
}
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
// During the TLS handshake with the client, we use the proxy's CA config to
// create a certificate on-the-fly.
@ -170,13 +233,14 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
}
func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
return
switch {
case !errors.Is(err, context.Canceled):
p.logger.Errorw("Failed to proxy request.",
"error", err)
case errors.Is(err, context.Canceled):
p.logger.Debugw("Proxy request was cancelled.")
}
p.logger.Errorw("Failed to proxy request.",
"error", err)
w.WriteHeader(http.StatusBadGateway)
}