mirror of
https://github.com/dstotijn/hetty.git
synced 2025-07-01 18:47:29 -04:00
Add initial support for HTTP request intercepting
This commit is contained in:
189
pkg/proxy/intercept/intercept.go
Normal file
189
pkg/proxy/intercept/intercept.go
Normal file
@ -0,0 +1,189 @@
|
||||
package intercept
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/log"
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestAborted = errors.New("intercept: request was aborted")
|
||||
ErrRequestNotFound = errors.New("intercept: request not found")
|
||||
ErrRequestDone = errors.New("intercept: request is done")
|
||||
)
|
||||
|
||||
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
|
||||
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
|
||||
type Request struct {
|
||||
req *http.Request
|
||||
ch chan<- *http.Request
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
mu *sync.RWMutex
|
||||
requests map[ulid.ULID]Request
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger log.Logger
|
||||
}
|
||||
|
||||
// RequestIDs implements sort.Interface.
|
||||
type RequestIDs []ulid.ULID
|
||||
|
||||
func NewService(cfg Config) *Service {
|
||||
s := &Service{
|
||||
mu: &sync.RWMutex{},
|
||||
requests: make(map[ulid.ULID]Request),
|
||||
logger: cfg.Logger,
|
||||
}
|
||||
|
||||
if s.logger == nil {
|
||||
s.logger = log.NewNopLogger()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP
|
||||
// requests.
|
||||
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
||||
return func(req *http.Request) {
|
||||
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
|
||||
// (typically `context.Canceled`).
|
||||
modifiedReq, err := svc.Intercept(req.Context(), req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrRequestAborted):
|
||||
svc.logger.Debugw("Stopping intercept, request was aborted.")
|
||||
// Prevent further processing by replacing req.Context with a cancelled context value.
|
||||
// This will cause the http.Roundtripper in the `proxy` package to
|
||||
// handle this request as an error.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
*req = *req.WithContext(ctx)
|
||||
case errors.Is(err, context.Canceled):
|
||||
svc.logger.Debugw("Stopping intercept, context was cancelled.")
|
||||
case err != nil:
|
||||
svc.logger.Errorw("Failed to intercept request.",
|
||||
"error", err)
|
||||
default:
|
||||
*req = *modifiedReq.WithContext(req.Context())
|
||||
next(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Intercept adds an HTTP request to an array of pending intercepted requests, alongside channels used for sending a
|
||||
// cancellation signal and receiving a modified request. It's safe for concurrent use.
|
||||
func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Request, error) {
|
||||
reqID, ok := proxy.RequestIDFromContext(ctx)
|
||||
if !ok {
|
||||
svc.logger.Errorw("Failed to intercept: request doesn't have an ID.")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
ch := make(chan *http.Request)
|
||||
done := make(chan struct{})
|
||||
|
||||
svc.mu.Lock()
|
||||
svc.requests[reqID] = Request{
|
||||
req: req,
|
||||
ch: ch,
|
||||
done: done,
|
||||
}
|
||||
svc.mu.Unlock()
|
||||
|
||||
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
|
||||
// should be unblocked, and the request should be removed from the requests queue.
|
||||
defer func() {
|
||||
close(done)
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
delete(svc.requests, reqID)
|
||||
}()
|
||||
|
||||
select {
|
||||
case modReq := <-ch:
|
||||
if modReq == nil {
|
||||
return nil, ErrRequestAborted
|
||||
}
|
||||
|
||||
return modReq, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
|
||||
// cancelled. It's safe for concurrent use.
|
||||
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error {
|
||||
svc.mu.RLock()
|
||||
req, ok := svc.requests[reqID]
|
||||
svc.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return ErrRequestNotFound
|
||||
}
|
||||
|
||||
select {
|
||||
case <-req.done:
|
||||
return ErrRequestDone
|
||||
case req.ch <- modReq:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) ClearRequests() {
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
|
||||
for _, req := range svc.requests {
|
||||
select {
|
||||
case <-req.done:
|
||||
case req.ch <- nil:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Requests returns a list of pending intercepted requests. It's safe for concurrent use.
|
||||
func (svc *Service) Requests() []*http.Request {
|
||||
svc.mu.RLock()
|
||||
defer svc.mu.RUnlock()
|
||||
|
||||
ids := make([]ulid.ULID, 0, len(svc.requests))
|
||||
for id := range svc.requests {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
sort.Sort(RequestIDs(ids))
|
||||
|
||||
reqs := make([]*http.Request, len(ids))
|
||||
for i, id := range ids {
|
||||
reqs[i] = svc.requests[id].req
|
||||
}
|
||||
|
||||
return reqs
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Len() int {
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Less(i, j int) bool {
|
||||
return ids[i].Compare(ids[j]) == -1
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Swap(i, j int) {
|
||||
ids[i], ids[j] = ids[j], ids[i]
|
||||
}
|
122
pkg/proxy/intercept/intercept_test.go
Normal file
122
pkg/proxy/intercept/intercept_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
package intercept_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/proxy/intercept"
|
||||
)
|
||||
|
||||
//nolint:gosec
|
||||
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
func TestRequestModifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("modify request that's not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
})
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
|
||||
err := svc.ModifyRequest(reqID, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestNotFound) {
|
||||
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify request that's done", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(ctx)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
next := func(req *http.Request) {}
|
||||
go svc.RequestModifier(next)(req)
|
||||
|
||||
// Wait shortly, to allow the req modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
err := svc.ModifyRequest(reqID, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestDone) {
|
||||
t.Fatalf("expected `interept.ErrRequestDone`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify intercepted request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
req.Header.Set("X-Foo", "foo")
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
modReq := req.Clone(context.Background())
|
||||
modReq.Header.Set("X-Foo", "bar")
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
})
|
||||
|
||||
var got *http.Request
|
||||
|
||||
next := func(req *http.Request) {
|
||||
got = req.Clone(context.Background())
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
svc.RequestModifier(next)(req)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the req modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := svc.ModifyRequest(reqID, modReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("expected `got` not to be nil")
|
||||
}
|
||||
|
||||
if exp := "bar"; exp != got.Header.Get("X-Foo") {
|
||||
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
|
||||
}
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user