Files
hetty/pkg/proxy/intercept/intercept_test.go

123 lines
2.9 KiB
Go
Raw Normal View History

package intercept_test
import (
"context"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/oklog/ulid"
"go.uber.org/zap"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestRequestModifier(t *testing.T) {
t.Parallel()
t.Run("modify request that's not found", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestNotFound) {
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("modify request that's done", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(ctx)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
next := func(req *http.Request) {}
go svc.RequestModifier(next)(req)
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
cancel()
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestDone) {
t.Fatalf("expected `interept.ErrRequestDone`, got: %v", err)
}
})
t.Run("modify intercepted request", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
req.Header.Set("X-Foo", "foo")
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
modReq := req.Clone(context.Background())
modReq.Header.Set("X-Foo", "bar")
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
var got *http.Request
next := func(req *http.Request) {
got = req.Clone(context.Background())
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
svc.RequestModifier(next)(req)
wg.Done()
}()
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
err := svc.ModifyRequest(reqID, modReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
wg.Wait()
if got == nil {
t.Fatal("expected `got` not to be nil")
}
if exp := "bar"; exp != got.Header.Get("X-Foo") {
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
}
})
}