diff --git a/pkg/proxy/intercept/intercept_test.go b/pkg/proxy/intercept/intercept_test.go index bb5e91d..2652a93 100644 --- a/pkg/proxy/intercept/intercept_test.go +++ b/pkg/proxy/intercept/intercept_test.go @@ -129,3 +129,142 @@ func TestRequestModifier(t *testing.T) { } }) } + +func TestResponseModifier(t *testing.T) { + t.Parallel() + + t.Run("modify response that's not found", func(t *testing.T) { + t.Parallel() + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + RequestsEnabled: false, + ResponsesEnabled: true, + }) + + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + + err := svc.ModifyResponse(reqID, nil) + if !errors.Is(err, intercept.ErrRequestNotFound) { + t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err) + } + }) + + t.Run("modify response of request that's done", func(t *testing.T) { + t.Parallel() + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + RequestsEnabled: false, + ResponsesEnabled: true, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + *req = *req.WithContext(ctx) + *req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID)) + + res := &http.Response{ + Request: req, + Header: make(http.Header), + } + res.Header.Add("X-Foo", "foo") + + var modErr error + var wg sync.WaitGroup + wg.Add(1) + + next := func(res *http.Response) error { return nil } + go func() { + defer wg.Done() + modErr = svc.ResponseModifier(next)(res) + }() + + // Wait shortly, to allow the res modifier goroutine to add `res` to the + // array of intercepted responses. + time.Sleep(10 * time.Millisecond) + cancel() + + modRes := *res + modRes.Header = make(http.Header) + modRes.Header.Set("X-Foo", "bar") + + err := svc.ModifyResponse(reqID, &modRes) + if !errors.Is(err, intercept.ErrRequestDone) { + t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err) + } + + wg.Wait() + + if !errors.Is(modErr, context.Canceled) { + t.Fatalf("expected `context.Canceled`, got: %v", modErr) + } + }) + + t.Run("modify intercepted response", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + req.Header.Set("X-Foo", "foo") + + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + *req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID)) + + res := &http.Response{ + Request: req, + Header: make(http.Header), + } + res.Header.Add("X-Foo", "foo") + + modRes := *res + modRes.Header = make(http.Header) + modRes.Header.Set("X-Foo", "bar") + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + RequestsEnabled: false, + ResponsesEnabled: true, + }) + + var gotHeader string + + var next proxy.ResponseModifyFunc = func(res *http.Response) error { + gotHeader = res.Header.Get("X-Foo") + return nil + } + + var modErr error + var wg sync.WaitGroup + wg.Add(1) + + go func() { + modErr = svc.ResponseModifier(next)(res) + wg.Done() + }() + + // Wait shortly, to allow the res modifier goroutine to add `req` to the + // array of intercepted reqs. + time.Sleep(10 * time.Millisecond) + + err := svc.ModifyResponse(reqID, &modRes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wg.Wait() + + if modErr != nil { + t.Fatalf("unexpected error: %v", modErr) + } + + if exp := "bar"; exp != gotHeader { + t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, gotHeader) + } + }) +}