Add initial support for HTTP request intercepting

This commit is contained in:
David Stotijn
2022-03-08 13:27:31 +01:00
parent bdd667381a
commit 71e550f0cd
11 changed files with 1184 additions and 28 deletions

View File

@ -0,0 +1,189 @@
package intercept
import (
"context"
"errors"
"net/http"
"sort"
"sync"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/log"
"github.com/dstotijn/hetty/pkg/proxy"
)
var (
ErrRequestAborted = errors.New("intercept: request was aborted")
ErrRequestNotFound = errors.New("intercept: request not found")
ErrRequestDone = errors.New("intercept: request is done")
)
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Request struct {
req *http.Request
ch chan<- *http.Request
done <-chan struct{}
}
type Service struct {
mu *sync.RWMutex
requests map[ulid.ULID]Request
logger log.Logger
}
type Config struct {
Logger log.Logger
}
// RequestIDs implements sort.Interface.
type RequestIDs []ulid.ULID
func NewService(cfg Config) *Service {
s := &Service{
mu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
logger: cfg.Logger,
}
if s.logger == nil {
s.logger = log.NewNopLogger()
}
return s
}
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP
// requests.
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
// (typically `context.Canceled`).
modifiedReq, err := svc.Intercept(req.Context(), req)
switch {
case errors.Is(err, ErrRequestAborted):
svc.logger.Debugw("Stopping intercept, request was aborted.")
// Prevent further processing by replacing req.Context with a cancelled context value.
// This will cause the http.Roundtripper in the `proxy` package to
// handle this request as an error.
ctx, cancel := context.WithCancel(context.Background())
cancel()
*req = *req.WithContext(ctx)
case errors.Is(err, context.Canceled):
svc.logger.Debugw("Stopping intercept, context was cancelled.")
case err != nil:
svc.logger.Errorw("Failed to intercept request.",
"error", err)
default:
*req = *modifiedReq.WithContext(req.Context())
next(req)
}
}
}
// Intercept adds an HTTP request to an array of pending intercepted requests, alongside channels used for sending a
// cancellation signal and receiving a modified request. It's safe for concurrent use.
func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Request, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: request doesn't have an ID.")
return req, nil
}
ch := make(chan *http.Request)
done := make(chan struct{})
svc.mu.Lock()
svc.requests[reqID] = Request{
req: req,
ch: ch,
done: done,
}
svc.mu.Unlock()
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the request should be removed from the requests queue.
defer func() {
close(done)
svc.mu.Lock()
defer svc.mu.Unlock()
delete(svc.requests, reqID)
}()
select {
case modReq := <-ch:
if modReq == nil {
return nil, ErrRequestAborted
}
return modReq, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
// cancelled. It's safe for concurrent use.
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error {
svc.mu.RLock()
req, ok := svc.requests[reqID]
svc.mu.RUnlock()
if !ok {
return ErrRequestNotFound
}
select {
case <-req.done:
return ErrRequestDone
case req.ch <- modReq:
return nil
}
}
func (svc *Service) ClearRequests() {
svc.mu.Lock()
defer svc.mu.Unlock()
for _, req := range svc.requests {
select {
case <-req.done:
case req.ch <- nil:
}
}
}
// Requests returns a list of pending intercepted requests. It's safe for concurrent use.
func (svc *Service) Requests() []*http.Request {
svc.mu.RLock()
defer svc.mu.RUnlock()
ids := make([]ulid.ULID, 0, len(svc.requests))
for id := range svc.requests {
ids = append(ids, id)
}
sort.Sort(RequestIDs(ids))
reqs := make([]*http.Request, len(ids))
for i, id := range ids {
reqs[i] = svc.requests[id].req
}
return reqs
}
func (ids RequestIDs) Len() int {
return len(ids)
}
func (ids RequestIDs) Less(i, j int) bool {
return ids[i].Compare(ids[j]) == -1
}
func (ids RequestIDs) Swap(i, j int) {
ids[i], ids[j] = ids[j], ids[i]
}

View File

@ -0,0 +1,122 @@
package intercept_test
import (
"context"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/oklog/ulid"
"go.uber.org/zap"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestRequestModifier(t *testing.T) {
t.Parallel()
t.Run("modify request that's not found", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestNotFound) {
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("modify request that's done", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(ctx)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
next := func(req *http.Request) {}
go svc.RequestModifier(next)(req)
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
cancel()
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestDone) {
t.Fatalf("expected `interept.ErrRequestDone`, got: %v", err)
}
})
t.Run("modify intercepted request", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
req.Header.Set("X-Foo", "foo")
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
modReq := req.Clone(context.Background())
modReq.Header.Set("X-Foo", "bar")
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
var got *http.Request
next := func(req *http.Request) {
got = req.Clone(context.Background())
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
svc.RequestModifier(next)(req)
wg.Done()
}()
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
err := svc.ModifyRequest(reqID, modReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
wg.Wait()
if got == nil {
t.Fatal("expected `got` not to be nil")
}
if exp := "bar"; exp != got.Header.Get("X-Foo") {
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
}
})
}

View File

@ -7,16 +7,22 @@ import (
"crypto/x509"
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"time"
"github.com/dstotijn/hetty/pkg/log"
"github.com/oklog/ulid"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type contextKey int
const ReqLogIDKey contextKey = 0
const reqIDKey contextKey = 0
// Proxy implements http.Handler and offers MITM behaviour for modifying
// HTTP requests and responses.
@ -69,6 +75,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
ctx := context.WithValue(r.Context(), reqIDKey, reqID)
*r = *r.WithContext(ctx)
p.handler.ServeHTTP(w, r)
}
@ -110,6 +120,15 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
return fn(res)
}
func WithRequestID(ctx context.Context, id ulid.ULID) context.Context {
return context.WithValue(ctx, reqIDKey, id)
}
func RequestIDFromContext(ctx context.Context) (ulid.ULID, bool) {
id, ok := ctx.Value(reqIDKey).(ulid.ULID)
return id, ok
}
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
// During the TLS handshake with the client, we use the proxy's CA config to
// create a certificate on-the-fly.
@ -170,13 +189,11 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
}
func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
return
if !errors.Is(err, context.Canceled) {
p.logger.Errorw("Failed to proxy request.",
"error", err)
}
p.logger.Errorw("Failed to proxy request.",
"error", err)
w.WriteHeader(http.StatusBadGateway)
}