mirror of
https://github.com/dstotijn/hetty.git
synced 2025-07-01 18:47:29 -04:00
Add intercept module
This commit is contained in:
@ -2,16 +2,13 @@ package reqlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
@ -23,16 +20,16 @@ import (
|
||||
|
||||
type contextKey int
|
||||
|
||||
const LogBypassedKey contextKey = 0
|
||||
const (
|
||||
LogBypassedKey contextKey = iota
|
||||
ReqLogIDKey
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestNotFound = errors.New("reqlog: request not found")
|
||||
ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set")
|
||||
)
|
||||
|
||||
//nolint:gosec
|
||||
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
type RequestLog struct {
|
||||
ID ulid.ULID
|
||||
ProjectID ulid.ULID
|
||||
@ -170,8 +167,14 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||
return
|
||||
}
|
||||
|
||||
reqID, ok := proxy.RequestIDFromContext(req.Context())
|
||||
if !ok {
|
||||
svc.logger.Errorw("Bypassed logging: request doesn't have an ID.")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := RequestLog{
|
||||
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
|
||||
ID: reqID,
|
||||
ProjectID: svc.activeProjectID,
|
||||
Method: clone.Method,
|
||||
URL: clone.URL,
|
||||
@ -191,7 +194,7 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||
"reqLogID", reqLog.ID.String(),
|
||||
"url", reqLog.URL.String())
|
||||
|
||||
ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID)
|
||||
ctx := context.WithValue(req.Context(), ReqLogIDKey, reqLog.ID)
|
||||
*req = *req.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@ -206,21 +209,23 @@ func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
|
||||
return nil
|
||||
}
|
||||
|
||||
reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID)
|
||||
reqLogID, ok := res.Request.Context().Value(ReqLogIDKey).(ulid.ULID)
|
||||
if !ok {
|
||||
return errors.New("reqlog: request is missing ID")
|
||||
}
|
||||
|
||||
clone := *res
|
||||
|
||||
// TODO: Use io.LimitReader.
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||
}
|
||||
if res.Body != nil {
|
||||
// TODO: Use io.LimitReader.
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
res.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := svc.storeResponse(context.Background(), reqLogID, &clone); err != nil {
|
||||
@ -261,23 +266,6 @@ func (svc *service) BypassOutOfScopeRequests() bool {
|
||||
}
|
||||
|
||||
func ParseHTTPResponse(res *http.Response) (ResponseLog, error) {
|
||||
if res.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(res.Body)
|
||||
if err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
//nolint:gosec
|
||||
if _, err := io.Copy(buf, gzipReader); err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = io.NopCloser(buf)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not read body: %w", err)
|
||||
|
@ -41,6 +41,8 @@ func TestRequestModifier(t *testing.T) {
|
||||
}
|
||||
reqModFn := svc.RequestModifier(next)
|
||||
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
req = req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
reqModFn(req)
|
||||
|
||||
@ -88,7 +90,7 @@ func TestResponseModifier(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
|
||||
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID))
|
||||
req = req.WithContext(context.WithValue(req.Context(), reqlog.ReqLogIDKey, reqLogID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
|
@ -3,7 +3,6 @@ package reqlog
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -100,7 +99,7 @@ func (reqLog RequestLog) matchInfixExpr(expr search.InfixExpression) (bool, erro
|
||||
leftVal := reqLog.getMappedStringLiteral(left.Value)
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(*regexp.Regexp)
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
Reference in New Issue
Block a user