Add intercept module

This commit is contained in:
David Stotijn
2022-03-23 14:31:27 +01:00
parent 6ffc55cde3
commit 02408b5196
51 changed files with 5779 additions and 304 deletions

View File

@ -2,16 +2,13 @@ package reqlog
import (
"bytes"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/oklog/ulid"
@ -23,16 +20,16 @@ import (
type contextKey int
const LogBypassedKey contextKey = 0
const (
LogBypassedKey contextKey = iota
ReqLogIDKey
)
var (
ErrRequestNotFound = errors.New("reqlog: request not found")
ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set")
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type RequestLog struct {
ID ulid.ULID
ProjectID ulid.ULID
@ -170,8 +167,14 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
return
}
reqID, ok := proxy.RequestIDFromContext(req.Context())
if !ok {
svc.logger.Errorw("Bypassed logging: request doesn't have an ID.")
return
}
reqLog := RequestLog{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ID: reqID,
ProjectID: svc.activeProjectID,
Method: clone.Method,
URL: clone.URL,
@ -191,7 +194,7 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
"reqLogID", reqLog.ID.String(),
"url", reqLog.URL.String())
ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID)
ctx := context.WithValue(req.Context(), ReqLogIDKey, reqLog.ID)
*req = *req.WithContext(ctx)
}
}
@ -206,21 +209,23 @@ func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID)
reqLogID, ok := res.Request.Context().Value(ReqLogIDKey).(ulid.ULID)
if !ok {
return errors.New("reqlog: request is missing ID")
}
clone := *res
// TODO: Use io.LimitReader.
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("reqlog: could not read response body: %w", err)
}
if res.Body != nil {
// TODO: Use io.LimitReader.
body, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("reqlog: could not read response body: %w", err)
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
clone.Body = ioutil.NopCloser(bytes.NewBuffer(body))
res.Body = io.NopCloser(bytes.NewBuffer(body))
clone.Body = io.NopCloser(bytes.NewBuffer(body))
}
go func() {
if err := svc.storeResponse(context.Background(), reqLogID, &clone); err != nil {
@ -261,23 +266,6 @@ func (svc *service) BypassOutOfScopeRequests() bool {
}
func ParseHTTPResponse(res *http.Response) (ResponseLog, error) {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(res.Body)
if err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
}
defer gzipReader.Close()
buf := &bytes.Buffer{}
//nolint:gosec
if _, err := io.Copy(buf, gzipReader); err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
}
res.Body = io.NopCloser(buf)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not read body: %w", err)

View File

@ -41,6 +41,8 @@ func TestRequestModifier(t *testing.T) {
}
reqModFn := svc.RequestModifier(next)
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req = req.WithContext(proxy.WithRequestID(req.Context(), reqID))
reqModFn(req)
@ -88,7 +90,7 @@ func TestResponseModifier(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID))
req = req.WithContext(context.WithValue(req.Context(), reqlog.ReqLogIDKey, reqLogID))
res := &http.Response{
Request: req,

View File

@ -3,7 +3,6 @@ package reqlog
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
@ -100,7 +99,7 @@ func (reqLog RequestLog) matchInfixExpr(expr search.InfixExpression) (bool, erro
leftVal := reqLog.getMappedStringLiteral(left.Value)
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(*regexp.Regexp)
right, ok := expr.Right.(search.RegexpLiteral)
if !ok {
return false, errors.New("right operand must be a regular expression")
}