diff --git a/cmd/hetty/hetty.go b/cmd/hetty/hetty.go index 90d7095..d7e3672 100644 --- a/cmd/hetty/hetty.go +++ b/cmd/hetty/hetty.go @@ -28,6 +28,7 @@ import ( "github.com/dstotijn/hetty/pkg/db/badger" "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proxy" + "github.com/dstotijn/hetty/pkg/proxy/intercept" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/sender" @@ -175,16 +176,21 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error { Logger: cmd.config.logger.Named("reqlog").Sugar(), }) + interceptService := intercept.NewService(intercept.Config{ + Logger: cmd.config.logger.Named("intercept").Sugar(), + }) + senderService := sender.NewService(sender.Config{ Repository: badger, ReqLogService: reqLogService, }) projService, err := proj.NewService(proj.Config{ - Repository: badger, - ReqLogService: reqLogService, - SenderService: senderService, - Scope: scope, + Repository: badger, + InterceptService: interceptService, + ReqLogService: reqLogService, + SenderService: senderService, + Scope: scope, }) if err != nil { cmd.config.logger.Fatal("Failed to create new projects service.", zap.Error(err)) @@ -201,6 +207,7 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error { proxy.UseRequestModifier(reqLogService.RequestModifier) proxy.UseResponseModifier(reqLogService.ResponseModifier) + proxy.UseRequestModifier(interceptService.RequestModifier) fsSub, err := fs.Sub(adminContent, "admin") if err != nil { @@ -231,6 +238,7 @@ func (cmd *HettyCommand) Exec(ctx context.Context, _ []string) error { adminRouter.Path(gqlEndpoint).Handler(api.HTTPHandler(&api.Resolver{ ProjectService: projService, RequestLogService: reqLogService, + InterceptService: interceptService, SenderService: senderService, }, gqlEndpoint)) diff --git a/pkg/api/generated.go b/pkg/api/generated.go index 4aea32f..2d5040c 100644 --- a/pkg/api/generated.go +++ b/pkg/api/generated.go @@ -66,6 +66,15 @@ type ComplexityRoot struct { Value func(childComplexity int) int } + HTTPRequest struct { + Body func(childComplexity int) int + Headers func(childComplexity int) int + ID func(childComplexity int) int + Method func(childComplexity int) int + Proto func(childComplexity int) int + URL func(childComplexity int) int + } + HTTPRequestLog struct { Body func(childComplexity int) int Headers func(childComplexity int) int @@ -91,6 +100,10 @@ type ComplexityRoot struct { StatusReason func(childComplexity int) int } + ModifyRequestResult struct { + Success func(childComplexity int) int + } + Mutation struct { ClearHTTPRequestLog func(childComplexity int) int CloseProject func(childComplexity int) int @@ -99,6 +112,7 @@ type ComplexityRoot struct { CreateSenderRequestFromHTTPRequestLog func(childComplexity int, id ulid.ULID) int DeleteProject func(childComplexity int, id ulid.ULID) int DeleteSenderRequests func(childComplexity int) int + ModifyRequest func(childComplexity int, request ModifyRequestInput) int OpenProject func(childComplexity int, id ulid.ULID) int SendRequest func(childComplexity int, id ulid.ULID) int SetHTTPRequestLogFilter func(childComplexity int, filter *HTTPRequestLogFilterInput) int @@ -117,6 +131,7 @@ type ComplexityRoot struct { HTTPRequestLog func(childComplexity int, id ulid.ULID) int HTTPRequestLogFilter func(childComplexity int) int HTTPRequestLogs func(childComplexity int) int + InterceptedRequests func(childComplexity int) int Projects func(childComplexity int) int Scope func(childComplexity int) int SenderRequest func(childComplexity int, id ulid.ULID) int @@ -165,6 +180,7 @@ type MutationResolver interface { CreateSenderRequestFromHTTPRequestLog(ctx context.Context, id ulid.ULID) (*SenderRequest, error) SendRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) DeleteSenderRequests(ctx context.Context) (*DeleteSenderRequestsResult, error) + ModifyRequest(ctx context.Context, request ModifyRequestInput) (*ModifyRequestResult, error) } type QueryResolver interface { HTTPRequestLog(ctx context.Context, id ulid.ULID) (*HTTPRequestLog, error) @@ -175,6 +191,7 @@ type QueryResolver interface { Scope(ctx context.Context) ([]ScopeRule, error) SenderRequest(ctx context.Context, id ulid.ULID) (*SenderRequest, error) SenderRequests(ctx context.Context) ([]SenderRequest, error) + InterceptedRequests(ctx context.Context) ([]HTTPRequest, error) } type executableSchema struct { @@ -234,6 +251,48 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.HTTPHeader.Value(childComplexity), true + case "HttpRequest.body": + if e.complexity.HTTPRequest.Body == nil { + break + } + + return e.complexity.HTTPRequest.Body(childComplexity), true + + case "HttpRequest.headers": + if e.complexity.HTTPRequest.Headers == nil { + break + } + + return e.complexity.HTTPRequest.Headers(childComplexity), true + + case "HttpRequest.id": + if e.complexity.HTTPRequest.ID == nil { + break + } + + return e.complexity.HTTPRequest.ID(childComplexity), true + + case "HttpRequest.method": + if e.complexity.HTTPRequest.Method == nil { + break + } + + return e.complexity.HTTPRequest.Method(childComplexity), true + + case "HttpRequest.proto": + if e.complexity.HTTPRequest.Proto == nil { + break + } + + return e.complexity.HTTPRequest.Proto(childComplexity), true + + case "HttpRequest.url": + if e.complexity.HTTPRequest.URL == nil { + break + } + + return e.complexity.HTTPRequest.URL(childComplexity), true + case "HttpRequestLog.body": if e.complexity.HTTPRequestLog.Body == nil { break @@ -346,6 +405,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.HTTPResponseLog.StatusReason(childComplexity), true + case "ModifyRequestResult.success": + if e.complexity.ModifyRequestResult.Success == nil { + break + } + + return e.complexity.ModifyRequestResult.Success(childComplexity), true + case "Mutation.clearHTTPRequestLog": if e.complexity.Mutation.ClearHTTPRequestLog == nil { break @@ -415,6 +481,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.DeleteSenderRequests(childComplexity), true + case "Mutation.modifyRequest": + if e.complexity.Mutation.ModifyRequest == nil { + break + } + + args, err := ec.field_Mutation_modifyRequest_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.ModifyRequest(childComplexity, args["request"].(ModifyRequestInput)), true + case "Mutation.openProject": if e.complexity.Mutation.OpenProject == nil { break @@ -529,6 +607,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.HTTPRequestLogs(childComplexity), true + case "Query.interceptedRequests": + if e.complexity.Query.InterceptedRequests == nil { + break + } + + return e.complexity.Query.InterceptedRequests(childComplexity), true + case "Query.projects": if e.complexity.Query.Projects == nil { break @@ -856,6 +941,28 @@ type SenderRequestFilter { searchExpression: String } +type HttpRequest { + id: ID! + url: URL! + method: HttpMethod! + proto: HttpProtocol! + headers: [HttpHeader!]! + body: String +} + +input ModifyRequestInput { + id: ID! + url: URL! + method: HttpMethod! + proto: HttpProtocol! + headers: [HttpHeaderInput!] + body: String +} + +type ModifyRequestResult { + success: Boolean! +} + type Query { httpRequestLog(id: ID!): HttpRequestLog httpRequestLogs: [HttpRequestLog!]! @@ -865,6 +972,7 @@ type Query { scope: [ScopeRule!]! senderRequest(id: ID!): SenderRequest senderRequests: [SenderRequest!]! + interceptedRequests: [HttpRequest!]! } type Mutation { @@ -882,6 +990,7 @@ type Mutation { createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest! sendRequest(id: ID!): SenderRequest! deleteSenderRequests: DeleteSenderRequestsResult! + modifyRequest(request: ModifyRequestInput!): ModifyRequestResult! } enum HttpMethod { @@ -973,6 +1082,21 @@ func (ec *executionContext) field_Mutation_deleteProject_args(ctx context.Contex return args, nil } +func (ec *executionContext) field_Mutation_modifyRequest_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 ModifyRequestInput + if tmp, ok := rawArgs["request"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("request")) + arg0, err = ec.unmarshalNModifyRequestInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["request"] = arg0 + return args, nil +} + func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -1341,6 +1465,213 @@ func (ec *executionContext) _HttpHeader_value(ctx context.Context, field graphql return ec.marshalNString2string(ctx, field.Selections, res) } +func (ec *executionContext) _HttpRequest_id(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(ulid.ULID) + fc.Result = res + return ec.marshalNID2githubᚗcomᚋoklogᚋulidᚐULID(ctx, field.Selections, res) +} + +func (ec *executionContext) _HttpRequest_url(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.URL, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*url.URL) + fc.Result = res + return ec.marshalNURL2ᚖnetᚋurlᚐURL(ctx, field.Selections, res) +} + +func (ec *executionContext) _HttpRequest_method(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Method, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(HTTPMethod) + fc.Result = res + return ec.marshalNHttpMethod2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPMethod(ctx, field.Selections, res) +} + +func (ec *executionContext) _HttpRequest_proto(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Proto, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(HTTPProtocol) + fc.Result = res + return ec.marshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPProtocol(ctx, field.Selections, res) +} + +func (ec *executionContext) _HttpRequest_headers(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Headers, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]HTTPHeader) + fc.Result = res + return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) _HttpRequest_body(ctx context.Context, field graphql.CollectedField, obj *HTTPRequest) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequest", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Body, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + func (ec *executionContext) _HttpRequestLog_id(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLog) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1889,6 +2220,41 @@ func (ec *executionContext) _HttpResponseLog_headers(ctx context.Context, field return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res) } +func (ec *executionContext) _ModifyRequestResult_success(ctx context.Context, field graphql.CollectedField, obj *ModifyRequestResult) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ModifyRequestResult", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Success, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _Mutation_createProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2360,6 +2726,48 @@ func (ec *executionContext) _Mutation_deleteSenderRequests(ctx context.Context, return ec.marshalNDeleteSenderRequestsResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐDeleteSenderRequestsResult(ctx, field.Selections, res) } +func (ec *executionContext) _Mutation_modifyRequest(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Mutation", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Mutation_modifyRequest_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().ModifyRequest(rctx, args["request"].(ModifyRequestInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*ModifyRequestResult) + fc.Result = res + return ec.marshalNModifyRequestResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx, field.Selections, res) +} + func (ec *executionContext) _Project_id(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2747,6 +3155,41 @@ func (ec *executionContext) _Query_senderRequests(ctx context.Context, field gra return ec.marshalNSenderRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐSenderRequestᚄ(ctx, field.Selections, res) } +func (ec *executionContext) _Query_interceptedRequests(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().InterceptedRequests(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]HTTPRequest) + fc.Result = res + return ec.marshalNHttpRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestᚄ(ctx, field.Selections, res) +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -4532,6 +4975,69 @@ func (ec *executionContext) unmarshalInputHttpRequestLogFilterInput(ctx context. return it, nil } +func (ec *executionContext) unmarshalInputModifyRequestInput(ctx context.Context, obj interface{}) (ModifyRequestInput, error) { + var it ModifyRequestInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + for k, v := range asMap { + switch k { + case "id": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id")) + it.ID, err = ec.unmarshalNID2githubᚗcomᚋoklogᚋulidᚐULID(ctx, v) + if err != nil { + return it, err + } + case "url": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("url")) + it.URL, err = ec.unmarshalNURL2ᚖnetᚋurlᚐURL(ctx, v) + if err != nil { + return it, err + } + case "method": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("method")) + it.Method, err = ec.unmarshalNHttpMethod2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPMethod(ctx, v) + if err != nil { + return it, err + } + case "proto": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("proto")) + it.Proto, err = ec.unmarshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPProtocol(ctx, v) + if err != nil { + return it, err + } + case "headers": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("headers")) + it.Headers, err = ec.unmarshalOHttpHeaderInput2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderInputᚄ(ctx, v) + if err != nil { + return it, err + } + case "body": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("body")) + it.Body, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context, obj interface{}) (ScopeHeaderInput, error) { var it ScopeHeaderInput asMap := map[string]interface{}{} @@ -4844,6 +5350,55 @@ func (ec *executionContext) _HttpHeader(ctx context.Context, sel ast.SelectionSe return out } +var httpRequestImplementors = []string{"HttpRequest"} + +func (ec *executionContext) _HttpRequest(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequest) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, httpRequestImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("HttpRequest") + case "id": + out.Values[i] = ec._HttpRequest_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "url": + out.Values[i] = ec._HttpRequest_url(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "method": + out.Values[i] = ec._HttpRequest_method(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "proto": + out.Values[i] = ec._HttpRequest_proto(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "headers": + out.Values[i] = ec._HttpRequest_headers(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "body": + out.Values[i] = ec._HttpRequest_body(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var httpRequestLogImplementors = []string{"HttpRequestLog"} func (ec *executionContext) _HttpRequestLog(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequestLog) graphql.Marshaler { @@ -4978,6 +5533,33 @@ func (ec *executionContext) _HttpResponseLog(ctx context.Context, sel ast.Select return out } +var modifyRequestResultImplementors = []string{"ModifyRequestResult"} + +func (ec *executionContext) _ModifyRequestResult(ctx context.Context, sel ast.SelectionSet, obj *ModifyRequestResult) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, modifyRequestResultImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ModifyRequestResult") + case "success": + out.Values[i] = ec._ModifyRequestResult_success(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var mutationImplementors = []string{"Mutation"} func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { @@ -5041,6 +5623,11 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) if out.Values[i] == graphql.Null { invalids++ } + case "modifyRequest": + out.Values[i] = ec._Mutation_modifyRequest(ctx, field) + if out.Values[i] == graphql.Null { + invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -5204,6 +5791,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "interceptedRequests": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_interceptedRequests(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "__type": out.Values[i] = ec._Query___type(ctx, field) case "__schema": @@ -5751,6 +6352,54 @@ func (ec *executionContext) marshalNHttpProtocol2githubᚗcomᚋdstotijnᚋhetty return v } +func (ec *executionContext) marshalNHttpRequest2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequest(ctx context.Context, sel ast.SelectionSet, v HTTPRequest) graphql.Marshaler { + return ec._HttpRequest(ctx, sel, &v) +} + +func (ec *executionContext) marshalNHttpRequest2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []HTTPRequest) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNHttpRequest2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequest(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) marshalNHttpRequestLog2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLog(ctx context.Context, sel ast.SelectionSet, v HTTPRequestLog) graphql.Marshaler { return ec._HttpRequestLog(ctx, sel, &v) } @@ -5829,6 +6478,25 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti return res } +func (ec *executionContext) unmarshalNModifyRequestInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestInput(ctx context.Context, v interface{}) (ModifyRequestInput, error) { + res, err := ec.unmarshalInputModifyRequestInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) marshalNModifyRequestResult2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx context.Context, sel ast.SelectionSet, v ModifyRequestResult) graphql.Marshaler { + return ec._ModifyRequestResult(ctx, sel, &v) +} + +func (ec *executionContext) marshalNModifyRequestResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐModifyRequestResult(ctx context.Context, sel ast.SelectionSet, v *ModifyRequestResult) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + return ec._ModifyRequestResult(ctx, sel, v) +} + func (ec *executionContext) marshalNProject2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐProject(ctx context.Context, sel ast.SelectionSet, v Project) graphql.Marshaler { return ec._Project(ctx, sel, &v) } diff --git a/pkg/api/models_gen.go b/pkg/api/models_gen.go index 8f4d3e0..6cc66b5 100644 --- a/pkg/api/models_gen.go +++ b/pkg/api/models_gen.go @@ -38,6 +38,15 @@ type HTTPHeaderInput struct { Value string `json:"value"` } +type HTTPRequest struct { + ID ulid.ULID `json:"id"` + URL *url.URL `json:"url"` + Method HTTPMethod `json:"method"` + Proto HTTPProtocol `json:"proto"` + Headers []HTTPHeader `json:"headers"` + Body *string `json:"body"` +} + type HTTPRequestLog struct { ID ulid.ULID `json:"id"` URL string `json:"url"` @@ -69,6 +78,19 @@ type HTTPResponseLog struct { Headers []HTTPHeader `json:"headers"` } +type ModifyRequestInput struct { + ID ulid.ULID `json:"id"` + URL *url.URL `json:"url"` + Method HTTPMethod `json:"method"` + Proto HTTPProtocol `json:"proto"` + Headers []HTTPHeaderInput `json:"headers"` + Body *string `json:"body"` +} + +type ModifyRequestResult struct { + Success bool `json:"success"` +} + type Project struct { ID ulid.ULID `json:"id"` Name string `json:"name"` diff --git a/pkg/api/resolvers.go b/pkg/api/resolvers.go index c481441..d611e9b 100644 --- a/pkg/api/resolvers.go +++ b/pkg/api/resolvers.go @@ -3,9 +3,11 @@ package api //go:generate go run github.com/99designs/gqlgen import ( + "bytes" "context" "errors" "fmt" + "io/ioutil" "net/http" "regexp" "strings" @@ -15,6 +17,8 @@ import ( "github.com/vektah/gqlparser/v2/gqlerror" "github.com/dstotijn/hetty/pkg/proj" + "github.com/dstotijn/hetty/pkg/proxy" + "github.com/dstotijn/hetty/pkg/proxy/intercept" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/search" @@ -36,6 +40,7 @@ var revHTTPProtocolMap = map[HTTPProtocol]string{ type Resolver struct { ProjectService proj.Service RequestLogService reqlog.Service + InterceptService *intercept.Service SenderService sender.Service } @@ -520,6 +525,46 @@ func (r *mutationResolver) DeleteSenderRequests(ctx context.Context) (*DeleteSen return &DeleteSenderRequestsResult{true}, nil } +func (r *queryResolver) InterceptedRequests(ctx context.Context) ([]HTTPRequest, error) { + reqs := r.InterceptService.Requests() + httpReqs := make([]HTTPRequest, len(reqs)) + + for i, req := range reqs { + req, err := parseHTTPRequest(req) + if err != nil { + return nil, err + } + + httpReqs[i] = req + } + + return httpReqs, nil +} + +func (r *mutationResolver) ModifyRequest(ctx context.Context, input ModifyRequestInput) (*ModifyRequestResult, error) { + body := "" + if input.Body != nil { + body = *input.Body + } + + //nolint:noctx + req, err := http.NewRequest(input.Method.String(), input.URL.String(), strings.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to construct HTTP request: %w", err) + } + + for _, header := range input.Headers { + req.Header.Add(header.Key, header.Value) + } + + err = r.InterceptService.ModifyRequest(input.ID, req) + if err != nil { + return nil, fmt.Errorf("could not modify http request: %w", err) + } + + return &ModifyRequestResult{Success: true}, nil +} + func parseSenderRequest(req sender.Request) (SenderRequest, error) { method := HTTPMethod(req.Method) if method != "" && !method.IsValid() { @@ -575,6 +620,56 @@ func parseSenderRequest(req sender.Request) (SenderRequest, error) { return senderReq, nil } +func parseHTTPRequest(req *http.Request) (HTTPRequest, error) { + method := HTTPMethod(req.Method) + if method != "" && !method.IsValid() { + return HTTPRequest{}, fmt.Errorf("http request has invalid method: %v", method) + } + + reqProto := httpProtocolMap[req.Proto] + if !reqProto.IsValid() { + return HTTPRequest{}, fmt.Errorf("http request has invalid protocol: %v", req.Proto) + } + + id, ok := proxy.RequestIDFromContext(req.Context()) + if !ok { + return HTTPRequest{}, errors.New("http request has missing ID") + } + + httpReq := HTTPRequest{ + ID: id, + URL: req.URL, + Method: method, + Proto: HTTPProtocol(req.Proto), + } + + if req.Header != nil { + httpReq.Headers = make([]HTTPHeader, 0) + + for key, values := range req.Header { + for _, value := range values { + httpReq.Headers = append(httpReq.Headers, HTTPHeader{ + Key: key, + Value: value, + }) + } + } + } + + if req.Body != nil { + body, err := ioutil.ReadAll(req.Body) + if err != nil { + return HTTPRequest{}, fmt.Errorf("failed to read request body: %w", err) + } + + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + bodyStr := string(body) + httpReq.Body = &bodyStr + } + + return httpReq, nil +} + func stringPtrToRegexp(s *string) (*regexp.Regexp, error) { if s == nil { return nil, nil diff --git a/pkg/api/schema.graphql b/pkg/api/schema.graphql index 5bbfd05..797a8bf 100644 --- a/pkg/api/schema.graphql +++ b/pkg/api/schema.graphql @@ -116,6 +116,28 @@ type SenderRequestFilter { searchExpression: String } +type HttpRequest { + id: ID! + url: URL! + method: HttpMethod! + proto: HttpProtocol! + headers: [HttpHeader!]! + body: String +} + +input ModifyRequestInput { + id: ID! + url: URL! + method: HttpMethod! + proto: HttpProtocol! + headers: [HttpHeaderInput!] + body: String +} + +type ModifyRequestResult { + success: Boolean! +} + type Query { httpRequestLog(id: ID!): HttpRequestLog httpRequestLogs: [HttpRequestLog!]! @@ -125,6 +147,7 @@ type Query { scope: [ScopeRule!]! senderRequest(id: ID!): SenderRequest senderRequests: [SenderRequest!]! + interceptedRequests: [HttpRequest!]! } type Mutation { @@ -142,6 +165,7 @@ type Mutation { createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest! sendRequest(id: ID!): SenderRequest! deleteSenderRequests: DeleteSenderRequestsResult! + modifyRequest(request: ModifyRequestInput!): ModifyRequestResult! } enum HttpMethod { diff --git a/pkg/proj/proj.go b/pkg/proj/proj.go index 5456df7..e26443c 100644 --- a/pkg/proj/proj.go +++ b/pkg/proj/proj.go @@ -11,6 +11,7 @@ import ( "github.com/oklog/ulid" + "github.com/dstotijn/hetty/pkg/proxy/intercept" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/search" @@ -37,6 +38,7 @@ type Service interface { type service struct { repo Repository + interceptSvc *intercept.Service reqLogSvc reqlog.Service senderSvc sender.Service scope *scope.Scope @@ -73,19 +75,21 @@ var ( var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`) type Config struct { - Repository Repository - ReqLogService reqlog.Service - SenderService sender.Service - Scope *scope.Scope + Repository Repository + InterceptService *intercept.Service + ReqLogService reqlog.Service + SenderService sender.Service + Scope *scope.Scope } // NewService returns a new Service. func NewService(cfg Config) (Service, error) { return &service{ - repo: cfg.Repository, - reqLogSvc: cfg.ReqLogService, - senderSvc: cfg.SenderService, - scope: cfg.Scope, + repo: cfg.Repository, + interceptSvc: cfg.InterceptService, + reqLogSvc: cfg.ReqLogService, + senderSvc: cfg.SenderService, + scope: cfg.Scope, }, nil } @@ -117,6 +121,7 @@ func (svc *service) CloseProject() error { } svc.activeProjectID = ulid.ULID{} + svc.interceptSvc.ClearRequests() svc.reqLogSvc.SetActiveProjectID(ulid.ULID{}) svc.reqLogSvc.SetBypassOutOfScopeRequests(false) svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{}) diff --git a/pkg/proxy/intercept/intercept.go b/pkg/proxy/intercept/intercept.go new file mode 100644 index 0000000..9b29205 --- /dev/null +++ b/pkg/proxy/intercept/intercept.go @@ -0,0 +1,189 @@ +package intercept + +import ( + "context" + "errors" + "net/http" + "sort" + "sync" + + "github.com/oklog/ulid" + + "github.com/dstotijn/hetty/pkg/log" + "github.com/dstotijn/hetty/pkg/proxy" +) + +var ( + ErrRequestAborted = errors.New("intercept: request was aborted") + ErrRequestNotFound = errors.New("intercept: request not found") + ErrRequestDone = errors.New("intercept: request is done") +) + +// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the +// routine that's awaiting it. Also contains a channel for receiving a cancellation signal. +type Request struct { + req *http.Request + ch chan<- *http.Request + done <-chan struct{} +} + +type Service struct { + mu *sync.RWMutex + requests map[ulid.ULID]Request + logger log.Logger +} + +type Config struct { + Logger log.Logger +} + +// RequestIDs implements sort.Interface. +type RequestIDs []ulid.ULID + +func NewService(cfg Config) *Service { + s := &Service{ + mu: &sync.RWMutex{}, + requests: make(map[ulid.ULID]Request), + logger: cfg.Logger, + } + + if s.logger == nil { + s.logger = log.NewNopLogger() + } + + return s +} + +// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP +// requests. +func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc { + return func(req *http.Request) { + // This is a blocking operation, that gets unblocked when either a modified request is returned or an error + // (typically `context.Canceled`). + modifiedReq, err := svc.Intercept(req.Context(), req) + + switch { + case errors.Is(err, ErrRequestAborted): + svc.logger.Debugw("Stopping intercept, request was aborted.") + // Prevent further processing by replacing req.Context with a cancelled context value. + // This will cause the http.Roundtripper in the `proxy` package to + // handle this request as an error. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + *req = *req.WithContext(ctx) + case errors.Is(err, context.Canceled): + svc.logger.Debugw("Stopping intercept, context was cancelled.") + case err != nil: + svc.logger.Errorw("Failed to intercept request.", + "error", err) + default: + *req = *modifiedReq.WithContext(req.Context()) + next(req) + } + } +} + +// Intercept adds an HTTP request to an array of pending intercepted requests, alongside channels used for sending a +// cancellation signal and receiving a modified request. It's safe for concurrent use. +func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Request, error) { + reqID, ok := proxy.RequestIDFromContext(ctx) + if !ok { + svc.logger.Errorw("Failed to intercept: request doesn't have an ID.") + return req, nil + } + + ch := make(chan *http.Request) + done := make(chan struct{}) + + svc.mu.Lock() + svc.requests[reqID] = Request{ + req: req, + ch: ch, + done: done, + } + svc.mu.Unlock() + + // Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders + // should be unblocked, and the request should be removed from the requests queue. + defer func() { + close(done) + svc.mu.Lock() + defer svc.mu.Unlock() + delete(svc.requests, reqID) + }() + + select { + case modReq := <-ch: + if modReq == nil { + return nil, ErrRequestAborted + } + + return modReq, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was +// cancelled. It's safe for concurrent use. +func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error { + svc.mu.RLock() + req, ok := svc.requests[reqID] + svc.mu.RUnlock() + + if !ok { + return ErrRequestNotFound + } + + select { + case <-req.done: + return ErrRequestDone + case req.ch <- modReq: + return nil + } +} + +func (svc *Service) ClearRequests() { + svc.mu.Lock() + defer svc.mu.Unlock() + + for _, req := range svc.requests { + select { + case <-req.done: + case req.ch <- nil: + } + } +} + +// Requests returns a list of pending intercepted requests. It's safe for concurrent use. +func (svc *Service) Requests() []*http.Request { + svc.mu.RLock() + defer svc.mu.RUnlock() + + ids := make([]ulid.ULID, 0, len(svc.requests)) + for id := range svc.requests { + ids = append(ids, id) + } + + sort.Sort(RequestIDs(ids)) + + reqs := make([]*http.Request, len(ids)) + for i, id := range ids { + reqs[i] = svc.requests[id].req + } + + return reqs +} + +func (ids RequestIDs) Len() int { + return len(ids) +} + +func (ids RequestIDs) Less(i, j int) bool { + return ids[i].Compare(ids[j]) == -1 +} + +func (ids RequestIDs) Swap(i, j int) { + ids[i], ids[j] = ids[j], ids[i] +} diff --git a/pkg/proxy/intercept/intercept_test.go b/pkg/proxy/intercept/intercept_test.go new file mode 100644 index 0000000..77b794c --- /dev/null +++ b/pkg/proxy/intercept/intercept_test.go @@ -0,0 +1,122 @@ +package intercept_test + +import ( + "context" + "errors" + "math/rand" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/oklog/ulid" + "go.uber.org/zap" + + "github.com/dstotijn/hetty/pkg/proxy" + "github.com/dstotijn/hetty/pkg/proxy/intercept" +) + +//nolint:gosec +var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano())) + +func TestRequestModifier(t *testing.T) { + t.Parallel() + + t.Run("modify request that's not found", func(t *testing.T) { + t.Parallel() + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + }) + + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + + err := svc.ModifyRequest(reqID, nil) + if !errors.Is(err, intercept.ErrRequestNotFound) { + t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err) + } + }) + + t.Run("modify request that's done", func(t *testing.T) { + t.Parallel() + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + *req = *req.WithContext(ctx) + *req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID)) + + next := func(req *http.Request) {} + go svc.RequestModifier(next)(req) + + // Wait shortly, to allow the req modifier goroutine to add `req` to the + // array of intercepted reqs. + time.Sleep(10 * time.Millisecond) + cancel() + + err := svc.ModifyRequest(reqID, nil) + if !errors.Is(err, intercept.ErrRequestDone) { + t.Fatalf("expected `interept.ErrRequestDone`, got: %v", err) + } + }) + + t.Run("modify intercepted request", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + req.Header.Set("X-Foo", "foo") + + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + *req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID)) + + modReq := req.Clone(context.Background()) + modReq.Header.Set("X-Foo", "bar") + + logger, _ := zap.NewDevelopment() + svc := intercept.NewService(intercept.Config{ + Logger: logger.Sugar(), + }) + + var got *http.Request + + next := func(req *http.Request) { + got = req.Clone(context.Background()) + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + svc.RequestModifier(next)(req) + wg.Done() + }() + + // Wait shortly, to allow the req modifier goroutine to add `req` to the + // array of intercepted reqs. + time.Sleep(10 * time.Millisecond) + + err := svc.ModifyRequest(reqID, modReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wg.Wait() + + if got == nil { + t.Fatal("expected `got` not to be nil") + } + + if exp := "bar"; exp != got.Header.Get("X-Foo") { + t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo")) + } + }) +} diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 0e30f65..0f4c57a 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -7,16 +7,22 @@ import ( "crypto/x509" "errors" "fmt" + "math/rand" "net" "net/http" "net/http/httputil" + "time" "github.com/dstotijn/hetty/pkg/log" + "github.com/oklog/ulid" ) +//nolint:gosec +var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano())) + type contextKey int -const ReqLogIDKey contextKey = 0 +const reqIDKey contextKey = 0 // Proxy implements http.Handler and offers MITM behaviour for modifying // HTTP requests and responses. @@ -69,6 +75,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + ctx := context.WithValue(r.Context(), reqIDKey, reqID) + *r = *r.WithContext(ctx) + p.handler.ServeHTTP(w, r) } @@ -110,6 +120,15 @@ func (p *Proxy) modifyResponse(res *http.Response) error { return fn(res) } +func WithRequestID(ctx context.Context, id ulid.ULID) context.Context { + return context.WithValue(ctx, reqIDKey, id) +} + +func RequestIDFromContext(ctx context.Context) (ulid.ULID, bool) { + id, ok := ctx.Value(reqIDKey).(ulid.ULID) + return id, ok +} + // handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel. // During the TLS handshake with the client, we use the proxy's CA config to // create a certificate on-the-fly. @@ -170,13 +189,11 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) { } func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) { - if errors.Is(err, context.Canceled) { - return + if !errors.Is(err, context.Canceled) { + p.logger.Errorw("Failed to proxy request.", + "error", err) } - p.logger.Errorw("Failed to proxy request.", - "error", err) - w.WriteHeader(http.StatusBadGateway) } diff --git a/pkg/reqlog/reqlog.go b/pkg/reqlog/reqlog.go index ce562e9..d7e652e 100644 --- a/pkg/reqlog/reqlog.go +++ b/pkg/reqlog/reqlog.go @@ -8,10 +8,8 @@ import ( "fmt" "io" "io/ioutil" - "math/rand" "net/http" "net/url" - "time" "github.com/oklog/ulid" @@ -23,16 +21,16 @@ import ( type contextKey int -const LogBypassedKey contextKey = 0 +const ( + LogBypassedKey contextKey = iota + ReqLogIDKey +) var ( ErrRequestNotFound = errors.New("reqlog: request not found") ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set") ) -//nolint:gosec -var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano())) - type RequestLog struct { ID ulid.ULID ProjectID ulid.ULID @@ -170,8 +168,14 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM return } + reqID, ok := proxy.RequestIDFromContext(req.Context()) + if !ok { + svc.logger.Errorw("Bypassed logging: request doesn't have an ID.") + return + } + reqLog := RequestLog{ - ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy), + ID: reqID, ProjectID: svc.activeProjectID, Method: clone.Method, URL: clone.URL, @@ -191,7 +195,7 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM "reqLogID", reqLog.ID.String(), "url", reqLog.URL.String()) - ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID) + ctx := context.WithValue(req.Context(), ReqLogIDKey, reqLog.ID) *req = *req.WithContext(ctx) } } @@ -206,7 +210,7 @@ func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon return nil } - reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID) + reqLogID, ok := res.Request.Context().Value(ReqLogIDKey).(ulid.ULID) if !ok { return errors.New("reqlog: request is missing ID") } diff --git a/pkg/reqlog/reqlog_test.go b/pkg/reqlog/reqlog_test.go index 38c0e7b..25eebeb 100644 --- a/pkg/reqlog/reqlog_test.go +++ b/pkg/reqlog/reqlog_test.go @@ -41,6 +41,8 @@ func TestRequestModifier(t *testing.T) { } reqModFn := svc.RequestModifier(next) req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar")) + reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) + req = req.WithContext(proxy.WithRequestID(req.Context(), reqID)) reqModFn(req) @@ -88,7 +90,7 @@ func TestResponseModifier(t *testing.T) { req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar")) reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy) - req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID)) + req = req.WithContext(context.WithValue(req.Context(), reqlog.ReqLogIDKey, reqLogID)) res := &http.Response{ Request: req,