Add Sender module

This commit is contained in:
David Stotijn
2022-02-22 14:10:39 +01:00
parent afa211d0ec
commit efc20564c1
58 changed files with 9131 additions and 282 deletions

File diff suppressed because it is too large Load Diff

View File

@ -3,29 +3,49 @@ package api
import (
"fmt"
"io"
"net/url"
"strconv"
"github.com/99designs/gqlgen/graphql"
"github.com/oklog/ulid"
)
type ULID ulid.ULID
func MarshalULID(u ulid.ULID) graphql.Marshaler {
return graphql.WriterFunc(func(w io.Writer) {
fmt.Fprint(w, strconv.Quote(u.String()))
})
}
func (u *ULID) UnmarshalGQL(v interface{}) (err error) {
str, ok := v.(string)
func UnmarshalULID(v interface{}) (ulid.ULID, error) {
rawULID, ok := v.(string)
if !ok {
return fmt.Errorf("ulid must be a string")
return ulid.ULID{}, fmt.Errorf("ulid must be a string")
}
id, err := ulid.Parse(str)
u, err := ulid.Parse(rawULID)
if err != nil {
return fmt.Errorf("failed to parse ULID: %w", err)
return ulid.ULID{}, fmt.Errorf("failed to parse ULID: %w", err)
}
*u = ULID(id)
return nil
return u, nil
}
func (u ULID) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(ulid.ULID(u).String()))
func MarshalURL(u *url.URL) graphql.Marshaler {
return graphql.WriterFunc(func(w io.Writer) {
fmt.Fprint(w, strconv.Quote(u.String()))
})
}
func UnmarshalURL(v interface{}) (*url.URL, error) {
rawURL, ok := v.(string)
if !ok {
return nil, fmt.Errorf("url must be a string")
}
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
return u, nil
}

View File

@ -5,8 +5,11 @@ package api
import (
"fmt"
"io"
"net/url"
"strconv"
"time"
"github.com/oklog/ulid"
)
type ClearHTTPRequestLogResult struct {
@ -21,13 +24,22 @@ type DeleteProjectResult struct {
Success bool `json:"success"`
}
type DeleteSenderRequestsResult struct {
Success bool `json:"success"`
}
type HTTPHeader struct {
Key string `json:"key"`
Value string `json:"value"`
}
type HTTPHeaderInput struct {
Key string `json:"key"`
Value string `json:"value"`
}
type HTTPRequestLog struct {
ID ULID `json:"id"`
ID ulid.ULID `json:"id"`
URL string `json:"url"`
Method HTTPMethod `json:"method"`
Proto string `json:"proto"`
@ -48,7 +60,9 @@ type HTTPRequestLogFilterInput struct {
}
type HTTPResponseLog struct {
Proto string `json:"proto"`
// Will be the same ID as its related request ID.
ID ulid.ULID `json:"id"`
Proto HTTPProtocol `json:"proto"`
StatusCode int `json:"statusCode"`
StatusReason string `json:"statusReason"`
Body *string `json:"body"`
@ -56,9 +70,9 @@ type HTTPResponseLog struct {
}
type Project struct {
ID ULID `json:"id"`
Name string `json:"name"`
IsActive bool `json:"isActive"`
ID ulid.ULID `json:"id"`
Name string `json:"name"`
IsActive bool `json:"isActive"`
}
type ScopeHeader struct {
@ -83,6 +97,37 @@ type ScopeRuleInput struct {
Body *string `json:"body"`
}
type SenderRequest struct {
ID ulid.ULID `json:"id"`
SourceRequestLogID *ulid.ULID `json:"sourceRequestLogID"`
URL *url.URL `json:"url"`
Method HTTPMethod `json:"method"`
Proto HTTPProtocol `json:"proto"`
Headers []HTTPHeader `json:"headers"`
Body *string `json:"body"`
Timestamp time.Time `json:"timestamp"`
Response *HTTPResponseLog `json:"response"`
}
type SenderRequestFilter struct {
OnlyInScope bool `json:"onlyInScope"`
SearchExpression *string `json:"searchExpression"`
}
type SenderRequestFilterInput struct {
OnlyInScope *bool `json:"onlyInScope"`
SearchExpression *string `json:"searchExpression"`
}
type SenderRequestInput struct {
ID *ulid.ULID `json:"id"`
URL *url.URL `json:"url"`
Method *HTTPMethod `json:"method"`
Proto *HTTPProtocol `json:"proto"`
Headers []HTTPHeaderInput `json:"headers"`
Body *string `json:"body"`
}
type HTTPMethod string
const (
@ -137,3 +182,44 @@ func (e *HTTPMethod) UnmarshalGQL(v interface{}) error {
func (e HTTPMethod) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type HTTPProtocol string
const (
HTTPProtocolHTTP1 HTTPProtocol = "HTTP1"
HTTPProtocolHTTP2 HTTPProtocol = "HTTP2"
)
var AllHTTPProtocol = []HTTPProtocol{
HTTPProtocolHTTP1,
HTTPProtocolHTTP2,
}
func (e HTTPProtocol) IsValid() bool {
switch e {
case HTTPProtocolHTTP1, HTTPProtocolHTTP2:
return true
}
return false
}
func (e HTTPProtocol) String() string {
return string(e)
}
func (e *HTTPProtocol) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = HTTPProtocol(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid HttpProtocol", str)
}
return nil
}
func (e HTTPProtocol) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}

View File

@ -6,6 +6,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
@ -17,11 +18,23 @@ import (
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
"github.com/dstotijn/hetty/pkg/sender"
)
var httpProtocolMap = map[string]HTTPProtocol{
sender.HTTPProto1: HTTPProtocolHTTP1,
sender.HTTPProto2: HTTPProtocolHTTP2,
}
var revHTTPProtocolMap = map[HTTPProtocol]string{
HTTPProtocolHTTP1: sender.HTTPProto1,
HTTPProtocolHTTP2: sender.HTTPProto2,
}
type Resolver struct {
ProjectService proj.Service
RequestLogService *reqlog.Service
RequestLogService reqlog.Service
SenderService sender.Service
}
type (
@ -54,8 +67,8 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
return logs, nil
}
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id ULID) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, ulid.ULID(id))
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id ulid.ULID) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
if errors.Is(err, reqlog.ErrRequestNotFound) {
return nil, nil
} else if err != nil {
@ -77,7 +90,7 @@ func parseRequestLog(reqLog reqlog.RequestLog) (HTTPRequestLog, error) {
}
log := HTTPRequestLog{
ID: ULID(reqLog.ID),
ID: reqLog.ID,
Proto: reqLog.Proto,
Method: method,
Timestamp: ulid.Time(reqLog.ID.Time()),
@ -106,36 +119,54 @@ func parseRequestLog(reqLog reqlog.RequestLog) (HTTPRequestLog, error) {
}
if reqLog.Response != nil {
log.Response = &HTTPResponseLog{
Proto: reqLog.Response.Proto,
StatusCode: reqLog.Response.StatusCode,
}
statusReasonSubs := strings.SplitN(reqLog.Response.Status, " ", 2)
if len(statusReasonSubs) == 2 {
log.Response.StatusReason = statusReasonSubs[1]
resLog, err := parseResponseLog(*reqLog.Response)
if err != nil {
return HTTPRequestLog{}, err
}
if len(reqLog.Response.Body) > 0 {
bodyStr := string(reqLog.Response.Body)
log.Response.Body = &bodyStr
}
resLog.ID = reqLog.ID
if reqLog.Response.Header != nil {
log.Response.Headers = make([]HTTPHeader, 0)
log.Response = &resLog
}
for key, values := range reqLog.Response.Header {
for _, value := range values {
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
Key: key,
Value: value,
})
}
return log, nil
}
func parseResponseLog(resLog reqlog.ResponseLog) (HTTPResponseLog, error) {
proto := httpProtocolMap[resLog.Proto]
if !proto.IsValid() {
return HTTPResponseLog{}, fmt.Errorf("sender response has invalid protocol: %v", resLog.Proto)
}
httpResLog := HTTPResponseLog{
Proto: proto,
StatusCode: resLog.StatusCode,
}
statusReasonSubs := strings.SplitN(resLog.Status, " ", 2)
if len(statusReasonSubs) == 2 {
httpResLog.StatusReason = statusReasonSubs[1]
}
if len(resLog.Body) > 0 {
bodyStr := string(resLog.Body)
httpResLog.Body = &bodyStr
}
if resLog.Header != nil {
httpResLog.Headers = make([]HTTPHeader, 0)
for key, values := range resLog.Header {
for _, value := range values {
httpResLog.Headers = append(httpResLog.Headers, HTTPHeader{
Key: key,
Value: value,
})
}
}
}
return log, nil
return httpResLog, nil
}
func (r *mutationResolver) CreateProject(ctx context.Context, name string) (*Project, error) {
@ -147,14 +178,14 @@ func (r *mutationResolver) CreateProject(ctx context.Context, name string) (*Pro
}
return &Project{
ID: ULID(p.ID),
ID: p.ID,
Name: p.Name,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
}
func (r *mutationResolver) OpenProject(ctx context.Context, id ULID) (*Project, error) {
p, err := r.ProjectService.OpenProject(ctx, ulid.ULID(id))
func (r *mutationResolver) OpenProject(ctx context.Context, id ulid.ULID) (*Project, error) {
p, err := r.ProjectService.OpenProject(ctx, id)
if errors.Is(err, proj.ErrInvalidName) {
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
} else if err != nil {
@ -162,7 +193,7 @@ func (r *mutationResolver) OpenProject(ctx context.Context, id ULID) (*Project,
}
return &Project{
ID: ULID(p.ID),
ID: p.ID,
Name: p.Name,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
@ -177,7 +208,7 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
}
return &Project{
ID: ULID(p.ID),
ID: p.ID,
Name: p.Name,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
@ -192,7 +223,7 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
projects := make([]Project, len(p))
for i, proj := range p {
projects[i] = Project{
ID: ULID(proj.ID),
ID: proj.ID,
Name: proj.Name,
IsActive: r.ProjectService.IsProjectActive(proj.ID),
}
@ -224,8 +255,8 @@ func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResul
return &CloseProjectResult{true}, nil
}
func (r *mutationResolver) DeleteProject(ctx context.Context, id ULID) (*DeleteProjectResult, error) {
if err := r.ProjectService.DeleteProject(ctx, ulid.ULID(id)); err != nil {
func (r *mutationResolver) DeleteProject(ctx context.Context, id ulid.ULID) (*DeleteProjectResult, error) {
if err := r.ProjectService.DeleteProject(ctx, id); err != nil {
return nil, fmt.Errorf("could not delete project: %w", err)
}
@ -296,7 +327,7 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput)
}
func (r *queryResolver) HTTPRequestLogFilter(ctx context.Context) (*HTTPRequestLogFilter, error) {
return findReqFilterToHTTPReqLogFilter(r.RequestLogService.FindReqsFilter), nil
return findReqFilterToHTTPReqLogFilter(r.RequestLogService.FindReqsFilter()), nil
}
func (r *mutationResolver) SetHTTPRequestLogFilter(
@ -318,6 +349,221 @@ func (r *mutationResolver) SetHTTPRequestLogFilter(
return findReqFilterToHTTPReqLogFilter(filter), nil
}
func (r *queryResolver) SenderRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) {
senderReq, err := r.SenderService.FindRequestByID(ctx, id)
if errors.Is(err, sender.ErrRequestNotFound) {
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("could not get request by ID: %w", err)
}
req, err := parseSenderRequest(senderReq)
if err != nil {
return nil, err
}
return &req, nil
}
func (r *queryResolver) SenderRequests(ctx context.Context) ([]SenderRequest, error) {
reqs, err := r.SenderService.FindRequests(ctx)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("failed to find sender requests: %w", err)
}
senderReqs := make([]SenderRequest, len(reqs))
for i, req := range reqs {
req, err := parseSenderRequest(req)
if err != nil {
return nil, err
}
senderReqs[i] = req
}
return senderReqs, nil
}
func (r *mutationResolver) SetSenderRequestFilter(
ctx context.Context,
input *SenderRequestFilterInput,
) (*SenderRequestFilter, error) {
filter, err := findSenderRequestsFilterFromInput(input)
if err != nil {
return nil, fmt.Errorf("could not parse request log filter: %w", err)
}
err = r.ProjectService.SetSenderRequestFindFilter(ctx, filter)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not set request log filter: %w", err)
}
return findReqFilterToSenderReqFilter(filter), nil
}
func (r *mutationResolver) CreateOrUpdateSenderRequest(ctx context.Context, input SenderRequestInput) (*SenderRequest, error) {
req := sender.Request{
URL: input.URL,
Header: make(http.Header),
}
if input.ID != nil {
req.ID = *input.ID
}
if input.Method != nil {
req.Method = input.Method.String()
}
if input.Proto != nil {
req.Proto = revHTTPProtocolMap[*input.Proto]
}
for _, header := range input.Headers {
req.Header.Add(header.Key, header.Value)
}
if input.Body != nil {
req.Body = []byte(*input.Body)
}
req, err := r.SenderService.CreateOrUpdateRequest(ctx, req)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not create sender request: %w", err)
}
senderReq, err := parseSenderRequest(req)
if err != nil {
return nil, err
}
return &senderReq, nil
}
func (r *mutationResolver) CreateSenderRequestFromHTTPRequestLog(ctx context.Context, id ulid.ULID) (*SenderRequest, error) {
req, err := r.SenderService.CloneFromRequestLog(ctx, id)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not create sender request from http request log: %w", err)
}
senderReq, err := parseSenderRequest(req)
if err != nil {
return nil, err
}
return &senderReq, nil
}
func (r *mutationResolver) SendRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) {
// Use new context, because we don't want to risk interrupting sending the request
// or the subsequent storing of the response, e.g. if ctx gets cancelled or
// times out.
ctx2 := context.Background()
var sendErr *sender.SendError
req, err := r.SenderService.SendRequest(ctx2, id)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if errors.As(err, &sendErr) {
return nil, &gqlerror.Error{
Path: graphql.GetPath(ctx),
Message: fmt.Sprintf("Sending request failed: %v", sendErr.Unwrap()),
Extensions: map[string]interface{}{
"code": "send_request_failed",
},
}
} else if err != nil {
return nil, fmt.Errorf("could not send request: %w", err)
}
senderReq, err := parseSenderRequest(req)
if err != nil {
return nil, err
}
return &senderReq, nil
}
func (r *mutationResolver) DeleteSenderRequests(ctx context.Context) (*DeleteSenderRequestsResult, error) {
project, err := r.ProjectService.ActiveProject(ctx)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not get active project: %w", err)
}
if err := r.SenderService.DeleteRequests(ctx, project.ID); err != nil {
return nil, fmt.Errorf("could not clear request log: %w", err)
}
return &DeleteSenderRequestsResult{true}, nil
}
func parseSenderRequest(req sender.Request) (SenderRequest, error) {
method := HTTPMethod(req.Method)
if method != "" && !method.IsValid() {
return SenderRequest{}, fmt.Errorf("sender request has invalid method: %v", method)
}
reqProto := httpProtocolMap[req.Proto]
if !reqProto.IsValid() {
return SenderRequest{}, fmt.Errorf("sender request has invalid protocol: %v", req.Proto)
}
senderReq := SenderRequest{
ID: req.ID,
URL: req.URL,
Method: method,
Proto: HTTPProtocol(req.Proto),
Timestamp: ulid.Time(req.ID.Time()),
}
if req.SourceRequestLogID.Compare(ulid.ULID{}) != 0 {
senderReq.SourceRequestLogID = &req.SourceRequestLogID
}
if req.Header != nil {
senderReq.Headers = make([]HTTPHeader, 0)
for key, values := range req.Header {
for _, value := range values {
senderReq.Headers = append(senderReq.Headers, HTTPHeader{
Key: key,
Value: value,
})
}
}
}
if len(req.Body) > 0 {
bodyStr := string(req.Body)
senderReq.Body = &bodyStr
}
if req.Response != nil {
resLog, err := parseResponseLog(*req.Response)
if err != nil {
return SenderRequest{}, err
}
resLog.ID = req.ID
senderReq.Response = &resLog
}
return senderReq, nil
}
func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
if s == nil {
return nil, nil
@ -364,6 +610,27 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo
return
}
func findSenderRequestsFilterFromInput(input *SenderRequestFilterInput) (filter sender.FindRequestsFilter, err error) {
if input == nil {
return
}
if input.OnlyInScope != nil {
filter.OnlyInScope = *input.OnlyInScope
}
if input.SearchExpression != nil && *input.SearchExpression != "" {
expr, err := search.ParseQuery(*input.SearchExpression)
if err != nil {
return sender.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err)
}
filter.SearchExpr = expr
}
return
}
func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *HTTPRequestLogFilter {
empty := reqlog.FindRequestsFilter{}
if findReqFilter == empty {
@ -382,6 +649,24 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H
return httpReqLogFilter
}
func findReqFilterToSenderReqFilter(findReqFilter sender.FindRequestsFilter) *SenderRequestFilter {
empty := sender.FindRequestsFilter{}
if findReqFilter == empty {
return nil
}
senderReqFilter := &SenderRequestFilter{
OnlyInScope: findReqFilter.OnlyInScope,
}
if findReqFilter.SearchExpr != nil {
searchExpr := findReqFilter.SearchExpr.String()
senderReqFilter.SearchExpression = &searchExpr
}
return senderReqFilter
}
func noActiveProjectErr(ctx context.Context) error {
return &gqlerror.Error{
Path: graphql.GetPath(ctx),

View File

@ -10,7 +10,11 @@ type HttpRequestLog {
}
type HttpResponseLog {
proto: String!
"""
Will be the same ID as its related request ID.
"""
id: ID!
proto: HttpProtocol!
statusCode: Int!
statusReason: String!
body: String
@ -62,6 +66,10 @@ type ClearHTTPRequestLogResult {
success: Boolean!
}
type DeleteSenderRequestsResult {
success: Boolean!
}
input HttpRequestLogFilterInput {
onlyInScope: Boolean
searchExpression: String
@ -72,6 +80,42 @@ type HttpRequestLogFilter {
searchExpression: String
}
input SenderRequestInput {
id: ID
url: URL!
method: HttpMethod
proto: HttpProtocol
headers: [HttpHeaderInput!]
body: String
}
input HttpHeaderInput {
key: String!
value: String!
}
type SenderRequest {
id: ID!
sourceRequestLogID: ID
url: URL!
method: HttpMethod!
proto: HttpProtocol!
headers: [HttpHeader!]
body: String
timestamp: Time!
response: HttpResponseLog
}
input SenderRequestFilterInput {
onlyInScope: Boolean
searchExpression: String
}
type SenderRequestFilter {
onlyInScope: Boolean!
searchExpression: String
}
type Query {
httpRequestLog(id: ID!): HttpRequestLog
httpRequestLogs: [HttpRequestLog!]!
@ -79,6 +123,8 @@ type Query {
activeProject: Project
projects: [Project!]!
scope: [ScopeRule!]!
senderRequest(id: ID!): SenderRequest
senderRequests: [SenderRequest!]!
}
type Mutation {
@ -91,6 +137,11 @@ type Mutation {
setHttpRequestLogFilter(
filter: HttpRequestLogFilterInput
): HttpRequestLogFilter
setSenderRequestFilter(filter: SenderRequestFilterInput): SenderRequestFilter
createOrUpdateSenderRequest(request: SenderRequestInput!): SenderRequest!
createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest!
sendRequest(id: ID!): SenderRequest!
deleteSenderRequests: DeleteSenderRequestsResult!
}
enum HttpMethod {
@ -105,5 +156,11 @@ enum HttpMethod {
PATCH
}
enum HttpProtocol {
HTTP1
HTTP2
}
scalar Time
scalar Regexp
scalar URL

View File

@ -8,12 +8,16 @@ import (
const (
// Key prefixes. Each prefix value should be unique.
projectPrefix = 0x00
reqLogPrefix = 0x01
resLogPrefix = 0x02
projectPrefix = 0x00
reqLogPrefix = 0x01
resLogPrefix = 0x02
senderReqPrefix = 0x03
// Request log indices.
reqLogProjectIDIndex = 0x00
// Sender request indices.
senderReqProjectIDIndex = 0x00
)
// Database is used to store and retrieve data from an underlying Badger database.

View File

@ -65,6 +65,11 @@ func (db *Database) DeleteProject(ctx context.Context, projectID ulid.ULID) erro
return fmt.Errorf("badger: failed to delete project request logs: %w", err)
}
err = db.DeleteSenderRequests(ctx, projectID)
if err != nil {
return fmt.Errorf("badger: failed to delete project sender requests: %w", err)
}
err = db.badger.Update(func(txn *badger.Txn) error {
return txn.Delete(entryKey(projectPrefix, 0, projectID[:]))
})

View File

@ -56,6 +56,7 @@ func TestUpsertProject(t *testing.T) {
Settings: proj.Settings{
ReqLogBypassOutOfScope: true,
ReqLogOnlyFindInScope: true,
ReqLogSearchExpr: searchExpr,
ScopeRules: []scope.Rule{
{
URL: regexp.MustCompile("^https://(.*)example.com(.*)$"),
@ -66,7 +67,6 @@ func TestUpsertProject(t *testing.T) {
Body: regexp.MustCompile("^foo(.*)"),
},
},
SearchExpr: searchExpr,
},
}
@ -178,21 +178,38 @@ func TestDeleteProject(t *testing.T) {
// Store fixtures.
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
senderReqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = badgerDB.Update(func(txn *badgerdb.Txn) error {
// Project item.
if err := txn.Set(entryKey(projectPrefix, 0, projectID[:]), nil); err != nil {
return err
}
// Sender request items.
if err := txn.Set(entryKey(senderReqPrefix, 0, senderReqID[:]), nil); err != nil {
return err
}
if err := txn.Set(entryKey(resLogPrefix, 0, senderReqID[:]), nil); err != nil {
return err
}
err := txn.Set(entryKey(senderReqPrefix, senderReqProjectIDIndex, append(projectID[:], senderReqID[:]...)), nil)
if err != nil {
return err
}
// Request log items.
if err := txn.Set(entryKey(reqLogPrefix, 0, reqLogID[:]), nil); err != nil {
return err
}
if err := txn.Set(entryKey(resLogPrefix, 0, reqLogID[:]), nil); err != nil {
return err
}
err := txn.Set(entryKey(reqLogPrefix, reqLogProjectIDIndex, append(projectID[:], reqLogID[:]...)), nil)
err = txn.Set(entryKey(reqLogPrefix, reqLogProjectIDIndex, append(projectID[:], reqLogID[:]...)), nil)
if err != nil {
return err
}
return nil
})
if err != nil {
@ -222,7 +239,7 @@ func TestDeleteProject(t *testing.T) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert response log item was deleted.
// Assert response log item related to request log was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(resLogPrefix, 0, reqLogID[:]))
return err
@ -239,6 +256,33 @@ func TestDeleteProject(t *testing.T) {
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert sender request item was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(senderReqPrefix, 0, senderReqID[:]))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert response log item related to sender request was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(resLogPrefix, 0, senderReqID[:]))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert sender request project ID index key was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(senderReqPrefix, senderReqProjectIDIndex, append(projectID[:], senderReqID[:]...)))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
}
func TestProjects(t *testing.T) {

View File

@ -66,7 +66,11 @@ func (db *Database) FindRequestLogs(ctx context.Context, filter reqlog.FindReque
func getRequestLogWithResponse(txn *badger.Txn, reqLogID ulid.ULID) (reqlog.RequestLog, error) {
item, err := txn.Get(entryKey(reqLogPrefix, 0, reqLogID[:]))
if err != nil {
switch {
case errors.Is(err, badger.ErrKeyNotFound):
return reqlog.RequestLog{}, reqlog.ErrRequestNotFound
case err != nil:
return reqlog.RequestLog{}, fmt.Errorf("failed to lookup request log item: %w", err)
}

236
pkg/db/badger/sender.go Normal file
View File

@ -0,0 +1,236 @@
package badger
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/dgraph-io/badger/v3"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/sender"
)
func (db *Database) StoreSenderRequest(ctx context.Context, req sender.Request) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(req)
if err != nil {
return fmt.Errorf("badger: failed to encode sender request: %w", err)
}
entries := []*badger.Entry{
// Sender request itself.
{
Key: entryKey(senderReqPrefix, 0, req.ID[:]),
Value: buf.Bytes(),
},
// Index by project ID.
{
Key: entryKey(senderReqPrefix, senderReqProjectIDIndex, append(req.ProjectID[:], req.ID[:]...)),
},
}
err = db.badger.Update(func(txn *badger.Txn) error {
for i := range entries {
err := txn.SetEntry(entries[i])
if err != nil {
return err
}
}
return nil
})
if err != nil {
return fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) FindSenderRequestByID(ctx context.Context, senderReqID ulid.ULID) (sender.Request, error) {
txn := db.badger.NewTransaction(false)
defer txn.Discard()
req, err := getSenderRequestWithResponseLog(txn, senderReqID)
if err != nil {
return sender.Request{}, fmt.Errorf("badger: failed to get sender request: %w", err)
}
return req, nil
}
func (db *Database) FindSenderRequests(ctx context.Context, filter sender.FindRequestsFilter, scope *scope.Scope) ([]sender.Request, error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, sender.ErrProjectIDMustBeSet
}
txn := db.badger.NewTransaction(false)
defer txn.Discard()
senderReqIDs, err := findSenderRequestIDsByProjectID(txn, filter.ProjectID)
if err != nil {
return nil, fmt.Errorf("badger: failed to find sender request IDs: %w", err)
}
senderReqs := make([]sender.Request, 0, len(senderReqIDs))
for _, id := range senderReqIDs {
senderReq, err := getSenderRequestWithResponseLog(txn, id)
if err != nil {
return nil, fmt.Errorf("badger: failed to get sender request (id: %v): %w", id.String(), err)
}
if filter.OnlyInScope {
if !senderReq.MatchScope(scope) {
continue
}
}
// Filter by search expression.
// TODO: Once pagination is introduced, this filter logic should be done
// as items are retrieved (e.g. when using a `badger.Iterator`).
if filter.SearchExpr != nil {
match, err := senderReq.Matches(filter.SearchExpr)
if err != nil {
return nil, fmt.Errorf(
"badger: failed to match search expression for sender request (id: %v): %w",
id.String(), err,
)
}
if !match {
continue
}
}
senderReqs = append(senderReqs, senderReq)
}
return senderReqs, nil
}
func (db *Database) DeleteSenderRequests(ctx context.Context, projectID ulid.ULID) error {
// Note: this transaction is used just for reading; we use the `badger.WriteBatch`
// API to bulk delete items.
txn := db.badger.NewTransaction(false)
defer txn.Discard()
senderReqIDs, err := findSenderRequestIDsByProjectID(txn, projectID)
if err != nil {
return fmt.Errorf("badger: failed to find sender request IDs: %w", err)
}
writeBatch := db.badger.NewWriteBatch()
defer writeBatch.Cancel()
for _, senderReqID := range senderReqIDs {
// Delete sender requests.
err := writeBatch.Delete(entryKey(senderReqPrefix, 0, senderReqID[:]))
if err != nil {
return fmt.Errorf("badger: failed to delete sender requests: %w", err)
}
// Delete related response log.
err = writeBatch.Delete(entryKey(resLogPrefix, 0, senderReqID[:]))
if err != nil {
return fmt.Errorf("badger: failed to delete request log: %w", err)
}
}
if err := writeBatch.Flush(); err != nil {
return fmt.Errorf("badger: failed to commit batch write: %w", err)
}
err = db.badger.DropPrefix(entryKey(senderReqPrefix, senderReqProjectIDIndex, projectID[:]))
if err != nil {
return fmt.Errorf("badger: failed to drop sender request project ID index items: %w", err)
}
return nil
}
func getSenderRequestWithResponseLog(txn *badger.Txn, senderReqID ulid.ULID) (sender.Request, error) {
item, err := txn.Get(entryKey(senderReqPrefix, 0, senderReqID[:]))
switch {
case errors.Is(err, badger.ErrKeyNotFound):
return sender.Request{}, sender.ErrRequestNotFound
case err != nil:
return sender.Request{}, fmt.Errorf("failed to lookup sender request item: %w", err)
}
req := sender.Request{
ID: senderReqID,
}
err = item.Value(func(rawSenderReq []byte) error {
err = gob.NewDecoder(bytes.NewReader(rawSenderReq)).Decode(&req)
if err != nil {
return fmt.Errorf("failed to decode sender request: %w", err)
}
return nil
})
if err != nil {
return sender.Request{}, fmt.Errorf("failed to retrieve or parse sender request value: %w", err)
}
item, err = txn.Get(entryKey(resLogPrefix, 0, senderReqID[:]))
if errors.Is(err, badger.ErrKeyNotFound) {
return req, nil
}
if err != nil {
return sender.Request{}, fmt.Errorf("failed to get response log: %w", err)
}
err = item.Value(func(rawReslog []byte) error {
var resLog reqlog.ResponseLog
err = gob.NewDecoder(bytes.NewReader(rawReslog)).Decode(&resLog)
if err != nil {
return fmt.Errorf("failed to decode response log: %w", err)
}
req.Response = &resLog
return nil
})
if err != nil {
return sender.Request{}, fmt.Errorf("failed to retrieve or parse response log value: %w", err)
}
return req, nil
}
func findSenderRequestIDsByProjectID(txn *badger.Txn, projectID ulid.ULID) ([]ulid.ULID, error) {
senderReqIDs := make([]ulid.ULID, 0)
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Reverse = true
iterator := txn.NewIterator(opts)
defer iterator.Close()
var projectIndexKey []byte
prefix := entryKey(senderReqPrefix, senderReqProjectIDIndex, projectID[:])
for iterator.Seek(append(prefix, 255)); iterator.ValidForPrefix(prefix); iterator.Next() {
projectIndexKey = iterator.Item().KeyCopy(projectIndexKey)
var id ulid.ULID
// The request log ID starts *after* the first 2 prefix and index bytes
// and the 16 byte project ID.
if err := id.UnmarshalBinary(projectIndexKey[18:]); err != nil {
return nil, fmt.Errorf("failed to parse sender request ID: %w", err)
}
senderReqIDs = append(senderReqIDs, id)
}
return senderReqIDs, nil
}

View File

@ -0,0 +1,204 @@
package badger_test
import (
"context"
"errors"
"math/rand"
"net/http"
"net/url"
"testing"
"time"
badgerdb "github.com/dgraph-io/badger/v3"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/db/badger"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/sender"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var exampleURL = func() *url.URL {
u, err := url.Parse("https://example.com/foobar")
if err != nil {
panic(err)
}
return u
}()
func TestFindRequestByID(t *testing.T) {
t.Parallel()
database, err := badger.OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
// See: https://go.dev/blog/subtests#cleaning-up-after-a-group-of-parallel-tests
t.Run("group", func(t *testing.T) {
t.Run("sender request not found", func(t *testing.T) {
t.Parallel()
_, err := database.FindSenderRequestByID(context.Background(), ulid.ULID{})
if !errors.Is(err, sender.ErrRequestNotFound) {
t.Fatalf("expected `sender.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("sender request found", func(t *testing.T) {
t.Parallel()
exp := sender.Request{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: sender.HTTPProto2,
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foo"),
}
err := database.StoreSenderRequest(context.Background(), exp)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
resLog := reqlog.ResponseLog{
Proto: "HTTP/2.0",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
}
err = database.StoreResponseLog(context.Background(), exp.ID, resLog)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
exp.Response = &resLog
got, err := database.FindSenderRequestByID(context.Background(), exp.ID)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender request not equal (-exp, +got):\n%v", diff)
}
})
})
}
func TestFindSenderRequests(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
database, err := badger.OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
filter := sender.FindRequestsFilter{}
_, err = database.FindSenderRequests(context.Background(), filter, nil)
if !errors.Is(err, sender.ErrProjectIDMustBeSet) {
t.Fatalf("expected `sender.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns sender requests and related response logs", func(t *testing.T) {
t.Parallel()
database, err := badger.OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
fixtures := []sender.Request{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
},
}
// Store fixtures.
for _, senderReq := range fixtures {
err = database.StoreSenderRequest(context.Background(), senderReq)
if err != nil {
t.Fatalf("unexpected error creating request log fixture: %v", err)
}
if senderReq.Response != nil {
err = database.StoreResponseLog(context.Background(), senderReq.ID, *senderReq.Response)
if err != nil {
t.Fatalf("unexpected error creating response log fixture: %v", err)
}
}
}
filter := sender.FindRequestsFilter{
ProjectID: projectID,
}
got, err := database.FindSenderRequests(context.Background(), filter, nil)
if err != nil {
t.Fatalf("unexpected error finding sender requests: %v", err)
}
// We expect the found sender requests are *reversed*, e.g. newest first.
exp := make([]sender.Request, len(fixtures))
for i, j := 0, len(fixtures)-1; i < j; i, j = i+1, j-1 {
exp[i], exp[j] = fixtures[j], fixtures[i]
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender requests not equal (-exp, +got):\n%v", diff)
}
})
}

View File

@ -15,6 +15,7 @@ import (
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
"github.com/dstotijn/hetty/pkg/sender"
)
//nolint:gosec
@ -37,13 +38,15 @@ type Service interface {
Scope() *scope.Scope
SetScopeRules(ctx context.Context, rules []scope.Rule) error
SetRequestLogFindFilter(ctx context.Context, filter reqlog.FindRequestsFilter) error
SetSenderRequestFindFilter(ctx context.Context, filter sender.FindRequestsFilter) error
OnProjectOpen(fn OnProjectOpenFn)
OnProjectClose(fn OnProjectCloseFn)
}
type service struct {
repo Repository
reqLogSvc *reqlog.Service
reqLogSvc reqlog.Service
senderSvc sender.Service
scope *scope.Scope
activeProjectID ulid.ULID
onProjectOpenFns []OnProjectOpenFn
@ -62,8 +65,12 @@ type Project struct {
type Settings struct {
ReqLogBypassOutOfScope bool
ReqLogOnlyFindInScope bool
ScopeRules []scope.Rule
SearchExpr search.Expression
ReqLogSearchExpr search.Expression
SenderOnlyFindInScope bool
SenderSearchExpr search.Expression
ScopeRules []scope.Rule
}
var (
@ -77,7 +84,8 @@ var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`)
type Config struct {
Repository Repository
ReqLogService *reqlog.Service
ReqLogService reqlog.Service
SenderService sender.Service
Scope *scope.Scope
}
@ -86,6 +94,7 @@ func NewService(cfg Config) (Service, error) {
return &service{
repo: cfg.Repository,
reqLogSvc: cfg.ReqLogService,
senderSvc: cfg.SenderService,
scope: cfg.Scope,
}, nil
}
@ -120,9 +129,11 @@ func (svc *service) CloseProject() error {
closedProjectID := svc.activeProjectID
svc.activeProjectID = ulid.ULID{}
svc.reqLogSvc.ActiveProjectID = ulid.ULID{}
svc.reqLogSvc.BypassOutOfScopeRequests = false
svc.reqLogSvc.FindReqsFilter = reqlog.FindRequestsFilter{}
svc.reqLogSvc.SetActiveProjectID(ulid.ULID{})
svc.reqLogSvc.SetBypassOutOfScopeRequests(false)
svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{})
svc.senderSvc.SetActiveProjectID(ulid.ULID{})
svc.senderSvc.SetFindReqsFilter(sender.FindRequestsFilter{})
svc.scope.SetRules(nil)
svc.emitProjectClosed(closedProjectID)
@ -154,13 +165,21 @@ func (svc *service) OpenProject(ctx context.Context, projectID ulid.ULID) (Proje
}
svc.activeProjectID = project.ID
svc.reqLogSvc.FindReqsFilter = reqlog.FindRequestsFilter{
svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{
ProjectID: project.ID,
OnlyInScope: project.Settings.ReqLogOnlyFindInScope,
SearchExpr: project.Settings.SearchExpr,
}
svc.reqLogSvc.BypassOutOfScopeRequests = project.Settings.ReqLogBypassOutOfScope
svc.reqLogSvc.ActiveProjectID = project.ID
SearchExpr: project.Settings.ReqLogSearchExpr,
})
svc.reqLogSvc.SetBypassOutOfScopeRequests(project.Settings.ReqLogBypassOutOfScope)
svc.reqLogSvc.SetActiveProjectID(project.ID)
svc.senderSvc.SetActiveProjectID(project.ID)
svc.senderSvc.SetFindReqsFilter(sender.FindRequestsFilter{
ProjectID: project.ID,
OnlyInScope: project.Settings.SenderOnlyFindInScope,
SearchExpr: project.Settings.SenderSearchExpr,
})
svc.scope.SetRules(project.Settings.ScopeRules)
@ -255,14 +274,35 @@ func (svc *service) SetRequestLogFindFilter(ctx context.Context, filter reqlog.F
filter.ProjectID = project.ID
project.Settings.ReqLogOnlyFindInScope = filter.OnlyInScope
project.Settings.SearchExpr = filter.SearchExpr
project.Settings.ReqLogSearchExpr = filter.SearchExpr
err = svc.repo.UpsertProject(ctx, project)
if err != nil {
return fmt.Errorf("proj: failed to update project: %w", err)
}
svc.reqLogSvc.FindReqsFilter = filter
svc.reqLogSvc.SetFindReqsFilter(filter)
return nil
}
func (svc *service) SetSenderRequestFindFilter(ctx context.Context, filter sender.FindRequestsFilter) error {
project, err := svc.ActiveProject(ctx)
if err != nil {
return err
}
filter.ProjectID = project.ID
project.Settings.SenderOnlyFindInScope = filter.OnlyInScope
project.Settings.SenderSearchExpr = filter.SearchExpr
err = svc.repo.UpsertProject(ctx, project)
if err != nil {
return fmt.Errorf("proj: failed to update project: %w", err)
}
svc.senderSvc.SetFindReqsFilter(filter)
return nil
}

View File

@ -54,13 +54,26 @@ type ResponseLog struct {
Body []byte
}
type Service struct {
BypassOutOfScopeRequests bool
FindReqsFilter FindRequestsFilter
ActiveProjectID ulid.ULID
type Service interface {
FindRequests(ctx context.Context) ([]RequestLog, error)
FindRequestLogByID(ctx context.Context, id ulid.ULID) (RequestLog, error)
ClearRequests(ctx context.Context, projectID ulid.ULID) error
RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc
ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc
SetActiveProjectID(id ulid.ULID)
ActiveProjectID() ulid.ULID
SetBypassOutOfScopeRequests(bool)
BypassOutOfScopeRequests() bool
SetFindReqsFilter(filter FindRequestsFilter)
FindReqsFilter() FindRequestsFilter
}
scope *scope.Scope
repo Repository
type service struct {
bypassOutOfScopeRequests bool
findReqsFilter FindRequestsFilter
activeProjectID ulid.ULID
scope *scope.Scope
repo Repository
}
type FindRequestsFilter struct {
@ -74,59 +87,35 @@ type Config struct {
Repository Repository
}
func NewService(cfg Config) *Service {
return &Service{
func NewService(cfg Config) Service {
return &service{
repo: cfg.Repository,
scope: cfg.Scope,
}
}
func (svc *Service) FindRequests(ctx context.Context) ([]RequestLog, error) {
return svc.repo.FindRequestLogs(ctx, svc.FindReqsFilter, svc.scope)
func (svc *service) FindRequests(ctx context.Context) ([]RequestLog, error) {
return svc.repo.FindRequestLogs(ctx, svc.findReqsFilter, svc.scope)
}
func (svc *Service) FindRequestLogByID(ctx context.Context, id ulid.ULID) (RequestLog, error) {
func (svc *service) FindRequestLogByID(ctx context.Context, id ulid.ULID) (RequestLog, error) {
return svc.repo.FindRequestLogByID(ctx, id)
}
func (svc *Service) ClearRequests(ctx context.Context, projectID ulid.ULID) error {
func (svc *service) ClearRequests(ctx context.Context, projectID ulid.ULID) error {
return svc.repo.ClearRequestLogs(ctx, projectID)
}
func (svc *Service) storeResponse(ctx context.Context, reqLogID ulid.ULID, res *http.Response) error {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(res.Body)
if err != nil {
return fmt.Errorf("could not create gzip reader: %w", err)
}
defer gzipReader.Close()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, gzipReader); err != nil {
return fmt.Errorf("could not read gzipped response body: %w", err)
}
res.Body = io.NopCloser(buf)
}
body, err := io.ReadAll(res.Body)
func (svc *service) storeResponse(ctx context.Context, reqLogID ulid.ULID, res *http.Response) error {
resLog, err := ParseHTTPResponse(res)
if err != nil {
return fmt.Errorf("could not read body: %w", err)
}
resLog := ResponseLog{
Proto: res.Proto,
StatusCode: res.StatusCode,
Status: res.Status,
Header: res.Header,
Body: body,
return err
}
return svc.repo.StoreResponseLog(ctx, reqLogID, resLog)
}
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
next(req)
@ -149,7 +138,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
}
// Bypass logging if no project is active.
if svc.ActiveProjectID.Compare(ulid.ULID{}) == 0 {
if svc.activeProjectID.Compare(ulid.ULID{}) == 0 {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
@ -158,7 +147,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
// Bypass logging if this setting is enabled and the incoming request
// doesn't match any scope rules.
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
if svc.bypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
@ -167,7 +156,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
reqLog := RequestLog{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: svc.ActiveProjectID,
ProjectID: svc.activeProjectID,
Method: clone.Method,
URL: clone.URL,
Proto: clone.Proto,
@ -186,7 +175,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
}
}
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
if err := next(res); err != nil {
return err
@ -221,3 +210,58 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
}
func (svc *service) SetActiveProjectID(id ulid.ULID) {
svc.activeProjectID = id
}
func (svc *service) ActiveProjectID() ulid.ULID {
return svc.activeProjectID
}
func (svc *service) SetFindReqsFilter(filter FindRequestsFilter) {
svc.findReqsFilter = filter
}
func (svc *service) FindReqsFilter() FindRequestsFilter {
return svc.findReqsFilter
}
func (svc *service) SetBypassOutOfScopeRequests(bypass bool) {
svc.bypassOutOfScopeRequests = bypass
}
func (svc *service) BypassOutOfScopeRequests() bool {
return svc.bypassOutOfScopeRequests
}
func ParseHTTPResponse(res *http.Response) (ResponseLog, error) {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(res.Body)
if err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
}
defer gzipReader.Close()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, gzipReader); err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
}
res.Body = io.NopCloser(buf)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return ResponseLog{}, fmt.Errorf("reqlog: could not read body: %w", err)
}
return ResponseLog{
Proto: res.Proto,
StatusCode: res.StatusCode,
Status: res.Status,
Header: res.Header,
Body: body,
}, nil
}

View File

@ -34,7 +34,7 @@ func TestRequestModifier(t *testing.T) {
Repository: repoMock,
Scope: &scope.Scope{},
})
svc.ActiveProjectID = ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
svc.SetActiveProjectID(ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy))
next := func(req *http.Request) {
req.Body = io.NopCloser(strings.NewReader("modified body"))
@ -52,7 +52,7 @@ func TestRequestModifier(t *testing.T) {
exp := reqlog.RequestLog{
ID: ulid.ULID{}, // Empty value
ProjectID: svc.ActiveProjectID,
ProjectID: svc.ActiveProjectID(),
Method: req.Method,
URL: req.URL,
Proto: req.Proto,
@ -78,7 +78,7 @@ func TestResponseModifier(t *testing.T) {
svc := reqlog.NewService(reqlog.Config{
Repository: repoMock,
})
svc.ActiveProjectID = ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
svc.SetActiveProjectID(ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy))
next := func(res *http.Response) error {
res.Body = io.NopCloser(strings.NewReader("modified body"))

View File

@ -27,7 +27,7 @@ var reqLogSearchKeyFns = map[string]func(rl RequestLog) string{
"req.timestamp": func(rl RequestLog) string { return ulid.Time(rl.ID.Time()).String() },
}
var resLogSearchKeyFns = map[string]func(rl ResponseLog) string{
var ResLogSearchKeyFns = map[string]func(rl ResponseLog) string{
"res.proto": func(rl ResponseLog) string { return rl.Proto },
"res.statusCode": func(rl ResponseLog) string { return strconv.Itoa(rl.StatusCode) },
"res.statusReason": func(rl ResponseLog) string { return rl.Status },
@ -154,7 +154,7 @@ func (reqLog RequestLog) getMappedStringLiteral(s string) string {
return ""
}
fn, ok := resLogSearchKeyFns[s]
fn, ok := ResLogSearchKeyFns[s]
if ok {
return fn(*reqLog.Response)
}
@ -174,7 +174,7 @@ func (reqLog RequestLog) matchStringLiteral(strLiteral search.StringLiteral) (bo
}
if reqLog.Response != nil {
for _, fn := range resLogSearchKeyFns {
for _, fn := range ResLogSearchKeyFns {
if strings.Contains(
strings.ToLower(fn(*reqLog.Response)),
strings.ToLower(strLiteral.Value),

18
pkg/sender/repo.go Normal file
View File

@ -0,0 +1,18 @@
package sender
import (
"context"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
type Repository interface {
FindSenderRequestByID(ctx context.Context, id ulid.ULID) (Request, error)
FindSenderRequests(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
StoreSenderRequest(ctx context.Context, req Request) error
StoreResponseLog(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error
DeleteSenderRequests(ctx context.Context, projectID ulid.ULID) error
}

View File

@ -0,0 +1,292 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package sender_test
import (
"context"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/sender"
"github.com/oklog/ulid"
"sync"
)
// Ensure, that RepoMock does implement sender.Repository.
// If this is not the case, regenerate this file with moq.
var _ sender.Repository = &RepoMock{}
// RepoMock is a mock implementation of sender.Repository.
//
// func TestSomethingThatUsesRepository(t *testing.T) {
//
// // make and configure a mocked sender.Repository
// mockedRepository := &RepoMock{
// DeleteSenderRequestsFunc: func(ctx context.Context, projectID ulid.ULID) error {
// panic("mock out the DeleteSenderRequests method")
// },
// FindSenderRequestByIDFunc: func(ctx context.Context, id ulid.ULID) (sender.Request, error) {
// panic("mock out the FindSenderRequestByID method")
// },
// FindSenderRequestsFunc: func(ctx context.Context, filter sender.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]sender.Request, error) {
// panic("mock out the FindSenderRequests method")
// },
// StoreResponseLogFunc: func(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
// panic("mock out the StoreResponseLog method")
// },
// StoreSenderRequestFunc: func(ctx context.Context, req sender.Request) error {
// panic("mock out the StoreSenderRequest method")
// },
// }
//
// // use mockedRepository in code that requires sender.Repository
// // and then make assertions.
//
// }
type RepoMock struct {
// DeleteSenderRequestsFunc mocks the DeleteSenderRequests method.
DeleteSenderRequestsFunc func(ctx context.Context, projectID ulid.ULID) error
// FindSenderRequestByIDFunc mocks the FindSenderRequestByID method.
FindSenderRequestByIDFunc func(ctx context.Context, id ulid.ULID) (sender.Request, error)
// FindSenderRequestsFunc mocks the FindSenderRequests method.
FindSenderRequestsFunc func(ctx context.Context, filter sender.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]sender.Request, error)
// StoreResponseLogFunc mocks the StoreResponseLog method.
StoreResponseLogFunc func(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error
// StoreSenderRequestFunc mocks the StoreSenderRequest method.
StoreSenderRequestFunc func(ctx context.Context, req sender.Request) error
// calls tracks calls to the methods.
calls struct {
// DeleteSenderRequests holds details about calls to the DeleteSenderRequests method.
DeleteSenderRequests []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ProjectID is the projectID argument value.
ProjectID ulid.ULID
}
// FindSenderRequestByID holds details about calls to the FindSenderRequestByID method.
FindSenderRequestByID []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ID is the id argument value.
ID ulid.ULID
}
// FindSenderRequests holds details about calls to the FindSenderRequests method.
FindSenderRequests []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Filter is the filter argument value.
Filter sender.FindRequestsFilter
// ScopeMoqParam is the scopeMoqParam argument value.
ScopeMoqParam *scope.Scope
}
// StoreResponseLog holds details about calls to the StoreResponseLog method.
StoreResponseLog []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ReqLogID is the reqLogID argument value.
ReqLogID ulid.ULID
// ResLog is the resLog argument value.
ResLog reqlog.ResponseLog
}
// StoreSenderRequest holds details about calls to the StoreSenderRequest method.
StoreSenderRequest []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Req is the req argument value.
Req sender.Request
}
}
lockDeleteSenderRequests sync.RWMutex
lockFindSenderRequestByID sync.RWMutex
lockFindSenderRequests sync.RWMutex
lockStoreResponseLog sync.RWMutex
lockStoreSenderRequest sync.RWMutex
}
// DeleteSenderRequests calls DeleteSenderRequestsFunc.
func (mock *RepoMock) DeleteSenderRequests(ctx context.Context, projectID ulid.ULID) error {
if mock.DeleteSenderRequestsFunc == nil {
panic("RepoMock.DeleteSenderRequestsFunc: method is nil but Repository.DeleteSenderRequests was just called")
}
callInfo := struct {
Ctx context.Context
ProjectID ulid.ULID
}{
Ctx: ctx,
ProjectID: projectID,
}
mock.lockDeleteSenderRequests.Lock()
mock.calls.DeleteSenderRequests = append(mock.calls.DeleteSenderRequests, callInfo)
mock.lockDeleteSenderRequests.Unlock()
return mock.DeleteSenderRequestsFunc(ctx, projectID)
}
// DeleteSenderRequestsCalls gets all the calls that were made to DeleteSenderRequests.
// Check the length with:
// len(mockedRepository.DeleteSenderRequestsCalls())
func (mock *RepoMock) DeleteSenderRequestsCalls() []struct {
Ctx context.Context
ProjectID ulid.ULID
} {
var calls []struct {
Ctx context.Context
ProjectID ulid.ULID
}
mock.lockDeleteSenderRequests.RLock()
calls = mock.calls.DeleteSenderRequests
mock.lockDeleteSenderRequests.RUnlock()
return calls
}
// FindSenderRequestByID calls FindSenderRequestByIDFunc.
func (mock *RepoMock) FindSenderRequestByID(ctx context.Context, id ulid.ULID) (sender.Request, error) {
if mock.FindSenderRequestByIDFunc == nil {
panic("RepoMock.FindSenderRequestByIDFunc: method is nil but Repository.FindSenderRequestByID was just called")
}
callInfo := struct {
Ctx context.Context
ID ulid.ULID
}{
Ctx: ctx,
ID: id,
}
mock.lockFindSenderRequestByID.Lock()
mock.calls.FindSenderRequestByID = append(mock.calls.FindSenderRequestByID, callInfo)
mock.lockFindSenderRequestByID.Unlock()
return mock.FindSenderRequestByIDFunc(ctx, id)
}
// FindSenderRequestByIDCalls gets all the calls that were made to FindSenderRequestByID.
// Check the length with:
// len(mockedRepository.FindSenderRequestByIDCalls())
func (mock *RepoMock) FindSenderRequestByIDCalls() []struct {
Ctx context.Context
ID ulid.ULID
} {
var calls []struct {
Ctx context.Context
ID ulid.ULID
}
mock.lockFindSenderRequestByID.RLock()
calls = mock.calls.FindSenderRequestByID
mock.lockFindSenderRequestByID.RUnlock()
return calls
}
// FindSenderRequests calls FindSenderRequestsFunc.
func (mock *RepoMock) FindSenderRequests(ctx context.Context, filter sender.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]sender.Request, error) {
if mock.FindSenderRequestsFunc == nil {
panic("RepoMock.FindSenderRequestsFunc: method is nil but Repository.FindSenderRequests was just called")
}
callInfo := struct {
Ctx context.Context
Filter sender.FindRequestsFilter
ScopeMoqParam *scope.Scope
}{
Ctx: ctx,
Filter: filter,
ScopeMoqParam: scopeMoqParam,
}
mock.lockFindSenderRequests.Lock()
mock.calls.FindSenderRequests = append(mock.calls.FindSenderRequests, callInfo)
mock.lockFindSenderRequests.Unlock()
return mock.FindSenderRequestsFunc(ctx, filter, scopeMoqParam)
}
// FindSenderRequestsCalls gets all the calls that were made to FindSenderRequests.
// Check the length with:
// len(mockedRepository.FindSenderRequestsCalls())
func (mock *RepoMock) FindSenderRequestsCalls() []struct {
Ctx context.Context
Filter sender.FindRequestsFilter
ScopeMoqParam *scope.Scope
} {
var calls []struct {
Ctx context.Context
Filter sender.FindRequestsFilter
ScopeMoqParam *scope.Scope
}
mock.lockFindSenderRequests.RLock()
calls = mock.calls.FindSenderRequests
mock.lockFindSenderRequests.RUnlock()
return calls
}
// StoreResponseLog calls StoreResponseLogFunc.
func (mock *RepoMock) StoreResponseLog(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
if mock.StoreResponseLogFunc == nil {
panic("RepoMock.StoreResponseLogFunc: method is nil but Repository.StoreResponseLog was just called")
}
callInfo := struct {
Ctx context.Context
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
}{
Ctx: ctx,
ReqLogID: reqLogID,
ResLog: resLog,
}
mock.lockStoreResponseLog.Lock()
mock.calls.StoreResponseLog = append(mock.calls.StoreResponseLog, callInfo)
mock.lockStoreResponseLog.Unlock()
return mock.StoreResponseLogFunc(ctx, reqLogID, resLog)
}
// StoreResponseLogCalls gets all the calls that were made to StoreResponseLog.
// Check the length with:
// len(mockedRepository.StoreResponseLogCalls())
func (mock *RepoMock) StoreResponseLogCalls() []struct {
Ctx context.Context
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
} {
var calls []struct {
Ctx context.Context
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
}
mock.lockStoreResponseLog.RLock()
calls = mock.calls.StoreResponseLog
mock.lockStoreResponseLog.RUnlock()
return calls
}
// StoreSenderRequest calls StoreSenderRequestFunc.
func (mock *RepoMock) StoreSenderRequest(ctx context.Context, req sender.Request) error {
if mock.StoreSenderRequestFunc == nil {
panic("RepoMock.StoreSenderRequestFunc: method is nil but Repository.StoreSenderRequest was just called")
}
callInfo := struct {
Ctx context.Context
Req sender.Request
}{
Ctx: ctx,
Req: req,
}
mock.lockStoreSenderRequest.Lock()
mock.calls.StoreSenderRequest = append(mock.calls.StoreSenderRequest, callInfo)
mock.lockStoreSenderRequest.Unlock()
return mock.StoreSenderRequestFunc(ctx, req)
}
// StoreSenderRequestCalls gets all the calls that were made to StoreSenderRequest.
// Check the length with:
// len(mockedRepository.StoreSenderRequestCalls())
func (mock *RepoMock) StoreSenderRequestCalls() []struct {
Ctx context.Context
Req sender.Request
} {
var calls []struct {
Ctx context.Context
Req sender.Request
}
mock.lockStoreSenderRequest.RLock()
calls = mock.calls.StoreSenderRequest
mock.lockStoreSenderRequest.RUnlock()
return calls
}

View File

@ -0,0 +1,498 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package sender_test
import (
"context"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/oklog/ulid"
"sync"
)
// Ensure, that ReqLogServiceMock does implement reqlog.Service.
// If this is not the case, regenerate this file with moq.
var _ reqlog.Service = &ReqLogServiceMock{}
// ReqLogServiceMock is a mock implementation of reqlog.Service.
//
// func TestSomethingThatUsesService(t *testing.T) {
//
// // make and configure a mocked reqlog.Service
// mockedService := &ReqLogServiceMock{
// ActiveProjectIDFunc: func() ulid.ULID {
// panic("mock out the ActiveProjectID method")
// },
// BypassOutOfScopeRequestsFunc: func() bool {
// panic("mock out the BypassOutOfScopeRequests method")
// },
// ClearRequestsFunc: func(ctx context.Context, projectID ulid.ULID) error {
// panic("mock out the ClearRequests method")
// },
// FindReqsFilterFunc: func() reqlog.FindRequestsFilter {
// panic("mock out the FindReqsFilter method")
// },
// FindRequestLogByIDFunc: func(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error) {
// panic("mock out the FindRequestLogByID method")
// },
// FindRequestsFunc: func(ctx context.Context) ([]reqlog.RequestLog, error) {
// panic("mock out the FindRequests method")
// },
// RequestModifierFunc: func(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
// panic("mock out the RequestModifier method")
// },
// ResponseModifierFunc: func(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
// panic("mock out the ResponseModifier method")
// },
// SetActiveProjectIDFunc: func(id ulid.ULID) {
// panic("mock out the SetActiveProjectID method")
// },
// SetBypassOutOfScopeRequestsFunc: func(b bool) {
// panic("mock out the SetBypassOutOfScopeRequests method")
// },
// SetFindReqsFilterFunc: func(filter reqlog.FindRequestsFilter) {
// panic("mock out the SetFindReqsFilter method")
// },
// }
//
// // use mockedService in code that requires reqlog.Service
// // and then make assertions.
//
// }
type ReqLogServiceMock struct {
// ActiveProjectIDFunc mocks the ActiveProjectID method.
ActiveProjectIDFunc func() ulid.ULID
// BypassOutOfScopeRequestsFunc mocks the BypassOutOfScopeRequests method.
BypassOutOfScopeRequestsFunc func() bool
// ClearRequestsFunc mocks the ClearRequests method.
ClearRequestsFunc func(ctx context.Context, projectID ulid.ULID) error
// FindReqsFilterFunc mocks the FindReqsFilter method.
FindReqsFilterFunc func() reqlog.FindRequestsFilter
// FindRequestLogByIDFunc mocks the FindRequestLogByID method.
FindRequestLogByIDFunc func(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error)
// FindRequestsFunc mocks the FindRequests method.
FindRequestsFunc func(ctx context.Context) ([]reqlog.RequestLog, error)
// RequestModifierFunc mocks the RequestModifier method.
RequestModifierFunc func(next proxy.RequestModifyFunc) proxy.RequestModifyFunc
// ResponseModifierFunc mocks the ResponseModifier method.
ResponseModifierFunc func(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc
// SetActiveProjectIDFunc mocks the SetActiveProjectID method.
SetActiveProjectIDFunc func(id ulid.ULID)
// SetBypassOutOfScopeRequestsFunc mocks the SetBypassOutOfScopeRequests method.
SetBypassOutOfScopeRequestsFunc func(b bool)
// SetFindReqsFilterFunc mocks the SetFindReqsFilter method.
SetFindReqsFilterFunc func(filter reqlog.FindRequestsFilter)
// calls tracks calls to the methods.
calls struct {
// ActiveProjectID holds details about calls to the ActiveProjectID method.
ActiveProjectID []struct {
}
// BypassOutOfScopeRequests holds details about calls to the BypassOutOfScopeRequests method.
BypassOutOfScopeRequests []struct {
}
// ClearRequests holds details about calls to the ClearRequests method.
ClearRequests []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ProjectID is the projectID argument value.
ProjectID ulid.ULID
}
// FindReqsFilter holds details about calls to the FindReqsFilter method.
FindReqsFilter []struct {
}
// FindRequestLogByID holds details about calls to the FindRequestLogByID method.
FindRequestLogByID []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ID is the id argument value.
ID ulid.ULID
}
// FindRequests holds details about calls to the FindRequests method.
FindRequests []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// RequestModifier holds details about calls to the RequestModifier method.
RequestModifier []struct {
// Next is the next argument value.
Next proxy.RequestModifyFunc
}
// ResponseModifier holds details about calls to the ResponseModifier method.
ResponseModifier []struct {
// Next is the next argument value.
Next proxy.ResponseModifyFunc
}
// SetActiveProjectID holds details about calls to the SetActiveProjectID method.
SetActiveProjectID []struct {
// ID is the id argument value.
ID ulid.ULID
}
// SetBypassOutOfScopeRequests holds details about calls to the SetBypassOutOfScopeRequests method.
SetBypassOutOfScopeRequests []struct {
// B is the b argument value.
B bool
}
// SetFindReqsFilter holds details about calls to the SetFindReqsFilter method.
SetFindReqsFilter []struct {
// Filter is the filter argument value.
Filter reqlog.FindRequestsFilter
}
}
lockActiveProjectID sync.RWMutex
lockBypassOutOfScopeRequests sync.RWMutex
lockClearRequests sync.RWMutex
lockFindReqsFilter sync.RWMutex
lockFindRequestLogByID sync.RWMutex
lockFindRequests sync.RWMutex
lockRequestModifier sync.RWMutex
lockResponseModifier sync.RWMutex
lockSetActiveProjectID sync.RWMutex
lockSetBypassOutOfScopeRequests sync.RWMutex
lockSetFindReqsFilter sync.RWMutex
}
// ActiveProjectID calls ActiveProjectIDFunc.
func (mock *ReqLogServiceMock) ActiveProjectID() ulid.ULID {
if mock.ActiveProjectIDFunc == nil {
panic("ReqLogServiceMock.ActiveProjectIDFunc: method is nil but Service.ActiveProjectID was just called")
}
callInfo := struct {
}{}
mock.lockActiveProjectID.Lock()
mock.calls.ActiveProjectID = append(mock.calls.ActiveProjectID, callInfo)
mock.lockActiveProjectID.Unlock()
return mock.ActiveProjectIDFunc()
}
// ActiveProjectIDCalls gets all the calls that were made to ActiveProjectID.
// Check the length with:
// len(mockedService.ActiveProjectIDCalls())
func (mock *ReqLogServiceMock) ActiveProjectIDCalls() []struct {
} {
var calls []struct {
}
mock.lockActiveProjectID.RLock()
calls = mock.calls.ActiveProjectID
mock.lockActiveProjectID.RUnlock()
return calls
}
// BypassOutOfScopeRequests calls BypassOutOfScopeRequestsFunc.
func (mock *ReqLogServiceMock) BypassOutOfScopeRequests() bool {
if mock.BypassOutOfScopeRequestsFunc == nil {
panic("ReqLogServiceMock.BypassOutOfScopeRequestsFunc: method is nil but Service.BypassOutOfScopeRequests was just called")
}
callInfo := struct {
}{}
mock.lockBypassOutOfScopeRequests.Lock()
mock.calls.BypassOutOfScopeRequests = append(mock.calls.BypassOutOfScopeRequests, callInfo)
mock.lockBypassOutOfScopeRequests.Unlock()
return mock.BypassOutOfScopeRequestsFunc()
}
// BypassOutOfScopeRequestsCalls gets all the calls that were made to BypassOutOfScopeRequests.
// Check the length with:
// len(mockedService.BypassOutOfScopeRequestsCalls())
func (mock *ReqLogServiceMock) BypassOutOfScopeRequestsCalls() []struct {
} {
var calls []struct {
}
mock.lockBypassOutOfScopeRequests.RLock()
calls = mock.calls.BypassOutOfScopeRequests
mock.lockBypassOutOfScopeRequests.RUnlock()
return calls
}
// ClearRequests calls ClearRequestsFunc.
func (mock *ReqLogServiceMock) ClearRequests(ctx context.Context, projectID ulid.ULID) error {
if mock.ClearRequestsFunc == nil {
panic("ReqLogServiceMock.ClearRequestsFunc: method is nil but Service.ClearRequests was just called")
}
callInfo := struct {
Ctx context.Context
ProjectID ulid.ULID
}{
Ctx: ctx,
ProjectID: projectID,
}
mock.lockClearRequests.Lock()
mock.calls.ClearRequests = append(mock.calls.ClearRequests, callInfo)
mock.lockClearRequests.Unlock()
return mock.ClearRequestsFunc(ctx, projectID)
}
// ClearRequestsCalls gets all the calls that were made to ClearRequests.
// Check the length with:
// len(mockedService.ClearRequestsCalls())
func (mock *ReqLogServiceMock) ClearRequestsCalls() []struct {
Ctx context.Context
ProjectID ulid.ULID
} {
var calls []struct {
Ctx context.Context
ProjectID ulid.ULID
}
mock.lockClearRequests.RLock()
calls = mock.calls.ClearRequests
mock.lockClearRequests.RUnlock()
return calls
}
// FindReqsFilter calls FindReqsFilterFunc.
func (mock *ReqLogServiceMock) FindReqsFilter() reqlog.FindRequestsFilter {
if mock.FindReqsFilterFunc == nil {
panic("ReqLogServiceMock.FindReqsFilterFunc: method is nil but Service.FindReqsFilter was just called")
}
callInfo := struct {
}{}
mock.lockFindReqsFilter.Lock()
mock.calls.FindReqsFilter = append(mock.calls.FindReqsFilter, callInfo)
mock.lockFindReqsFilter.Unlock()
return mock.FindReqsFilterFunc()
}
// FindReqsFilterCalls gets all the calls that were made to FindReqsFilter.
// Check the length with:
// len(mockedService.FindReqsFilterCalls())
func (mock *ReqLogServiceMock) FindReqsFilterCalls() []struct {
} {
var calls []struct {
}
mock.lockFindReqsFilter.RLock()
calls = mock.calls.FindReqsFilter
mock.lockFindReqsFilter.RUnlock()
return calls
}
// FindRequestLogByID calls FindRequestLogByIDFunc.
func (mock *ReqLogServiceMock) FindRequestLogByID(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error) {
if mock.FindRequestLogByIDFunc == nil {
panic("ReqLogServiceMock.FindRequestLogByIDFunc: method is nil but Service.FindRequestLogByID was just called")
}
callInfo := struct {
Ctx context.Context
ID ulid.ULID
}{
Ctx: ctx,
ID: id,
}
mock.lockFindRequestLogByID.Lock()
mock.calls.FindRequestLogByID = append(mock.calls.FindRequestLogByID, callInfo)
mock.lockFindRequestLogByID.Unlock()
return mock.FindRequestLogByIDFunc(ctx, id)
}
// FindRequestLogByIDCalls gets all the calls that were made to FindRequestLogByID.
// Check the length with:
// len(mockedService.FindRequestLogByIDCalls())
func (mock *ReqLogServiceMock) FindRequestLogByIDCalls() []struct {
Ctx context.Context
ID ulid.ULID
} {
var calls []struct {
Ctx context.Context
ID ulid.ULID
}
mock.lockFindRequestLogByID.RLock()
calls = mock.calls.FindRequestLogByID
mock.lockFindRequestLogByID.RUnlock()
return calls
}
// FindRequests calls FindRequestsFunc.
func (mock *ReqLogServiceMock) FindRequests(ctx context.Context) ([]reqlog.RequestLog, error) {
if mock.FindRequestsFunc == nil {
panic("ReqLogServiceMock.FindRequestsFunc: method is nil but Service.FindRequests was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockFindRequests.Lock()
mock.calls.FindRequests = append(mock.calls.FindRequests, callInfo)
mock.lockFindRequests.Unlock()
return mock.FindRequestsFunc(ctx)
}
// FindRequestsCalls gets all the calls that were made to FindRequests.
// Check the length with:
// len(mockedService.FindRequestsCalls())
func (mock *ReqLogServiceMock) FindRequestsCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockFindRequests.RLock()
calls = mock.calls.FindRequests
mock.lockFindRequests.RUnlock()
return calls
}
// RequestModifier calls RequestModifierFunc.
func (mock *ReqLogServiceMock) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
if mock.RequestModifierFunc == nil {
panic("ReqLogServiceMock.RequestModifierFunc: method is nil but Service.RequestModifier was just called")
}
callInfo := struct {
Next proxy.RequestModifyFunc
}{
Next: next,
}
mock.lockRequestModifier.Lock()
mock.calls.RequestModifier = append(mock.calls.RequestModifier, callInfo)
mock.lockRequestModifier.Unlock()
return mock.RequestModifierFunc(next)
}
// RequestModifierCalls gets all the calls that were made to RequestModifier.
// Check the length with:
// len(mockedService.RequestModifierCalls())
func (mock *ReqLogServiceMock) RequestModifierCalls() []struct {
Next proxy.RequestModifyFunc
} {
var calls []struct {
Next proxy.RequestModifyFunc
}
mock.lockRequestModifier.RLock()
calls = mock.calls.RequestModifier
mock.lockRequestModifier.RUnlock()
return calls
}
// ResponseModifier calls ResponseModifierFunc.
func (mock *ReqLogServiceMock) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
if mock.ResponseModifierFunc == nil {
panic("ReqLogServiceMock.ResponseModifierFunc: method is nil but Service.ResponseModifier was just called")
}
callInfo := struct {
Next proxy.ResponseModifyFunc
}{
Next: next,
}
mock.lockResponseModifier.Lock()
mock.calls.ResponseModifier = append(mock.calls.ResponseModifier, callInfo)
mock.lockResponseModifier.Unlock()
return mock.ResponseModifierFunc(next)
}
// ResponseModifierCalls gets all the calls that were made to ResponseModifier.
// Check the length with:
// len(mockedService.ResponseModifierCalls())
func (mock *ReqLogServiceMock) ResponseModifierCalls() []struct {
Next proxy.ResponseModifyFunc
} {
var calls []struct {
Next proxy.ResponseModifyFunc
}
mock.lockResponseModifier.RLock()
calls = mock.calls.ResponseModifier
mock.lockResponseModifier.RUnlock()
return calls
}
// SetActiveProjectID calls SetActiveProjectIDFunc.
func (mock *ReqLogServiceMock) SetActiveProjectID(id ulid.ULID) {
if mock.SetActiveProjectIDFunc == nil {
panic("ReqLogServiceMock.SetActiveProjectIDFunc: method is nil but Service.SetActiveProjectID was just called")
}
callInfo := struct {
ID ulid.ULID
}{
ID: id,
}
mock.lockSetActiveProjectID.Lock()
mock.calls.SetActiveProjectID = append(mock.calls.SetActiveProjectID, callInfo)
mock.lockSetActiveProjectID.Unlock()
mock.SetActiveProjectIDFunc(id)
}
// SetActiveProjectIDCalls gets all the calls that were made to SetActiveProjectID.
// Check the length with:
// len(mockedService.SetActiveProjectIDCalls())
func (mock *ReqLogServiceMock) SetActiveProjectIDCalls() []struct {
ID ulid.ULID
} {
var calls []struct {
ID ulid.ULID
}
mock.lockSetActiveProjectID.RLock()
calls = mock.calls.SetActiveProjectID
mock.lockSetActiveProjectID.RUnlock()
return calls
}
// SetBypassOutOfScopeRequests calls SetBypassOutOfScopeRequestsFunc.
func (mock *ReqLogServiceMock) SetBypassOutOfScopeRequests(b bool) {
if mock.SetBypassOutOfScopeRequestsFunc == nil {
panic("ReqLogServiceMock.SetBypassOutOfScopeRequestsFunc: method is nil but Service.SetBypassOutOfScopeRequests was just called")
}
callInfo := struct {
B bool
}{
B: b,
}
mock.lockSetBypassOutOfScopeRequests.Lock()
mock.calls.SetBypassOutOfScopeRequests = append(mock.calls.SetBypassOutOfScopeRequests, callInfo)
mock.lockSetBypassOutOfScopeRequests.Unlock()
mock.SetBypassOutOfScopeRequestsFunc(b)
}
// SetBypassOutOfScopeRequestsCalls gets all the calls that were made to SetBypassOutOfScopeRequests.
// Check the length with:
// len(mockedService.SetBypassOutOfScopeRequestsCalls())
func (mock *ReqLogServiceMock) SetBypassOutOfScopeRequestsCalls() []struct {
B bool
} {
var calls []struct {
B bool
}
mock.lockSetBypassOutOfScopeRequests.RLock()
calls = mock.calls.SetBypassOutOfScopeRequests
mock.lockSetBypassOutOfScopeRequests.RUnlock()
return calls
}
// SetFindReqsFilter calls SetFindReqsFilterFunc.
func (mock *ReqLogServiceMock) SetFindReqsFilter(filter reqlog.FindRequestsFilter) {
if mock.SetFindReqsFilterFunc == nil {
panic("ReqLogServiceMock.SetFindReqsFilterFunc: method is nil but Service.SetFindReqsFilter was just called")
}
callInfo := struct {
Filter reqlog.FindRequestsFilter
}{
Filter: filter,
}
mock.lockSetFindReqsFilter.Lock()
mock.calls.SetFindReqsFilter = append(mock.calls.SetFindReqsFilter, callInfo)
mock.lockSetFindReqsFilter.Unlock()
mock.SetFindReqsFilterFunc(filter)
}
// SetFindReqsFilterCalls gets all the calls that were made to SetFindReqsFilter.
// Check the length with:
// len(mockedService.SetFindReqsFilterCalls())
func (mock *ReqLogServiceMock) SetFindReqsFilterCalls() []struct {
Filter reqlog.FindRequestsFilter
} {
var calls []struct {
Filter reqlog.FindRequestsFilter
}
mock.lockSetFindReqsFilter.RLock()
calls = mock.calls.SetFindReqsFilter
mock.lockSetFindReqsFilter.RUnlock()
return calls
}

228
pkg/sender/search.go Normal file
View File

@ -0,0 +1,228 @@
package sender
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
var senderReqSearchKeyFns = map[string]func(req Request) string{
"req.id": func(req Request) string { return req.ID.String() },
"req.proto": func(req Request) string { return req.Proto },
"req.url": func(req Request) string {
if req.URL == nil {
return ""
}
return req.URL.String()
},
"req.method": func(req Request) string { return req.Method },
"req.body": func(req Request) string { return string(req.Body) },
"req.timestamp": func(req Request) string { return ulid.Time(req.ID.Time()).String() },
}
// TODO: Request and response headers search key functions.
// Matches returns true if the supplied search expression evaluates to true.
func (req Request) Matches(expr search.Expression) (bool, error) {
switch e := expr.(type) {
case search.PrefixExpression:
return req.matchPrefixExpr(e)
case search.InfixExpression:
return req.matchInfixExpr(e)
case search.StringLiteral:
return req.matchStringLiteral(e)
default:
return false, fmt.Errorf("expression type (%T) not supported", expr)
}
}
func (req Request) matchPrefixExpr(expr search.PrefixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpNot:
match, err := req.Matches(expr.Right)
if err != nil {
return false, err
}
return !match, nil
default:
return false, errors.New("operator is not supported")
}
}
func (req Request) matchInfixExpr(expr search.InfixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := req.Matches(expr.Left)
if err != nil {
return false, err
}
right, err := req.Matches(expr.Right)
if err != nil {
return false, err
}
return left && right, nil
case search.TokOpOr:
left, err := req.Matches(expr.Left)
if err != nil {
return false, err
}
right, err := req.Matches(expr.Right)
if err != nil {
return false, err
}
return left || right, nil
}
left, ok := expr.Left.(search.StringLiteral)
if !ok {
return false, errors.New("left operand must be a string literal")
}
leftVal := req.getMappedStringLiteral(left.Value)
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(*regexp.Regexp)
if !ok {
return false, errors.New("right operand must be a regular expression")
}
switch expr.Operator {
case search.TokOpRe:
return right.MatchString(leftVal), nil
case search.TokOpNotRe:
return !right.MatchString(leftVal), nil
}
}
right, ok := expr.Right.(search.StringLiteral)
if !ok {
return false, errors.New("right operand must be a string literal")
}
rightVal := req.getMappedStringLiteral(right.Value)
switch expr.Operator {
case search.TokOpEq:
return leftVal == rightVal, nil
case search.TokOpNotEq:
return leftVal != rightVal, nil
case search.TokOpGt:
// TODO(?) attempt to parse as int.
return leftVal > rightVal, nil
case search.TokOpLt:
// TODO(?) attempt to parse as int.
return leftVal < rightVal, nil
case search.TokOpGtEq:
// TODO(?) attempt to parse as int.
return leftVal >= rightVal, nil
case search.TokOpLtEq:
// TODO(?) attempt to parse as int.
return leftVal <= rightVal, nil
default:
return false, errors.New("unsupported operator")
}
}
func (req Request) getMappedStringLiteral(s string) string {
switch {
case strings.HasPrefix(s, "req."):
fn, ok := senderReqSearchKeyFns[s]
if ok {
return fn(req)
}
case strings.HasPrefix(s, "res."):
if req.Response == nil {
return ""
}
fn, ok := reqlog.ResLogSearchKeyFns[s]
if ok {
return fn(*req.Response)
}
}
return s
}
func (req Request) matchStringLiteral(strLiteral search.StringLiteral) (bool, error) {
for _, fn := range senderReqSearchKeyFns {
if strings.Contains(
strings.ToLower(fn(req)),
strings.ToLower(strLiteral.Value),
) {
return true, nil
}
}
if req.Response != nil {
for _, fn := range reqlog.ResLogSearchKeyFns {
if strings.Contains(
strings.ToLower(fn(*req.Response)),
strings.ToLower(strLiteral.Value),
) {
return true, nil
}
}
}
return false, nil
}
func (req Request) MatchScope(s *scope.Scope) bool {
for _, rule := range s.Rules() {
if rule.URL != nil && req.URL != nil {
if matches := rule.URL.MatchString(req.URL.String()); matches {
return true
}
}
for key, values := range req.Header {
var keyMatches, valueMatches bool
if rule.Header.Key != nil {
if matches := rule.Header.Key.MatchString(key); matches {
keyMatches = true
}
}
if rule.Header.Value != nil {
for _, value := range values {
if matches := rule.Header.Value.MatchString(value); matches {
valueMatches = true
break
}
}
}
// When only key or value is set, match on whatever is set.
// When both are set, both must match.
switch {
case rule.Header.Key != nil && rule.Header.Value == nil && keyMatches:
return true
case rule.Header.Key == nil && rule.Header.Value != nil && valueMatches:
return true
case rule.Header.Key != nil && rule.Header.Value != nil && keyMatches && valueMatches:
return true
}
}
if rule.Body != nil {
if matches := rule.Body.Match(req.Body); matches {
return true
}
}
}
return false
}

204
pkg/sender/search_test.go Normal file
View File

@ -0,0 +1,204 @@
package sender_test
import (
"testing"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/search"
"github.com/dstotijn/hetty/pkg/sender"
)
func TestRequestLogMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
senderReq sender.Request
expectedMatch bool
expectedError error
}{
{
name: "infix expression, equal operator, match",
query: "req.body = foo",
senderReq: sender.Request{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, not equal operator, match",
query: "req.body != bar",
senderReq: sender.Request{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than operator, match",
query: "req.body > a",
senderReq: sender.Request{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than operator, match",
query: "req.body < b",
senderReq: sender.Request{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than or equal operator, match greater than",
query: "req.body >= a",
senderReq: sender.Request{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than or equal operator, match equal",
query: "req.body >= a",
senderReq: sender.Request{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than or equal operator, match less than",
query: "req.body <= b",
senderReq: sender.Request{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than or equal operator, match equal",
query: "req.body <= b",
senderReq: sender.Request{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, regular expression operator, match",
query: `req.body =~ "^foo(.*)$"`,
senderReq: sender.Request{
Body: []byte("foobar"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, negate regular expression operator, match",
query: `req.body !~ "^foo(.*)$"`,
senderReq: sender.Request{
Body: []byte("xoobar"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, and operator, match",
query: "req.body = bar AND res.body = yolo",
senderReq: sender.Request{
Body: []byte("bar"),
Response: &reqlog.ResponseLog{
Body: []byte("yolo"),
},
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, or operator, match",
query: "req.body = bar OR res.body = yolo",
senderReq: sender.Request{
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Body: []byte("yolo"),
},
},
expectedMatch: true,
expectedError: nil,
},
{
name: "prefix expression, not operator, match",
query: "NOT (req.body = bar)",
senderReq: sender.Request{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "string literal expression, match in request log",
query: "foo",
senderReq: sender.Request{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "string literal expression, no match",
query: "foo",
senderReq: sender.Request{
Body: []byte("bar"),
},
expectedMatch: false,
expectedError: nil,
},
{
name: "string literal expression, match in response log",
query: "foo",
senderReq: sender.Request{
Response: &reqlog.ResponseLog{
Body: []byte("foo"),
},
},
expectedMatch: true,
expectedError: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
searchExpr, err := search.ParseQuery(tt.query)
assertError(t, nil, err)
got, err := tt.senderReq.Matches(searchExpr)
assertError(t, tt.expectedError, err)
if tt.expectedMatch != got {
t.Errorf("expected match result: %v, got: %v", tt.expectedMatch, got)
}
})
}
}
func assertError(t *testing.T, exp, got error) {
t.Helper()
switch {
case exp == nil && got != nil:
t.Fatalf("expected: nil, got: %v", got)
case exp != nil && got == nil:
t.Fatalf("expected: %v, got: nil", exp.Error())
case exp != nil && got != nil && exp.Error() != got.Error():
t.Fatalf("expected: %v, got: %v", exp.Error(), got.Error())
}
}

251
pkg/sender/sender.go Normal file
View File

@ -0,0 +1,251 @@
package sender
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var defaultHTTPClient = &http.Client{
Transport: &HTTPTransport{},
Timeout: 30 * time.Second,
}
var (
ErrProjectIDMustBeSet = errors.New("sender: project ID must be set")
ErrRequestNotFound = errors.New("sender: request not found")
)
type Service interface {
FindRequestByID(ctx context.Context, id ulid.ULID) (Request, error)
FindRequests(ctx context.Context) ([]Request, error)
CreateOrUpdateRequest(ctx context.Context, req Request) (Request, error)
CloneFromRequestLog(ctx context.Context, reqLogID ulid.ULID) (Request, error)
DeleteRequests(ctx context.Context, projectID ulid.ULID) error
SendRequest(ctx context.Context, id ulid.ULID) (Request, error)
SetActiveProjectID(ulid.ULID)
SetFindReqsFilter(filter FindRequestsFilter)
FindReqsFilter() FindRequestsFilter
}
type service struct {
activeProjectID ulid.ULID
findReqsFilter FindRequestsFilter
scope *scope.Scope
repo Repository
reqLogSvc reqlog.Service
httpClient *http.Client
}
type FindRequestsFilter struct {
ProjectID ulid.ULID
OnlyInScope bool
SearchExpr search.Expression
}
type Config struct {
Scope *scope.Scope
Repository Repository
ReqLogService reqlog.Service
HTTPClient *http.Client
}
type SendError struct {
err error
}
func NewService(cfg Config) Service {
svc := &service{
repo: cfg.Repository,
reqLogSvc: cfg.ReqLogService,
httpClient: defaultHTTPClient,
scope: cfg.Scope,
}
if cfg.HTTPClient != nil {
svc.httpClient = cfg.HTTPClient
}
return svc
}
type Request struct {
ID ulid.ULID
ProjectID ulid.ULID
SourceRequestLogID ulid.ULID
URL *url.URL
Method string
Proto string
Header http.Header
Body []byte
Response *reqlog.ResponseLog
}
func (svc *service) FindRequestByID(ctx context.Context, id ulid.ULID) (Request, error) {
req, err := svc.repo.FindSenderRequestByID(ctx, id)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to find request: %w", err)
}
return req, nil
}
func (svc *service) FindRequests(ctx context.Context) ([]Request, error) {
return svc.repo.FindSenderRequests(ctx, svc.findReqsFilter, svc.scope)
}
func (svc *service) CreateOrUpdateRequest(ctx context.Context, req Request) (Request, error) {
if svc.activeProjectID.Compare(ulid.ULID{}) == 0 {
return Request{}, ErrProjectIDMustBeSet
}
if req.ID.Compare(ulid.ULID{}) == 0 {
req.ID = ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
}
req.ProjectID = svc.activeProjectID
if req.Method == "" {
req.Method = http.MethodGet
}
if req.Proto == "" {
req.Proto = HTTPProto2
}
if !isValidProto(req.Proto) {
return Request{}, fmt.Errorf("sender: unsupported HTTP protocol: %v", req.Proto)
}
err := svc.repo.StoreSenderRequest(ctx, req)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to store request: %w", err)
}
return req, nil
}
func (svc *service) CloneFromRequestLog(ctx context.Context, reqLogID ulid.ULID) (Request, error) {
if svc.activeProjectID.Compare(ulid.ULID{}) == 0 {
return Request{}, ErrProjectIDMustBeSet
}
reqLog, err := svc.reqLogSvc.FindRequestLogByID(ctx, reqLogID)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to find request log: %w", err)
}
req := Request{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: svc.activeProjectID,
SourceRequestLogID: reqLogID,
Method: reqLog.Method,
URL: reqLog.URL,
Proto: HTTPProto2, // Attempt HTTP/2.
Header: reqLog.Header,
Body: reqLog.Body,
}
err = svc.repo.StoreSenderRequest(ctx, req)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to store request: %w", err)
}
return req, nil
}
func (svc *service) SetFindReqsFilter(filter FindRequestsFilter) {
svc.findReqsFilter = filter
}
func (svc *service) FindReqsFilter() FindRequestsFilter {
return svc.findReqsFilter
}
func (svc *service) SendRequest(ctx context.Context, id ulid.ULID) (Request, error) {
req, err := svc.repo.FindSenderRequestByID(ctx, id)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to find request: %w", err)
}
httpReq, err := parseHTTPRequest(ctx, req)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to parse HTTP request: %w", err)
}
resLog, err := svc.sendHTTPRequest(httpReq)
if err != nil {
return Request{}, fmt.Errorf("sender: could not send HTTP request: %w", err)
}
err = svc.repo.StoreResponseLog(ctx, id, resLog)
if err != nil {
return Request{}, fmt.Errorf("sender: failed to store sender response log: %w", err)
}
req.Response = &resLog
return req, nil
}
func parseHTTPRequest(ctx context.Context, req Request) (*http.Request, error) {
ctx = context.WithValue(ctx, protoCtxKey{}, req.Proto)
httpReq, err := http.NewRequestWithContext(ctx, req.Method, req.URL.String(), bytes.NewReader(req.Body))
if err != nil {
return nil, fmt.Errorf("failed to construct HTTP request: %w", err)
}
if req.Header != nil {
httpReq.Header = req.Header
}
return httpReq, nil
}
func (svc *service) sendHTTPRequest(httpReq *http.Request) (reqlog.ResponseLog, error) {
res, err := svc.httpClient.Do(httpReq)
if err != nil {
return reqlog.ResponseLog{}, &SendError{err}
}
defer res.Body.Close()
resLog, err := reqlog.ParseHTTPResponse(res)
if err != nil {
return reqlog.ResponseLog{}, fmt.Errorf("failed to parse http response: %w", err)
}
return resLog, err
}
func (svc *service) SetActiveProjectID(id ulid.ULID) {
svc.activeProjectID = id
}
func (svc *service) DeleteRequests(ctx context.Context, projectID ulid.ULID) error {
return svc.repo.DeleteSenderRequests(ctx, projectID)
}
func (e SendError) Error() string {
return fmt.Sprintf("failed to send HTTP request: %v", e.err)
}
func (e SendError) Unwrap() error {
return e.err
}

289
pkg/sender/sender_test.go Normal file
View File

@ -0,0 +1,289 @@
package sender_test
//go:generate go run github.com/matryer/moq -out reqlog_mock_test.go -pkg sender_test ../reqlog Service:ReqLogServiceMock
//go:generate go run github.com/matryer/moq -out repo_mock_test.go -pkg sender_test . Repository:RepoMock
import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/sender"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var exampleURL = func() *url.URL {
u, err := url.Parse("https://example.com/foobar")
if err != nil {
panic(err)
}
return u
}()
func TestStoreRequest(t *testing.T) {
t.Parallel()
t.Run("without active project", func(t *testing.T) {
t.Parallel()
svc := sender.NewService(sender.Config{})
_, err := svc.CreateOrUpdateRequest(context.Background(), sender.Request{
URL: exampleURL,
Method: http.MethodPost,
Body: []byte("foobar"),
})
if !errors.Is(err, sender.ErrProjectIDMustBeSet) {
t.Fatalf("expected `sender.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("with active project", func(t *testing.T) {
t.Parallel()
repoMock := &RepoMock{
StoreSenderRequestFunc: func(ctx context.Context, req sender.Request) error {
return nil
},
}
svc := sender.NewService(sender.Config{
Repository: repoMock,
})
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
svc.SetActiveProjectID(projectID)
exp := sender.Request{
ProjectID: projectID,
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foobar"),
}
got, err := svc.CreateOrUpdateRequest(context.Background(), sender.Request{
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foobar"),
})
if err != nil {
t.Fatalf("unexpected error storing request: %v", err)
}
if got.ID.Compare(ulid.ULID{}) == 0 {
t.Fatal("expected request ID to be non-empty value")
}
if len(repoMock.StoreSenderRequestCalls()) != 1 {
t.Fatal("expected `svc.repo.StoreSenderRequest()` to have been called 1 time")
}
if diff := cmp.Diff(got, repoMock.StoreSenderRequestCalls()[0].Req); diff != "" {
t.Fatalf("repo call arg not equal (-exp, +got):\n%v", diff)
}
// Reset ID to make comparison with expected easier.
got.ID = ulid.ULID{}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request not equal (-exp, +got):\n%v", diff)
}
})
}
func TestCloneFromRequestLog(t *testing.T) {
t.Parallel()
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
t.Run("without active project", func(t *testing.T) {
t.Parallel()
svc := sender.NewService(sender.Config{})
_, err := svc.CloneFromRequestLog(context.Background(), reqLogID)
if !errors.Is(err, sender.ErrProjectIDMustBeSet) {
t.Fatalf("expected `sender.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("with active project", func(t *testing.T) {
t.Parallel()
reqLog := reqlog.RequestLog{
ID: reqLogID,
ProjectID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foobar"),
}
reqLogMock := &ReqLogServiceMock{
FindRequestLogByIDFunc: func(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error) {
return reqLog, nil
},
}
repoMock := &RepoMock{
StoreSenderRequestFunc: func(ctx context.Context, req sender.Request) error {
return nil
},
}
svc := sender.NewService(sender.Config{
ReqLogService: reqLogMock,
Repository: repoMock,
})
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
svc.SetActiveProjectID(projectID)
exp := sender.Request{
SourceRequestLogID: reqLogID,
ProjectID: projectID,
URL: exampleURL,
Method: http.MethodPost,
Proto: sender.HTTPProto2,
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foobar"),
}
got, err := svc.CloneFromRequestLog(context.Background(), reqLogID)
if err != nil {
t.Fatalf("unexpected error cloning from request log: %v", err)
}
if len(reqLogMock.FindRequestLogByIDCalls()) != 1 {
t.Fatal("expected `svc.reqLogSvc.FindRequestLogByID()` to have been called 1 time")
}
if got := reqLogMock.FindRequestLogByIDCalls()[0].ID; reqLogID.Compare(got) != 0 {
t.Fatalf("reqlog service call arg `id` not equal (expected: %q, got: %q)", reqLogID, got)
}
if got.ID.Compare(ulid.ULID{}) == 0 {
t.Fatal("expected request ID to be non-empty value")
}
if len(repoMock.StoreSenderRequestCalls()) != 1 {
t.Fatal("expected `svc.repo.StoreSenderRequest()` to have been called 1 time")
}
if diff := cmp.Diff(got, repoMock.StoreSenderRequestCalls()[0].Req); diff != "" {
t.Fatalf("repo call arg not equal (-exp, +got):\n%v", diff)
}
// Reset ID to make comparison with expected easier.
got.ID = ulid.ULID{}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request not equal (-exp, +got):\n%v", diff)
}
})
}
func TestSendRequest(t *testing.T) {
t.Parallel()
date := time.Now().Format(http.TimeFormat)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Foobar", "baz")
w.Header().Set("Date", date)
fmt.Fprint(w, "baz")
}))
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req := sender.Request{
ID: reqID,
ProjectID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: tsURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foobar"),
}
repoMock := &RepoMock{
FindSenderRequestByIDFunc: func(ctx context.Context, id ulid.ULID) (sender.Request, error) {
return req, nil
},
StoreResponseLogFunc: func(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
return nil
},
}
svc := sender.NewService(sender.Config{
Repository: repoMock,
})
exp := reqlog.ResponseLog{
Proto: "HTTP/1.1",
StatusCode: http.StatusOK,
Status: "200 OK",
Header: http.Header{
"Content-Length": []string{"3"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Date": []string{date},
"Foobar": []string{"baz"},
},
Body: []byte("baz"),
}
got, err := svc.SendRequest(context.Background(), reqID)
if err != nil {
t.Fatalf("unexpected error sending request: %v", err)
}
if len(repoMock.FindSenderRequestByIDCalls()) != 1 {
t.Fatal("expected `svc.repo.FindSenderRequestByID()` to have been called 1 time")
}
if diff := cmp.Diff(reqID, repoMock.FindSenderRequestByIDCalls()[0].ID); diff != "" {
t.Fatalf("call arg `id` for `svc.repo.FindSenderRequestByID()` not equal (-exp, +got):\n%v", diff)
}
if len(repoMock.StoreResponseLogCalls()) != 1 {
t.Fatal("expected `svc.repo.StoreResponseLog()` to have been called 1 time")
}
if diff := cmp.Diff(reqID, repoMock.StoreResponseLogCalls()[0].ReqLogID); diff != "" {
t.Fatalf("call arg `reqLogID` for `svc.repo.StoreResponseLog()` not equal (-exp, +got):\n%v", diff)
}
if diff := cmp.Diff(exp, repoMock.StoreResponseLogCalls()[0].ResLog); diff != "" {
t.Fatalf("call arg `resLog` for `svc.repo.StoreResponseLog()` not equal (-exp, +got):\n%v", diff)
}
if diff := cmp.Diff(repoMock.StoreResponseLogCalls()[0].ResLog, *got.Response); diff != "" {
t.Fatalf("returned response log value and persisted value not equal (-exp, +got):\n%v", diff)
}
}

50
pkg/sender/transport.go Normal file
View File

@ -0,0 +1,50 @@
package sender
import (
"crypto/tls"
"net"
"net/http"
"time"
)
type HTTPTransport struct{}
type protoCtxKey struct{}
const (
HTTPProto1 = "HTTP/1.1"
HTTPProto2 = "HTTP/2.0"
)
// h1OnlyTransport mimics `http.DefaultTransport`, but with HTTP/2 disabled.
var h1OnlyTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Disable HTTP/2.
TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{},
}
// RountTrip implements http.RoundTripper. Based on a context value on the
// HTTP request, it switches between using `http.DefaultTransport` (which attempts
// HTTP/2) and a HTTP/1.1 only transport that's based off `http.DefaultTransport`.
func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
proto, ok := req.Context().Value(protoCtxKey{}).(string)
if ok && proto == HTTPProto1 {
return h1OnlyTransport.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
func isValidProto(proto string) bool {
return proto == HTTPProto1 || proto == HTTPProto2
}