Add initial support for HTTP request intercepting

This commit is contained in:
David Stotijn
2022-03-08 13:27:31 +01:00
parent bdd667381a
commit 71e550f0cd
11 changed files with 1184 additions and 28 deletions

View File

@ -28,6 +28,7 @@ import (
"github.com/dstotijn/hetty/pkg/db/badger" "github.com/dstotijn/hetty/pkg/db/badger"
"github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/proxy" "github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
"github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/sender" "github.com/dstotijn/hetty/pkg/sender"
@ -175,6 +176,10 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error {
Logger: cmd.config.logger.Named("reqlog").Sugar(), Logger: cmd.config.logger.Named("reqlog").Sugar(),
}) })
interceptService := intercept.NewService(intercept.Config{
Logger: cmd.config.logger.Named("intercept").Sugar(),
})
senderService := sender.NewService(sender.Config{ senderService := sender.NewService(sender.Config{
Repository: badger, Repository: badger,
ReqLogService: reqLogService, ReqLogService: reqLogService,
@ -182,6 +187,7 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error {
projService, err := proj.NewService(proj.Config{ projService, err := proj.NewService(proj.Config{
Repository: badger, Repository: badger,
InterceptService: interceptService,
ReqLogService: reqLogService, ReqLogService: reqLogService,
SenderService: senderService, SenderService: senderService,
Scope: scope, Scope: scope,
@ -201,6 +207,7 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error {
proxy.UseRequestModifier(reqLogService.RequestModifier) proxy.UseRequestModifier(reqLogService.RequestModifier)
proxy.UseResponseModifier(reqLogService.ResponseModifier) proxy.UseResponseModifier(reqLogService.ResponseModifier)
proxy.UseRequestModifier(interceptService.RequestModifier)
fsSub, err := fs.Sub(adminContent, "admin") fsSub, err := fs.Sub(adminContent, "admin")
if err != nil { if err != nil {
@ -231,6 +238,7 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error {
adminRouter.Path(gqlEndpoint).Handler(api.HTTPHandler(&api.Resolver{ adminRouter.Path(gqlEndpoint).Handler(api.HTTPHandler(&api.Resolver{
ProjectService: projService, ProjectService: projService,
RequestLogService: reqLogService, RequestLogService: reqLogService,
InterceptService: interceptService,
SenderService: senderService, SenderService: senderService,
}, gqlEndpoint)) }, gqlEndpoint))

View File

@ -66,6 +66,15 @@ type ComplexityRoot struct {
Value func(childComplexity int) int Value func(childComplexity int) int
} }
HTTPRequest struct {
Body func(childComplexity int) int
Headers func(childComplexity int) int
ID func(childComplexity int) int
Method func(childComplexity int) int
Proto func(childComplexity int) int
URL func(childComplexity int) int
}
HTTPRequestLog struct { HTTPRequestLog struct {
Body func(childComplexity int) int Body func(childComplexity int) int
Headers func(childComplexity int) int Headers func(childComplexity int) int
@ -91,6 +100,10 @@ type ComplexityRoot struct {
StatusReason func(childComplexity int) int StatusReason func(childComplexity int) int
} }
ModifyRequestResult struct {
Success func(childComplexity int) int
}
Mutation struct { Mutation struct {
ClearHTTPRequestLog func(childComplexity int) int ClearHTTPRequestLog func(childComplexity int) int
CloseProject func(childComplexity int) int CloseProject func(childComplexity int) int
@ -99,6 +112,7 @@ type ComplexityRoot struct {
CreateSenderRequestFromHTTPRequestLog func(childComplexity int, id ulid.ULID) int CreateSenderRequestFromHTTPRequestLog func(childComplexity int, id ulid.ULID) int
DeleteProject func(childComplexity int, id ulid.ULID) int DeleteProject func(childComplexity int, id ulid.ULID) int
DeleteSenderRequests func(childComplexity int) int DeleteSenderRequests func(childComplexity int) int
ModifyRequest func(childComplexity int, request ModifyRequestInput) int
OpenProject func(childComplexity int, id ulid.ULID) int OpenProject func(childComplexity int, id ulid.ULID) int
SendRequest func(childComplexity int, id ulid.ULID) int SendRequest func(childComplexity int, id ulid.ULID) int
SetHTTPRequestLogFilter func(childComplexity int, filter *HTTPRequestLogFilterInput) int SetHTTPRequestLogFilter func(childComplexity int, filter *HTTPRequestLogFilterInput) int
@ -117,6 +131,7 @@ type ComplexityRoot struct {
HTTPRequestLog func(childComplexity int, id ulid.ULID) int HTTPRequestLog func(childComplexity int, id ulid.ULID) int
HTTPRequestLogFilter func(childComplexity int) int HTTPRequestLogFilter func(childComplexity int) int
HTTPRequestLogs func(childComplexity int) int HTTPRequestLogs func(childComplexity int) int
InterceptedRequests func(childComplexity int) int
Projects func(childComplexity int) int Projects func(childComplexity int) int
Scope func(childComplexity int) int Scope func(childComplexity int) int
SenderRequest func(childComplexity int, id ulid.ULID) int SenderRequest func(childComplexity int, id ulid.ULID) int
@ -165,6 +180,7 @@ type MutationResolver interface {
CreateSenderRequestFromHTTPRequestLog(ctx context.Context, id ulid.ULID) (*SenderRequest, error) CreateSenderRequestFromHTTPRequestLog(ctx context.Context, id ulid.ULID) (*SenderRequest, error)
SendRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) SendRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error)
DeleteSenderRequests(ctx context.Context) (*DeleteSenderRequestsResult, error) DeleteSenderRequests(ctx context.Context) (*DeleteSenderRequestsResult, error)
ModifyRequest(ctx context.Context, request ModifyRequestInput) (*ModifyRequestResult, error)
} }
type QueryResolver interface { type QueryResolver interface {
HTTPRequestLog(ctx context.Context, id ulid.ULID) (*HTTPRequestLog, error) HTTPRequestLog(ctx context.Context, id ulid.ULID) (*HTTPRequestLog, error)
@ -175,6 +191,7 @@ type QueryResolver interface {
Scope(ctx context.Context) ([]ScopeRule, error) Scope(ctx context.Context) ([]ScopeRule, error)
SenderRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) SenderRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error)
SenderRequests(ctx context.Context) ([]SenderRequest, error) SenderRequests(ctx context.Context) ([]SenderRequest, error)
InterceptedRequests(ctx context.Context) ([]HTTPRequest, error)
} }
type executableSchema struct { type executableSchema struct {
@ -234,6 +251,48 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.HTTPHeader.Value(childComplexity), true return e.complexity.HTTPHeader.Value(childComplexity), true
case "HttpRequest.body":
if e.complexity.HTTPRequest.Body == nil {
break
}
return e.complexity.HTTPRequest.Body(childComplexity), true
case "HttpRequest.headers":
if e.complexity.HTTPRequest.Headers == nil {
break
}
return e.complexity.HTTPRequest.Headers(childComplexity), true
case "HttpRequest.id":
if e.complexity.HTTPRequest.ID == nil {
break
}
return e.complexity.HTTPRequest.ID(childComplexity), true
case "HttpRequest.method":
if e.complexity.HTTPRequest.Method == nil {
break
}
return e.complexity.HTTPRequest.Method(childComplexity), true
case "HttpRequest.proto":
if e.complexity.HTTPRequest.Proto == nil {
break
}
return e.complexity.HTTPRequest.Proto(childComplexity), true
case "HttpRequest.url":
if e.complexity.HTTPRequest.URL == nil {
break
}
return e.complexity.HTTPRequest.URL(childComplexity), true
case "HttpRequestLog.body": case "HttpRequestLog.body":
if e.complexity.HTTPRequestLog.Body == nil { if e.complexity.HTTPRequestLog.Body == nil {
break break
@ -346,6 +405,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.HTTPResponseLog.StatusReason(childComplexity), true return e.complexity.HTTPResponseLog.StatusReason(childComplexity), true
case "ModifyRequestResult.success":
if e.complexity.ModifyRequestResult.Success == nil {
break
}
return e.complexity.ModifyRequestResult.Success(childComplexity), true
case "Mutation.clearHTTPRequestLog": case "Mutation.clearHTTPRequestLog":
if e.complexity.Mutation.ClearHTTPRequestLog == nil { if e.complexity.Mutation.ClearHTTPRequestLog == nil {
break break
@ -415,6 +481,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Mutation.DeleteSenderRequests(childComplexity), true return e.complexity.Mutation.DeleteSenderRequests(childComplexity), true
case "Mutation.modifyRequest":
if e.complexity.Mutation.ModifyRequest == nil {
break
}
args, err := ec.field_Mutation_modifyRequest_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Mutation.ModifyRequest(childComplexity, args["request"].(ModifyRequestInput)), true
case "Mutation.openProject": case "Mutation.openProject":
if e.complexity.Mutation.OpenProject == nil { if e.complexity.Mutation.OpenProject == nil {
break break
@ -529,6 +607,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Query.HTTPRequestLogs(childComplexity), true return e.complexity.Query.HTTPRequestLogs(childComplexity), true
case "Query.interceptedRequests":
if e.complexity.Query.InterceptedRequests == nil {
break
}
return e.complexity.Query.InterceptedRequests(childComplexity), true
case "Query.projects": case "Query.projects":
if e.complexity.Query.Projects == nil { if e.complexity.Query.Projects == nil {
break break
@ -856,6 +941,28 @@ type SenderRequestFilter {
searchExpression: String searchExpression: String
} }
type HttpRequest {
id: ID!
url: URL!
method: HttpMethod!
proto: HttpProtocol!
headers: [HttpHeader!]!
body: String
}
input ModifyRequestInput {
id: ID!
url: URL!
method: HttpMethod!
proto: HttpProtocol!
headers: [HttpHeaderInput!]
body: String
}
type ModifyRequestResult {
success: Boolean!
}
type Query { type Query {
httpRequestLog(id: ID!): HttpRequestLog httpRequestLog(id: ID!): HttpRequestLog
httpRequestLogs: [HttpRequestLog!]! httpRequestLogs: [HttpRequestLog!]!
@ -865,6 +972,7 @@ type Query {
scope: [ScopeRule!]! scope: [ScopeRule!]!
senderRequest(id: ID!): SenderRequest senderRequest(id: ID!): SenderRequest
senderRequests: [SenderRequest!]! senderRequests: [SenderRequest!]!
interceptedRequests: [HttpRequest!]!
} }
type Mutation { type Mutation {
@ -882,6 +990,7 @@ type Mutation {
createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest! createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest!
sendRequest(id: ID!): SenderRequest! sendRequest(id: ID!): SenderRequest!
deleteSenderRequests: DeleteSenderRequestsResult! deleteSenderRequests: DeleteSenderRequestsResult!
modifyRequest(request: ModifyRequestInput!): ModifyRequestResult!
} }
enum HttpMethod { enum HttpMethod {
@ -973,6 +1082,21 @@ func (ec *executionContext) field_Mutation_deleteProject_args(ctx context.Contex
return args, nil return args, nil
} }
func (ec *executionContext) field_Mutation_modifyRequest_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 ModifyRequestInput
if tmp, ok := rawArgs["request"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("request"))
arg0, err = ec.unmarshalNModifyRequestInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestInput(ctx, tmp)
if err != nil {
return nil, err
}
}
args["request"] = arg0
return args, nil
}
func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error var err error
args := map[string]interface{}{} args := map[string]interface{}{}
@ -1341,6 +1465,213 @@ func (ec *executionContext) _HttpHeader_value(ctx context.Context, field graphql
return ec.marshalNString2string(ctx, field.Selections, res) return ec.marshalNString2string(ctx, field.Selections, res)
} }
func (ec *executionContext) _HttpRequest_id(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.ID, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(ulid.ULID)
fc.Result = res
return ec.marshalNID2githubᚗcomᚋoklogᚋulidᚐULID(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequest_url(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.URL, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(*url.URL)
fc.Result = res
return ec.marshalNURL2ᚖnetᚋurlᚐURL(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequest_method(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Method, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(HTTPMethod)
fc.Result = res
return ec.marshalNHttpMethod2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPMethod(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequest_proto(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Proto, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(HTTPProtocol)
fc.Result = res
return ec.marshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPProtocol(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequest_headers(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Headers, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.([]HTTPHeader)
fc.Result = res
return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequest_body(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpRequest",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Body, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequestLog_id(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLog) (ret graphql.Marshaler) { func (ec *executionContext) _HttpRequestLog_id(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLog) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -1889,6 +2220,41 @@ func (ec *executionContext) _HttpResponseLog_headers(ctx context.Context, field
return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res) return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res)
} }
func (ec *executionContext) _ModifyRequestResult_success(ctx context.Context, field graphql.CollectedField, obj *ModifyRequestResult) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "ModifyRequestResult",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Success, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(bool)
fc.Result = res
return ec.marshalNBoolean2bool(ctx, field.Selections, res)
}
func (ec *executionContext) _Mutation_createProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { func (ec *executionContext) _Mutation_createProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -2360,6 +2726,48 @@ func (ec *executionContext) _Mutation_deleteSenderRequests(ctx context.Context,
return ec.marshalNDeleteSenderRequestsResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐDeleteSenderRequestsResult(ctx, field.Selections, res) return ec.marshalNDeleteSenderRequestsResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐDeleteSenderRequestsResult(ctx, field.Selections, res)
} }
func (ec *executionContext) _Mutation_modifyRequest(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "Mutation",
Field: field,
Args: nil,
IsMethod: true,
IsResolver: true,
}
ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Mutation_modifyRequest_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Mutation().ModifyRequest(rctx, args["request"].(ModifyRequestInput))
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(*ModifyRequestResult)
fc.Result = res
return ec.marshalNModifyRequestResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx, field.Selections, res)
}
func (ec *executionContext) _Project_id(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) { func (ec *executionContext) _Project_id(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -2747,6 +3155,41 @@ func (ec *executionContext) _Query_senderRequests(ctx context.Context, field gra
return ec.marshalNSenderRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐSenderRequestᚄ(ctx, field.Selections, res) return ec.marshalNSenderRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐSenderRequestᚄ(ctx, field.Selections, res)
} }
func (ec *executionContext) _Query_interceptedRequests(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "Query",
Field: field,
Args: nil,
IsMethod: true,
IsResolver: true,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().InterceptedRequests(rctx)
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.([]HTTPRequest)
fc.Result = res
return ec.marshalNHttpRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -4532,6 +4975,69 @@ func (ec *executionContext) unmarshalInputHttpRequestLogFilterInput(ctx context.
return it, nil return it, nil
} }
func (ec *executionContext) unmarshalInputModifyRequestInput(ctx context.Context, obj interface{}) (ModifyRequestInput, error) {
var it ModifyRequestInput
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
for k, v := range asMap {
switch k {
case "id":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id"))
it.ID, err = ec.unmarshalNID2githubᚗcomᚋoklogᚋulidᚐULID(ctx, v)
if err != nil {
return it, err
}
case "url":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("url"))
it.URL, err = ec.unmarshalNURL2ᚖnetᚋurlᚐURL(ctx, v)
if err != nil {
return it, err
}
case "method":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("method"))
it.Method, err = ec.unmarshalNHttpMethod2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPMethod(ctx, v)
if err != nil {
return it, err
}
case "proto":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("proto"))
it.Proto, err = ec.unmarshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPProtocol(ctx, v)
if err != nil {
return it, err
}
case "headers":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("headers"))
it.Headers, err = ec.unmarshalOHttpHeaderInput2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderInputᚄ(ctx, v)
if err != nil {
return it, err
}
case "body":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("body"))
it.Body, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
}
}
return it, nil
}
func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context, obj interface{}) (ScopeHeaderInput, error) { func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context, obj interface{}) (ScopeHeaderInput, error) {
var it ScopeHeaderInput var it ScopeHeaderInput
asMap := map[string]interface{}{} asMap := map[string]interface{}{}
@ -4844,6 +5350,55 @@ func (ec *executionContext) _HttpHeader(ctx context.Context, sel ast.SelectionSe
return out return out
} }
var httpRequestImplementors = []string{"HttpRequest"}
func (ec *executionContext) _HttpRequest(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequest) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, httpRequestImplementors)
out := graphql.NewFieldSet(fields)
var invalids uint32
for i, field := range fields {
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("HttpRequest")
case "id":
out.Values[i] = ec._HttpRequest_id(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "url":
out.Values[i] = ec._HttpRequest_url(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "method":
out.Values[i] = ec._HttpRequest_method(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "proto":
out.Values[i] = ec._HttpRequest_proto(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "headers":
out.Values[i] = ec._HttpRequest_headers(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "body":
out.Values[i] = ec._HttpRequest_body(ctx, field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch()
if invalids > 0 {
return graphql.Null
}
return out
}
var httpRequestLogImplementors = []string{"HttpRequestLog"} var httpRequestLogImplementors = []string{"HttpRequestLog"}
func (ec *executionContext) _HttpRequestLog(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequestLog) graphql.Marshaler { func (ec *executionContext) _HttpRequestLog(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequestLog) graphql.Marshaler {
@ -4978,6 +5533,33 @@ func (ec *executionContext) _HttpResponseLog(ctx context.Context, sel ast.Select
return out return out
} }
var modifyRequestResultImplementors = []string{"ModifyRequestResult"}
func (ec *executionContext) _ModifyRequestResult(ctx context.Context, sel ast.SelectionSet, obj *ModifyRequestResult) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, modifyRequestResultImplementors)
out := graphql.NewFieldSet(fields)
var invalids uint32
for i, field := range fields {
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("ModifyRequestResult")
case "success":
out.Values[i] = ec._ModifyRequestResult_success(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch()
if invalids > 0 {
return graphql.Null
}
return out
}
var mutationImplementors = []string{"Mutation"} var mutationImplementors = []string{"Mutation"}
func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler {
@ -5041,6 +5623,11 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet)
if out.Values[i] == graphql.Null { if out.Values[i] == graphql.Null {
invalids++ invalids++
} }
case "modifyRequest":
out.Values[i] = ec._Mutation_modifyRequest(ctx, field)
if out.Values[i] == graphql.Null {
invalids++
}
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@ -5204,6 +5791,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
} }
return res return res
}) })
case "interceptedRequests":
field := field
out.Concurrently(i, func() (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._Query_interceptedRequests(ctx, field)
if res == graphql.Null {
atomic.AddUint32(&invalids, 1)
}
return res
})
case "__type": case "__type":
out.Values[i] = ec._Query___type(ctx, field) out.Values[i] = ec._Query___type(ctx, field)
case "__schema": case "__schema":
@ -5751,6 +6352,54 @@ func (ec *executionContext) marshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhetty
return v return v
} }
func (ec *executionContext) marshalNHttpRequest2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequest(ctx context.Context, sel ast.SelectionSet, v HTTPRequest) graphql.Marshaler {
return ec._HttpRequest(ctx, sel, &v)
}
func (ec *executionContext) marshalNHttpRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []HTTPRequest) graphql.Marshaler {
ret := make(graphql.Array, len(v))
var wg sync.WaitGroup
isLen1 := len(v) == 1
if !isLen1 {
wg.Add(len(v))
}
for i := range v {
i := i
fc := &graphql.FieldContext{
Index: &i,
Result: &v[i],
}
ctx := graphql.WithFieldContext(ctx, fc)
f := func(i int) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = nil
}
}()
if !isLen1 {
defer wg.Done()
}
ret[i] = ec.marshalNHttpRequest2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequest(ctx, sel, v[i])
}
if isLen1 {
f(i)
} else {
go f(i)
}
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
func (ec *executionContext) marshalNHttpRequestLog2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLog(ctx context.Context, sel ast.SelectionSet, v HTTPRequestLog) graphql.Marshaler { func (ec *executionContext) marshalNHttpRequestLog2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLog(ctx context.Context, sel ast.SelectionSet, v HTTPRequestLog) graphql.Marshaler {
return ec._HttpRequestLog(ctx, sel, &v) return ec._HttpRequestLog(ctx, sel, &v)
} }
@ -5829,6 +6478,25 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti
return res return res
} }
func (ec *executionContext) unmarshalNModifyRequestInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestInput(ctx context.Context, v interface{}) (ModifyRequestInput, error) {
res, err := ec.unmarshalInputModifyRequestInput(ctx, v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNModifyRequestResult2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx context.Context, sel ast.SelectionSet, v ModifyRequestResult) graphql.Marshaler {
return ec._ModifyRequestResult(ctx, sel, &v)
}
func (ec *executionContext) marshalNModifyRequestResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx context.Context, sel ast.SelectionSet, v *ModifyRequestResult) graphql.Marshaler {
if v == nil {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
return ec._ModifyRequestResult(ctx, sel, v)
}
func (ec *executionContext) marshalNProject2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐProject(ctx context.Context, sel ast.SelectionSet, v Project) graphql.Marshaler { func (ec *executionContext) marshalNProject2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐProject(ctx context.Context, sel ast.SelectionSet, v Project) graphql.Marshaler {
return ec._Project(ctx, sel, &v) return ec._Project(ctx, sel, &v)
} }

View File

@ -38,6 +38,15 @@ type HTTPHeaderInput struct {
Value string `json:"value"` Value string `json:"value"`
} }
type HTTPRequest struct {
ID ulid.ULID `json:"id"`
URL *url.URL `json:"url"`
Method HTTPMethod `json:"method"`
Proto HTTPProtocol `json:"proto"`
Headers []HTTPHeader `json:"headers"`
Body *string `json:"body"`
}
type HTTPRequestLog struct { type HTTPRequestLog struct {
ID ulid.ULID `json:"id"` ID ulid.ULID `json:"id"`
URL string `json:"url"` URL string `json:"url"`
@ -69,6 +78,19 @@ type HTTPResponseLog struct {
Headers []HTTPHeader `json:"headers"` Headers []HTTPHeader `json:"headers"`
} }
type ModifyRequestInput struct {
ID ulid.ULID `json:"id"`
URL *url.URL `json:"url"`
Method HTTPMethod `json:"method"`
Proto HTTPProtocol `json:"proto"`
Headers []HTTPHeaderInput `json:"headers"`
Body *string `json:"body"`
}
type ModifyRequestResult struct {
Success bool `json:"success"`
}
type Project struct { type Project struct {
ID ulid.ULID `json:"id"` ID ulid.ULID `json:"id"`
Name string `json:"name"` Name string `json:"name"`

View File

@ -3,9 +3,11 @@ package api
//go:generate go run github.com/99designs/gqlgen //go:generate go run github.com/99designs/gqlgen
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@ -15,6 +17,8 @@ import (
"github.com/vektah/gqlparser/v2/gqlerror" "github.com/vektah/gqlparser/v2/gqlerror"
"github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
"github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search" "github.com/dstotijn/hetty/pkg/search"
@ -36,6 +40,7 @@ var revHTTPProtocolMap = map[HTTPProtocol]string{
type Resolver struct { type Resolver struct {
ProjectService proj.Service ProjectService proj.Service
RequestLogService reqlog.Service RequestLogService reqlog.Service
InterceptService *intercept.Service
SenderService sender.Service SenderService sender.Service
} }
@ -520,6 +525,46 @@ func (r *mutationResolver) DeleteSenderRequests(ctx context.Context) (*DeleteSen
return &DeleteSenderRequestsResult{true}, nil return &DeleteSenderRequestsResult{true}, nil
} }
func (r *queryResolver) InterceptedRequests(ctx context.Context) ([]HTTPRequest, error) {
reqs := r.InterceptService.Requests()
httpReqs := make([]HTTPRequest, len(reqs))
for i, req := range reqs {
req, err := parseHTTPRequest(req)
if err != nil {
return nil, err
}
httpReqs[i] = req
}
return httpReqs, nil
}
func (r *mutationResolver) ModifyRequest(ctx context.Context, input ModifyRequestInput) (*ModifyRequestResult, error) {
body := ""
if input.Body != nil {
body = *input.Body
}
//nolint:noctx
req, err := http.NewRequest(input.Method.String(), input.URL.String(), strings.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to construct HTTP request: %w", err)
}
for _, header := range input.Headers {
req.Header.Add(header.Key, header.Value)
}
err = r.InterceptService.ModifyRequest(input.ID, req)
if err != nil {
return nil, fmt.Errorf("could not modify http request: %w", err)
}
return &ModifyRequestResult{Success: true}, nil
}
func parseSenderRequest(req sender.Request) (SenderRequest, error) { func parseSenderRequest(req sender.Request) (SenderRequest, error) {
method := HTTPMethod(req.Method) method := HTTPMethod(req.Method)
if method != "" && !method.IsValid() { if method != "" && !method.IsValid() {
@ -575,6 +620,56 @@ func parseSenderRequest(req sender.Request) (SenderRequest, error) {
return senderReq, nil return senderReq, nil
} }
func parseHTTPRequest(req *http.Request) (HTTPRequest, error) {
method := HTTPMethod(req.Method)
if method != "" && !method.IsValid() {
return HTTPRequest{}, fmt.Errorf("http request has invalid method: %v", method)
}
reqProto := httpProtocolMap[req.Proto]
if !reqProto.IsValid() {
return HTTPRequest{}, fmt.Errorf("http request has invalid protocol: %v", req.Proto)
}
id, ok := proxy.RequestIDFromContext(req.Context())
if !ok {
return HTTPRequest{}, errors.New("http request has missing ID")
}
httpReq := HTTPRequest{
ID: id,
URL: req.URL,
Method: method,
Proto: HTTPProtocol(req.Proto),
}
if req.Header != nil {
httpReq.Headers = make([]HTTPHeader, 0)
for key, values := range req.Header {
for _, value := range values {
httpReq.Headers = append(httpReq.Headers, HTTPHeader{
Key: key,
Value: value,
})
}
}
}
if req.Body != nil {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
return HTTPRequest{}, fmt.Errorf("failed to read request body: %w", err)
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
bodyStr := string(body)
httpReq.Body = &bodyStr
}
return httpReq, nil
}
func stringPtrToRegexp(s *string) (*regexp.Regexp, error) { func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
if s == nil { if s == nil {
return nil, nil return nil, nil

View File

@ -116,6 +116,28 @@ type SenderRequestFilter {
searchExpression: String searchExpression: String
} }
type HttpRequest {
id: ID!
url: URL!
method: HttpMethod!
proto: HttpProtocol!
headers: [HttpHeader!]!
body: String
}
input ModifyRequestInput {
id: ID!
url: URL!
method: HttpMethod!
proto: HttpProtocol!
headers: [HttpHeaderInput!]
body: String
}
type ModifyRequestResult {
success: Boolean!
}
type Query { type Query {
httpRequestLog(id: ID!): HttpRequestLog httpRequestLog(id: ID!): HttpRequestLog
httpRequestLogs: [HttpRequestLog!]! httpRequestLogs: [HttpRequestLog!]!
@ -125,6 +147,7 @@ type Query {
scope: [ScopeRule!]! scope: [ScopeRule!]!
senderRequest(id: ID!): SenderRequest senderRequest(id: ID!): SenderRequest
senderRequests: [SenderRequest!]! senderRequests: [SenderRequest!]!
interceptedRequests: [HttpRequest!]!
} }
type Mutation { type Mutation {
@ -142,6 +165,7 @@ type Mutation {
createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest! createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest!
sendRequest(id: ID!): SenderRequest! sendRequest(id: ID!): SenderRequest!
deleteSenderRequests: DeleteSenderRequestsResult! deleteSenderRequests: DeleteSenderRequestsResult!
modifyRequest(request: ModifyRequestInput!): ModifyRequestResult!
} }
enum HttpMethod { enum HttpMethod {

View File

@ -11,6 +11,7 @@ import (
"github.com/oklog/ulid" "github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
"github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search" "github.com/dstotijn/hetty/pkg/search"
@ -37,6 +38,7 @@ type Service interface {
type service struct { type service struct {
repo Repository repo Repository
interceptSvc *intercept.Service
reqLogSvc reqlog.Service reqLogSvc reqlog.Service
senderSvc sender.Service senderSvc sender.Service
scope *scope.Scope scope *scope.Scope
@ -74,6 +76,7 @@ var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`)
type Config struct { type Config struct {
Repository Repository Repository Repository
InterceptService *intercept.Service
ReqLogService reqlog.Service ReqLogService reqlog.Service
SenderService sender.Service SenderService sender.Service
Scope *scope.Scope Scope *scope.Scope
@ -83,6 +86,7 @@ type Config struct {
func NewService(cfg Config) (Service, error) { func NewService(cfg Config) (Service, error) {
return &service{ return &service{
repo: cfg.Repository, repo: cfg.Repository,
interceptSvc: cfg.InterceptService,
reqLogSvc: cfg.ReqLogService, reqLogSvc: cfg.ReqLogService,
senderSvc: cfg.SenderService, senderSvc: cfg.SenderService,
scope: cfg.Scope, scope: cfg.Scope,
@ -117,6 +121,7 @@ func (svc *service) CloseProject() error {
} }
svc.activeProjectID = ulid.ULID{} svc.activeProjectID = ulid.ULID{}
svc.interceptSvc.ClearRequests()
svc.reqLogSvc.SetActiveProjectID(ulid.ULID{}) svc.reqLogSvc.SetActiveProjectID(ulid.ULID{})
svc.reqLogSvc.SetBypassOutOfScopeRequests(false) svc.reqLogSvc.SetBypassOutOfScopeRequests(false)
svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{}) svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{})

View File

@ -0,0 +1,189 @@
package intercept
import (
"context"
"errors"
"net/http"
"sort"
"sync"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/log"
"github.com/dstotijn/hetty/pkg/proxy"
)
var (
ErrRequestAborted = errors.New("intercept: request was aborted")
ErrRequestNotFound = errors.New("intercept: request not found")
ErrRequestDone = errors.New("intercept: request is done")
)
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Request struct {
req *http.Request
ch chan<- *http.Request
done <-chan struct{}
}
type Service struct {
mu *sync.RWMutex
requests map[ulid.ULID]Request
logger log.Logger
}
type Config struct {
Logger log.Logger
}
// RequestIDs implements sort.Interface.
type RequestIDs []ulid.ULID
func NewService(cfg Config) *Service {
s := &Service{
mu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
logger: cfg.Logger,
}
if s.logger == nil {
s.logger = log.NewNopLogger()
}
return s
}
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP
// requests.
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
// (typically `context.Canceled`).
modifiedReq, err := svc.Intercept(req.Context(), req)
switch {
case errors.Is(err, ErrRequestAborted):
svc.logger.Debugw("Stopping intercept, request was aborted.")
// Prevent further processing by replacing req.Context with a cancelled context value.
// This will cause the http.Roundtripper in the `proxy` package to
// handle this request as an error.
ctx, cancel := context.WithCancel(context.Background())
cancel()
*req = *req.WithContext(ctx)
case errors.Is(err, context.Canceled):
svc.logger.Debugw("Stopping intercept, context was cancelled.")
case err != nil:
svc.logger.Errorw("Failed to intercept request.",
"error", err)
default:
*req = *modifiedReq.WithContext(req.Context())
next(req)
}
}
}
// Intercept adds an HTTP request to an array of pending intercepted requests, alongside channels used for sending a
// cancellation signal and receiving a modified request. It's safe for concurrent use.
func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Request, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: request doesn't have an ID.")
return req, nil
}
ch := make(chan *http.Request)
done := make(chan struct{})
svc.mu.Lock()
svc.requests[reqID] = Request{
req: req,
ch: ch,
done: done,
}
svc.mu.Unlock()
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the request should be removed from the requests queue.
defer func() {
close(done)
svc.mu.Lock()
defer svc.mu.Unlock()
delete(svc.requests, reqID)
}()
select {
case modReq := <-ch:
if modReq == nil {
return nil, ErrRequestAborted
}
return modReq, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
// cancelled. It's safe for concurrent use.
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error {
svc.mu.RLock()
req, ok := svc.requests[reqID]
svc.mu.RUnlock()
if !ok {
return ErrRequestNotFound
}
select {
case <-req.done:
return ErrRequestDone
case req.ch <- modReq:
return nil
}
}
func (svc *Service) ClearRequests() {
svc.mu.Lock()
defer svc.mu.Unlock()
for _, req := range svc.requests {
select {
case <-req.done:
case req.ch <- nil:
}
}
}
// Requests returns a list of pending intercepted requests. It's safe for concurrent use.
func (svc *Service) Requests() []*http.Request {
svc.mu.RLock()
defer svc.mu.RUnlock()
ids := make([]ulid.ULID, 0, len(svc.requests))
for id := range svc.requests {
ids = append(ids, id)
}
sort.Sort(RequestIDs(ids))
reqs := make([]*http.Request, len(ids))
for i, id := range ids {
reqs[i] = svc.requests[id].req
}
return reqs
}
func (ids RequestIDs) Len() int {
return len(ids)
}
func (ids RequestIDs) Less(i, j int) bool {
return ids[i].Compare(ids[j]) == -1
}
func (ids RequestIDs) Swap(i, j int) {
ids[i], ids[j] = ids[j], ids[i]
}

View File

@ -0,0 +1,122 @@
package intercept_test
import (
"context"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/oklog/ulid"
"go.uber.org/zap"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/proxy/intercept"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestRequestModifier(t *testing.T) {
t.Parallel()
t.Run("modify request that's not found", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestNotFound) {
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("modify request that's done", func(t *testing.T) {
t.Parallel()
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(ctx)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
next := func(req *http.Request) {}
go svc.RequestModifier(next)(req)
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
cancel()
err := svc.ModifyRequest(reqID, nil)
if !errors.Is(err, intercept.ErrRequestDone) {
t.Fatalf("expected `interept.ErrRequestDone`, got: %v", err)
}
})
t.Run("modify intercepted request", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
req.Header.Set("X-Foo", "foo")
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
modReq := req.Clone(context.Background())
modReq.Header.Set("X-Foo", "bar")
logger, _ := zap.NewDevelopment()
svc := intercept.NewService(intercept.Config{
Logger: logger.Sugar(),
})
var got *http.Request
next := func(req *http.Request) {
got = req.Clone(context.Background())
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
svc.RequestModifier(next)(req)
wg.Done()
}()
// Wait shortly, to allow the req modifier goroutine to add `req` to the
// array of intercepted reqs.
time.Sleep(10 * time.Millisecond)
err := svc.ModifyRequest(reqID, modReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
wg.Wait()
if got == nil {
t.Fatal("expected `got` not to be nil")
}
if exp := "bar"; exp != got.Header.Get("X-Foo") {
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
}
})
}

View File

@ -7,16 +7,22 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"time"
"github.com/dstotijn/hetty/pkg/log" "github.com/dstotijn/hetty/pkg/log"
"github.com/oklog/ulid"
) )
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type contextKey int type contextKey int
const ReqLogIDKey contextKey = 0 const reqIDKey contextKey = 0
// Proxy implements http.Handler and offers MITM behaviour for modifying // Proxy implements http.Handler and offers MITM behaviour for modifying
// HTTP requests and responses. // HTTP requests and responses.
@ -69,6 +75,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
ctx := context.WithValue(r.Context(), reqIDKey, reqID)
*r = *r.WithContext(ctx)
p.handler.ServeHTTP(w, r) p.handler.ServeHTTP(w, r)
} }
@ -110,6 +120,15 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
return fn(res) return fn(res)
} }
func WithRequestID(ctx context.Context, id ulid.ULID) context.Context {
return context.WithValue(ctx, reqIDKey, id)
}
func RequestIDFromContext(ctx context.Context) (ulid.ULID, bool) {
id, ok := ctx.Value(reqIDKey).(ulid.ULID)
return id, ok
}
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel. // handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
// During the TLS handshake with the client, we use the proxy's CA config to // During the TLS handshake with the client, we use the proxy's CA config to
// create a certificate on-the-fly. // create a certificate on-the-fly.
@ -170,12 +189,10 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
} }
func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) { func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
return
}
p.logger.Errorw("Failed to proxy request.", p.logger.Errorw("Failed to proxy request.",
"error", err) "error", err)
}
w.WriteHeader(http.StatusBadGateway) w.WriteHeader(http.StatusBadGateway)
} }

View File

@ -8,10 +8,8 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand"
"net/http" "net/http"
"net/url" "net/url"
"time"
"github.com/oklog/ulid" "github.com/oklog/ulid"
@ -23,16 +21,16 @@ import (
type contextKey int type contextKey int
const LogBypassedKey contextKey = 0 const (
LogBypassedKey contextKey = iota
ReqLogIDKey
)
var ( var (
ErrRequestNotFound = errors.New("reqlog: request not found") ErrRequestNotFound = errors.New("reqlog: request not found")
ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set") ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set")
) )
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type RequestLog struct { type RequestLog struct {
ID ulid.ULID ID ulid.ULID
ProjectID ulid.ULID ProjectID ulid.ULID
@ -170,8 +168,14 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
return return
} }
reqID, ok := proxy.RequestIDFromContext(req.Context())
if !ok {
svc.logger.Errorw("Bypassed logging: request doesn't have an ID.")
return
}
reqLog := RequestLog{ reqLog := RequestLog{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy), ID: reqID,
ProjectID: svc.activeProjectID, ProjectID: svc.activeProjectID,
Method: clone.Method, Method: clone.Method,
URL: clone.URL, URL: clone.URL,
@ -191,7 +195,7 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
"reqLogID", reqLog.ID.String(), "reqLogID", reqLog.ID.String(),
"url", reqLog.URL.String()) "url", reqLog.URL.String())
ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID) ctx := context.WithValue(req.Context(), ReqLogIDKey, reqLog.ID)
*req = *req.WithContext(ctx) *req = *req.WithContext(ctx)
} }
} }
@ -206,7 +210,7 @@ func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil return nil
} }
reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID) reqLogID, ok := res.Request.Context().Value(ReqLogIDKey).(ulid.ULID)
if !ok { if !ok {
return errors.New("reqlog: request is missing ID") return errors.New("reqlog: request is missing ID")
} }

View File

@ -41,6 +41,8 @@ func TestRequestModifier(t *testing.T) {
} }
reqModFn := svc.RequestModifier(next) reqModFn := svc.RequestModifier(next)
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar")) req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req = req.WithContext(proxy.WithRequestID(req.Context(), reqID))
reqModFn(req) reqModFn(req)
@ -88,7 +90,7 @@ func TestResponseModifier(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar")) req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID)) req = req.WithContext(context.WithValue(req.Context(), reqlog.ReqLogIDKey, reqLogID))
res := &http.Response{ res := &http.Response{
Request: req, Request: req,