Replace SQLite with BadgerDB

This commit is contained in:
David Stotijn
2022-01-21 11:45:54 +01:00
parent 8a3b3cbf02
commit d84d2d0905
49 changed files with 2496 additions and 2677 deletions

View File

@ -80,7 +80,6 @@ type ComplexityRoot struct {
Body func(childComplexity int) int
Headers func(childComplexity int) int
Proto func(childComplexity int) int
RequestID func(childComplexity int) int
StatusCode func(childComplexity int) int
StatusReason func(childComplexity int) int
}
@ -88,20 +87,22 @@ type ComplexityRoot struct {
Mutation struct {
ClearHTTPRequestLog func(childComplexity int) int
CloseProject func(childComplexity int) int
DeleteProject func(childComplexity int, name string) int
OpenProject func(childComplexity int, name string) int
CreateProject func(childComplexity int, name string) int
DeleteProject func(childComplexity int, id ULID) int
OpenProject func(childComplexity int, id ULID) int
SetHTTPRequestLogFilter func(childComplexity int, filter *HTTPRequestLogFilterInput) int
SetScope func(childComplexity int, scope []ScopeRuleInput) int
}
Project struct {
ID func(childComplexity int) int
IsActive func(childComplexity int) int
Name func(childComplexity int) int
}
Query struct {
ActiveProject func(childComplexity int) int
HTTPRequestLog func(childComplexity int, id int64) int
HTTPRequestLog func(childComplexity int, id ULID) int
HTTPRequestLogFilter func(childComplexity int) int
HTTPRequestLogs func(childComplexity int) int
Projects func(childComplexity int) int
@ -121,15 +122,16 @@ type ComplexityRoot struct {
}
type MutationResolver interface {
OpenProject(ctx context.Context, name string) (*Project, error)
CreateProject(ctx context.Context, name string) (*Project, error)
OpenProject(ctx context.Context, id ULID) (*Project, error)
CloseProject(ctx context.Context) (*CloseProjectResult, error)
DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error)
DeleteProject(ctx context.Context, id ULID) (*DeleteProjectResult, error)
ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error)
SetScope(ctx context.Context, scope []ScopeRuleInput) ([]ScopeRule, error)
SetHTTPRequestLogFilter(ctx context.Context, filter *HTTPRequestLogFilterInput) (*HTTPRequestLogFilter, error)
}
type QueryResolver interface {
HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error)
HTTPRequestLog(ctx context.Context, id ULID) (*HTTPRequestLog, error)
HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error)
HTTPRequestLogFilter(ctx context.Context) (*HTTPRequestLogFilter, error)
ActiveProject(ctx context.Context) (*Project, error)
@ -278,13 +280,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.HTTPResponseLog.Proto(childComplexity), true
case "HttpResponseLog.requestId":
if e.complexity.HTTPResponseLog.RequestID == nil {
break
}
return e.complexity.HTTPResponseLog.RequestID(childComplexity), true
case "HttpResponseLog.statusCode":
if e.complexity.HTTPResponseLog.StatusCode == nil {
break
@ -313,6 +308,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Mutation.CloseProject(childComplexity), true
case "Mutation.createProject":
if e.complexity.Mutation.CreateProject == nil {
break
}
args, err := ec.field_Mutation_createProject_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Mutation.CreateProject(childComplexity, args["name"].(string)), true
case "Mutation.deleteProject":
if e.complexity.Mutation.DeleteProject == nil {
break
@ -323,7 +330,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Mutation.DeleteProject(childComplexity, args["name"].(string)), true
return e.complexity.Mutation.DeleteProject(childComplexity, args["id"].(ULID)), true
case "Mutation.openProject":
if e.complexity.Mutation.OpenProject == nil {
@ -335,7 +342,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Mutation.OpenProject(childComplexity, args["name"].(string)), true
return e.complexity.Mutation.OpenProject(childComplexity, args["id"].(ULID)), true
case "Mutation.setHttpRequestLogFilter":
if e.complexity.Mutation.SetHTTPRequestLogFilter == nil {
@ -361,6 +368,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Mutation.SetScope(childComplexity, args["scope"].([]ScopeRuleInput)), true
case "Project.id":
if e.complexity.Project.ID == nil {
break
}
return e.complexity.Project.ID(childComplexity), true
case "Project.isActive":
if e.complexity.Project.IsActive == nil {
break
@ -392,7 +406,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Query.HTTPRequestLog(childComplexity, args["id"].(int64)), true
return e.complexity.Query.HTTPRequestLog(childComplexity, args["id"].(ULID)), true
case "Query.httpRequestLogFilter":
if e.complexity.Query.HTTPRequestLogFilter == nil {
@ -533,7 +547,6 @@ var sources = []*ast.Source{
}
type HttpResponseLog {
requestId: ID!
proto: String!
statusCode: Int!
statusReason: String!
@ -547,6 +560,7 @@ type HttpHeader {
}
type Project {
id: ID!
name: String!
isActive: Boolean!
}
@ -605,9 +619,10 @@ type Query {
}
type Mutation {
openProject(name: String!): Project
createProject(name: String!): Project
openProject(id: ID!): Project
closeProject: CloseProjectResult!
deleteProject(name: String!): DeleteProjectResult!
deleteProject(id: ID!): DeleteProjectResult!
clearHTTPRequestLog: ClearHTTPRequestLogResult!
setScope(scope: [ScopeRuleInput!]!): [ScopeRule!]!
setHttpRequestLogFilter(
@ -637,7 +652,7 @@ var parsedSchema = gqlparser.MustLoadSchema(sources...)
// region ***************************** args.gotpl *****************************
func (ec *executionContext) field_Mutation_deleteProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
func (ec *executionContext) field_Mutation_createProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 string
@ -652,18 +667,33 @@ func (ec *executionContext) field_Mutation_deleteProject_args(ctx context.Contex
return args, nil
}
func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
func (ec *executionContext) field_Mutation_deleteProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 string
if tmp, ok := rawArgs["name"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("name"))
arg0, err = ec.unmarshalNString2string(ctx, tmp)
var arg0 ULID
if tmp, ok := rawArgs["id"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id"))
arg0, err = ec.unmarshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx, tmp)
if err != nil {
return nil, err
}
}
args["name"] = arg0
args["id"] = arg0
return args, nil
}
func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 ULID
if tmp, ok := rawArgs["id"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id"))
arg0, err = ec.unmarshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx, tmp)
if err != nil {
return nil, err
}
}
args["id"] = arg0
return args, nil
}
@ -715,10 +745,10 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
func (ec *executionContext) field_Query_httpRequestLog_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 int64
var arg0 ULID
if tmp, ok := rawArgs["id"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id"))
arg0, err = ec.unmarshalNID2int64(ctx, tmp)
arg0, err = ec.unmarshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx, tmp)
if err != nil {
return nil, err
}
@ -970,9 +1000,9 @@ func (ec *executionContext) _HttpRequestLog_id(ctx context.Context, field graphq
}
return graphql.Null
}
res := resTmp.(int64)
res := resTmp.(ULID)
fc.Result = res
return ec.marshalNID2int64(ctx, field.Selections, res)
return ec.marshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequestLog_url(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLog) (ret graphql.Marshaler) {
@ -1281,41 +1311,6 @@ func (ec *executionContext) _HttpRequestLogFilter_searchExpression(ctx context.C
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpResponseLog_requestId(ctx context.Context, field graphql.CollectedField, obj *HTTPResponseLog) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "HttpResponseLog",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.RequestID, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(int64)
fc.Result = res
return ec.marshalNID2int64(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpResponseLog_proto(ctx context.Context, field graphql.CollectedField, obj *HTTPResponseLog) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
@ -1488,6 +1483,45 @@ func (ec *executionContext) _HttpResponseLog_headers(ctx context.Context, field
return ec.marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPHeaderᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) _Mutation_createProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "Mutation",
Field: field,
Args: nil,
IsMethod: true,
IsResolver: true,
}
ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Mutation_createProject_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Mutation().CreateProject(rctx, args["name"].(string))
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*Project)
fc.Result = res
return ec.marshalOProject2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐProject(ctx, field.Selections, res)
}
func (ec *executionContext) _Mutation_openProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
@ -1513,7 +1547,7 @@ func (ec *executionContext) _Mutation_openProject(ctx context.Context, field gra
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Mutation().OpenProject(rctx, args["name"].(string))
return ec.resolvers.Mutation().OpenProject(rctx, args["id"].(ULID))
})
if err != nil {
ec.Error(ctx, err)
@ -1587,7 +1621,7 @@ func (ec *executionContext) _Mutation_deleteProject(ctx context.Context, field g
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Mutation().DeleteProject(rctx, args["name"].(string))
return ec.resolvers.Mutation().DeleteProject(rctx, args["id"].(ULID))
})
if err != nil {
ec.Error(ctx, err)
@ -1720,6 +1754,41 @@ func (ec *executionContext) _Mutation_setHttpRequestLogFilter(ctx context.Contex
return ec.marshalOHttpRequestLogFilter2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilter(ctx, field.Selections, res)
}
func (ec *executionContext) _Project_id(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "Project",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.ID, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(ULID)
fc.Result = res
return ec.marshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx, field.Selections, res)
}
func (ec *executionContext) _Project_name(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
@ -1815,7 +1884,7 @@ func (ec *executionContext) _Query_httpRequestLog(ctx context.Context, field gra
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().HTTPRequestLog(rctx, args["id"].(int64))
return ec.resolvers.Query().HTTPRequestLog(rctx, args["id"].(ULID))
})
if err != nil {
ec.Error(ctx, err)
@ -2366,6 +2435,41 @@ func (ec *executionContext) ___Directive_args(ctx context.Context, field graphql
return ec.marshalN__InputValue2ᚕgithubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐInputValueᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) ___Directive_isRepeatable(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "__Directive",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.IsRepeatable, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(bool)
fc.Result = res
return ec.marshalNBoolean2bool(ctx, field.Selections, res)
}
func (ec *executionContext) ___EnumValue_name(ctx context.Context, field graphql.CollectedField, obj *introspection.EnumValue) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
@ -3318,7 +3422,10 @@ func (ec *executionContext) ___Type_ofType(ctx context.Context, field graphql.Co
func (ec *executionContext) unmarshalInputHttpRequestLogFilterInput(ctx context.Context, obj interface{}) (HTTPRequestLogFilterInput, error) {
var it HTTPRequestLogFilterInput
var asMap = obj.(map[string]interface{})
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
for k, v := range asMap {
switch k {
@ -3346,7 +3453,10 @@ func (ec *executionContext) unmarshalInputHttpRequestLogFilterInput(ctx context.
func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context, obj interface{}) (ScopeHeaderInput, error) {
var it ScopeHeaderInput
var asMap = obj.(map[string]interface{})
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
for k, v := range asMap {
switch k {
@ -3374,7 +3484,10 @@ func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context,
func (ec *executionContext) unmarshalInputScopeRuleInput(ctx context.Context, obj interface{}) (ScopeRuleInput, error) {
var it ScopeRuleInput
var asMap = obj.(map[string]interface{})
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
for k, v := range asMap {
switch k {
@ -3625,11 +3738,6 @@ func (ec *executionContext) _HttpResponseLog(ctx context.Context, sel ast.Select
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("HttpResponseLog")
case "requestId":
out.Values[i] = ec._HttpResponseLog_requestId(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "proto":
out.Values[i] = ec._HttpResponseLog_proto(ctx, field, obj)
if out.Values[i] == graphql.Null {
@ -3678,6 +3786,8 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet)
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("Mutation")
case "createProject":
out.Values[i] = ec._Mutation_createProject(ctx, field)
case "openProject":
out.Values[i] = ec._Mutation_openProject(ctx, field)
case "closeProject":
@ -3724,6 +3834,11 @@ func (ec *executionContext) _Project(ctx context.Context, sel ast.SelectionSet,
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("Project")
case "id":
out.Values[i] = ec._Project_id(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "name":
out.Values[i] = ec._Project_name(ctx, field, obj)
if out.Values[i] == graphql.Null {
@ -3932,6 +4047,11 @@ func (ec *executionContext) ___Directive(ctx context.Context, sel ast.SelectionS
if out.Values[i] == graphql.Null {
invalids++
}
case "isRepeatable":
out.Values[i] = ec.___Directive_isRepeatable(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
@ -4244,6 +4364,13 @@ func (ec *executionContext) marshalNHttpHeader2ᚕgithubᚗcomᚋdstotijnᚋhett
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4295,22 +4422,24 @@ func (ec *executionContext) marshalNHttpRequestLog2ᚕgithubᚗcomᚋdstotijnᚋ
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
func (ec *executionContext) unmarshalNID2int64(ctx context.Context, v interface{}) (int64, error) {
res, err := graphql.UnmarshalInt64(v)
func (ec *executionContext) unmarshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx context.Context, v interface{}) (ULID, error) {
var res ULID
err := res.UnmarshalGQL(v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNID2int64(ctx context.Context, sel ast.SelectionSet, v int64) graphql.Marshaler {
res := graphql.MarshalInt64(v)
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
}
}
return res
func (ec *executionContext) marshalNID2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐULID(ctx context.Context, sel ast.SelectionSet, v ULID) graphql.Marshaler {
return v
}
func (ec *executionContext) unmarshalNInt2int(ctx context.Context, v interface{}) (int, error) {
@ -4366,6 +4495,13 @@ func (ec *executionContext) marshalNProject2ᚕgithubᚗcomᚋdstotijnᚋhetty
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4407,6 +4543,13 @@ func (ec *executionContext) marshalNScopeRule2ᚕgithubᚗcomᚋdstotijnᚋhetty
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4504,6 +4647,13 @@ func (ec *executionContext) marshalN__Directive2ᚕgithubᚗcomᚋ99designsᚋgq
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4577,6 +4727,13 @@ func (ec *executionContext) marshalN__DirectiveLocation2ᚕstringᚄ(ctx context
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4626,6 +4783,13 @@ func (ec *executionContext) marshalN__InputValue2ᚕgithubᚗcomᚋ99designsᚋg
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4667,6 +4831,13 @@ func (ec *executionContext) marshalN__Type2ᚕgithubᚗcomᚋ99designsᚋgqlgen
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4846,6 +5017,13 @@ func (ec *executionContext) marshalO__EnumValue2ᚕgithubᚗcomᚋ99designsᚋgq
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4886,6 +5064,13 @@ func (ec *executionContext) marshalO__Field2ᚕgithubᚗcomᚋ99designsᚋgqlgen
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4926,6 +5111,13 @@ func (ec *executionContext) marshalO__InputValue2ᚕgithubᚗcomᚋ99designsᚋg
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
@ -4973,6 +5165,13 @@ func (ec *executionContext) marshalO__Type2ᚕgithubᚗcomᚋ99designsᚋgqlgen
}
wg.Wait()
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}

31
pkg/api/models.go Normal file
View File

@ -0,0 +1,31 @@
package api
import (
"fmt"
"io"
"strconv"
"github.com/oklog/ulid"
)
type ULID ulid.ULID
func (u *ULID) UnmarshalGQL(v interface{}) (err error) {
str, ok := v.(string)
if !ok {
return fmt.Errorf("ulid must be a string")
}
id, err := ulid.Parse(str)
if err != nil {
return fmt.Errorf("failed to parse ULID: %w", err)
}
*u = ULID(id)
return nil
}
func (u ULID) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(ulid.ULID(u).String()))
}

View File

@ -27,7 +27,7 @@ type HTTPHeader struct {
}
type HTTPRequestLog struct {
ID int64 `json:"id"`
ID ULID `json:"id"`
URL string `json:"url"`
Method HTTPMethod `json:"method"`
Proto string `json:"proto"`
@ -48,7 +48,6 @@ type HTTPRequestLogFilterInput struct {
}
type HTTPResponseLog struct {
RequestID int64 `json:"requestId"`
Proto string `json:"proto"`
StatusCode int `json:"statusCode"`
StatusReason string `json:"statusReason"`
@ -57,6 +56,7 @@ type HTTPResponseLog struct {
}
type Project struct {
ID ULID `json:"id"`
Name string `json:"name"`
IsActive bool `json:"isActive"`
}

View File

@ -10,6 +10,7 @@ import (
"strings"
"github.com/99designs/gqlgen/graphql"
"github.com/oklog/ulid"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/dstotijn/hetty/pkg/proj"
@ -19,9 +20,8 @@ import (
)
type Resolver struct {
RequestLogService *reqlog.Service
ProjectService proj.Service
ScopeService *scope.Scope
RequestLogService *reqlog.Service
}
type (
@ -54,8 +54,8 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
return logs, nil
}
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id ULID) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, ulid.ULID(id))
if errors.Is(err, reqlog.ErrRequestNotFound) {
return nil, nil
} else if err != nil {
@ -70,32 +70,32 @@ func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequ
return &req, nil
}
func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
method := HTTPMethod(req.Request.Method)
func parseRequestLog(reqLog reqlog.RequestLog) (HTTPRequestLog, error) {
method := HTTPMethod(reqLog.Method)
if method != "" && !method.IsValid() {
return HTTPRequestLog{}, fmt.Errorf("request has invalid method: %v", method)
}
log := HTTPRequestLog{
ID: req.ID,
Proto: req.Request.Proto,
ID: ULID(reqLog.ID),
Proto: reqLog.Proto,
Method: method,
Timestamp: req.Timestamp,
Timestamp: ulid.Time(reqLog.ID.Time()),
}
if req.Request.URL != nil {
log.URL = req.Request.URL.String()
if reqLog.URL != nil {
log.URL = reqLog.URL.String()
}
if len(req.Body) > 0 {
reqBody := string(req.Body)
log.Body = &reqBody
if len(reqLog.Body) > 0 {
bodyStr := string(reqLog.Body)
log.Body = &bodyStr
}
if req.Request.Header != nil {
if reqLog.Header != nil {
log.Headers = make([]HTTPHeader, 0)
for key, values := range req.Request.Header {
for key, values := range reqLog.Header {
for _, value := range values {
log.Headers = append(log.Headers, HTTPHeader{
Key: key,
@ -105,27 +105,26 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
}
}
if req.Response != nil {
if reqLog.Response != nil {
log.Response = &HTTPResponseLog{
RequestID: req.Response.RequestID,
Proto: req.Response.Response.Proto,
StatusCode: req.Response.Response.StatusCode,
Proto: reqLog.Response.Proto,
StatusCode: reqLog.Response.StatusCode,
}
statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2)
statusReasonSubs := strings.SplitN(reqLog.Response.Status, " ", 2)
if len(statusReasonSubs) == 2 {
log.Response.StatusReason = statusReasonSubs[1]
}
if len(req.Response.Body) > 0 {
resBody := string(req.Response.Body)
log.Response.Body = &resBody
if len(reqLog.Response.Body) > 0 {
bodyStr := string(reqLog.Response.Body)
log.Response.Body = &bodyStr
}
if req.Response.Response.Header != nil {
if reqLog.Response.Header != nil {
log.Response.Headers = make([]HTTPHeader, 0)
for key, values := range req.Response.Response.Header {
for key, values := range reqLog.Response.Header {
for _, value := range values {
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
Key: key,
@ -139,8 +138,8 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
return log, nil
}
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
p, err := r.ProjectService.Open(ctx, name)
func (r *mutationResolver) CreateProject(ctx context.Context, name string) (*Project, error) {
p, err := r.ProjectService.CreateProject(ctx, name)
if errors.Is(err, proj.ErrInvalidName) {
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
} else if err != nil {
@ -148,13 +147,29 @@ func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Proje
}
return &Project{
ID: ULID(p.ID),
Name: p.Name,
IsActive: p.IsActive,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
}
func (r *mutationResolver) OpenProject(ctx context.Context, id ULID) (*Project, error) {
p, err := r.ProjectService.OpenProject(ctx, ulid.ULID(id))
if errors.Is(err, proj.ErrInvalidName) {
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
} else if err != nil {
return nil, fmt.Errorf("could not open project: %w", err)
}
return &Project{
ID: ULID(p.ID),
Name: p.Name,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
}
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
p, err := r.ProjectService.ActiveProject()
p, err := r.ProjectService.ActiveProject(ctx)
if errors.Is(err, proj.ErrNoProject) {
return nil, nil
} else if err != nil {
@ -162,13 +177,14 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
}
return &Project{
ID: ULID(p.ID),
Name: p.Name,
IsActive: p.IsActive,
IsActive: r.ProjectService.IsProjectActive(p.ID),
}, nil
}
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
p, err := r.ProjectService.Projects()
p, err := r.ProjectService.Projects(ctx)
if err != nil {
return nil, fmt.Errorf("could not get projects: %w", err)
}
@ -176,8 +192,9 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
projects := make([]Project, len(p))
for i, proj := range p {
projects[i] = Project{
ID: ULID(proj.ID),
Name: proj.Name,
IsActive: proj.IsActive,
IsActive: r.ProjectService.IsProjectActive(proj.ID),
}
}
@ -185,7 +202,7 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
}
func (r *queryResolver) Scope(ctx context.Context) ([]ScopeRule, error) {
rules := r.ScopeService.Rules()
rules := r.ProjectService.Scope().Rules()
return scopeToScopeRules(rules), nil
}
@ -200,15 +217,15 @@ func regexpToStringPtr(r *regexp.Regexp) *string {
}
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
if err := r.ProjectService.Close(); err != nil {
if err := r.ProjectService.CloseProject(); err != nil {
return nil, fmt.Errorf("could not close project: %w", err)
}
return &CloseProjectResult{true}, nil
}
func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) {
if err := r.ProjectService.Delete(name); err != nil {
func (r *mutationResolver) DeleteProject(ctx context.Context, id ULID) (*DeleteProjectResult, error) {
if err := r.ProjectService.DeleteProject(ctx, ulid.ULID(id)); err != nil {
return nil, fmt.Errorf("could not delete project: %w", err)
}
@ -218,7 +235,14 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del
}
func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) {
if err := r.RequestLogService.ClearRequests(ctx); err != nil {
project, err := r.ProjectService.ActiveProject(ctx)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not get active project: %w", err)
}
if err := r.RequestLogService.ClearRequests(ctx, project.ID); err != nil {
return nil, fmt.Errorf("could not clear request log: %w", err)
}
@ -263,8 +287,9 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput)
}
}
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
return nil, fmt.Errorf("could not set scope: %w", err)
err := r.ProjectService.SetScopeRules(ctx, rules)
if err != nil {
return nil, fmt.Errorf("could not set scope rules: %w", err)
}
return scopeToScopeRules(rules), nil
@ -283,7 +308,7 @@ func (r *mutationResolver) SetHTTPRequestLogFilter(
return nil, fmt.Errorf("could not parse request log filter: %w", err)
}
err = r.RequestLogService.SetRequestLogFilter(ctx, filter)
err = r.ProjectService.SetRequestLogFindFilter(ctx, filter)
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
@ -333,7 +358,6 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err)
}
filter.RawSearchExpr = *input.SearchExpression
filter.SearchExpr = expr
}
@ -350,8 +374,9 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H
OnlyInScope: findReqFilter.OnlyInScope,
}
if findReqFilter.RawSearchExpr != "" {
httpReqLogFilter.SearchExpression = &findReqFilter.RawSearchExpr
if findReqFilter.SearchExpr != nil {
searchExpr := findReqFilter.SearchExpr.String()
httpReqLogFilter.SearchExpression = &searchExpr
}
return httpReqLogFilter

View File

@ -10,7 +10,6 @@ type HttpRequestLog {
}
type HttpResponseLog {
requestId: ID!
proto: String!
statusCode: Int!
statusReason: String!
@ -24,6 +23,7 @@ type HttpHeader {
}
type Project {
id: ID!
name: String!
isActive: Boolean!
}
@ -82,9 +82,10 @@ type Query {
}
type Mutation {
openProject(name: String!): Project
createProject(name: String!): Project
openProject(id: ID!): Project
closeProject: CloseProjectResult!
deleteProject(name: String!): DeleteProjectResult!
deleteProject(id: ID!): DeleteProjectResult!
clearHTTPRequestLog: ClearHTTPRequestLogResult!
setScope(scope: [ScopeRuleInput!]!): [ScopeRule!]!
setHttpRequestLogFilter(

53
pkg/db/badger/badger.go Normal file
View File

@ -0,0 +1,53 @@
package badger
import (
"fmt"
"github.com/dgraph-io/badger/v3"
)
const (
// Key prefixes. Each prefix value should be unique.
projectPrefix = 0x00
reqLogPrefix = 0x01
resLogPrefix = 0x02
// Request log indices.
reqLogProjectIDIndex = 0x00
)
// Database is used to store and retrieve data from an underlying Badger database.
type Database struct {
badger *badger.DB
}
// OpenDatabase opens a new Badger database.
func OpenDatabase(opts badger.Options) (*Database, error) {
db, err := badger.Open(opts)
if err != nil {
return nil, fmt.Errorf("badger: failed to open database: %w", err)
}
return &Database{badger: db}, nil
}
// Close closes the underlying Badger database.
func (db *Database) Close() error {
return db.badger.Close()
}
// DatabaseFromBadgerDB returns a Database with `db` set as the underlying
// Badger database.
func DatabaseFromBadgerDB(db *badger.DB) *Database {
return &Database{badger: db}
}
func entryKey(prefix, index byte, value []byte) []byte {
// Key consists of: | prefix (byte) | index (byte) | value
key := make([]byte, 2+len(value))
key[0] = prefix
key[1] = index
copy(key[2:len(value)+2], value)
return key
}

110
pkg/db/badger/proj.go Normal file
View File

@ -0,0 +1,110 @@
package badger
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/dgraph-io/badger/v3"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/proj"
)
func (db *Database) UpsertProject(ctx context.Context, project proj.Project) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(project)
if err != nil {
return fmt.Errorf("badger: failed to encode project: %w", err)
}
err = db.badger.Update(func(txn *badger.Txn) error {
return txn.Set(entryKey(projectPrefix, 0, project.ID[:]), buf.Bytes())
})
if err != nil {
return fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) FindProjectByID(ctx context.Context, projectID ulid.ULID) (project proj.Project, err error) {
err = db.badger.View(func(txn *badger.Txn) error {
item, err := txn.Get(entryKey(projectPrefix, 0, projectID[:]))
if err != nil {
return err
}
err = item.Value(func(rawProject []byte) error {
return gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
})
if err != nil {
return fmt.Errorf("failed to retrieve or parse project: %w", err)
}
return nil
})
if errors.Is(err, badger.ErrKeyNotFound) {
return proj.Project{}, proj.ErrProjectNotFound
}
if err != nil {
return proj.Project{}, fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return project, nil
}
func (db *Database) DeleteProject(ctx context.Context, projectID ulid.ULID) error {
err := db.ClearRequestLogs(ctx, projectID)
if err != nil {
return fmt.Errorf("badger: failed to delete project request logs: %w", err)
}
err = db.badger.Update(func(txn *badger.Txn) error {
return txn.Delete(entryKey(projectPrefix, 0, projectID[:]))
})
if err != nil {
return fmt.Errorf("badger: failed to delete project item: %w", err)
}
return nil
}
func (db *Database) Projects(ctx context.Context) ([]proj.Project, error) {
projects := make([]proj.Project, 0)
err := db.badger.View(func(txn *badger.Txn) error {
var rawProject []byte
prefix := entryKey(projectPrefix, 0, nil)
iterator := txn.NewIterator(badger.DefaultIteratorOptions)
defer iterator.Close()
for iterator.Seek(prefix); iterator.ValidForPrefix(prefix); iterator.Next() {
rawProject, err := iterator.Item().ValueCopy(rawProject)
if err != nil {
return fmt.Errorf("failed to copy value: %w", err)
}
var project proj.Project
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
if err != nil {
return fmt.Errorf("failed to decode project: %w", err)
}
projects = append(projects, project)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return projects, nil
}

284
pkg/db/badger/proj_test.go Normal file
View File

@ -0,0 +1,284 @@
package badger
import (
"bytes"
"context"
"encoding/gob"
"errors"
"math/rand"
"regexp"
"testing"
"time"
badgerdb "github.com/dgraph-io/badger/v3"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var regexpCompareOpt = cmp.Comparer(func(x, y *regexp.Regexp) bool {
switch {
case x == nil && y == nil:
return true
case x == nil || y == nil:
return false
default:
return x.String() == y.String()
}
})
func TestUpsertProject(t *testing.T) {
t.Parallel()
badgerDB, err := badgerdb.Open(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
database := DatabaseFromBadgerDB(badgerDB)
defer database.Close()
searchExpr, err := search.ParseQuery("foo AND bar OR NOT baz")
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "foobar",
Settings: proj.Settings{
ReqLogBypassOutOfScope: true,
ReqLogOnlyFindInScope: true,
ScopeRules: []scope.Rule{
{
URL: regexp.MustCompile("^https://(.*)example.com(.*)$"),
Header: scope.Header{
Key: regexp.MustCompile("^X-Foo(.*)$"),
Value: regexp.MustCompile("^foo(.*)$"),
},
Body: regexp.MustCompile("^foo(.*)"),
},
},
SearchExpr: searchExpr,
},
}
err = database.UpsertProject(context.Background(), exp)
if err != nil {
t.Fatalf("unexpected error storing project: %v", err)
}
var rawProject []byte
err = badgerDB.View(func(txn *badgerdb.Txn) error {
item, err := txn.Get(entryKey(projectPrefix, 0, exp.ID[:]))
if err != nil {
return err
}
rawProject, err = item.ValueCopy(nil)
return err
})
if err != nil {
t.Fatalf("unexpected error retrieving project from database: %v", err)
}
got := proj.Project{}
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&got)
if err != nil {
t.Fatalf("unexpected error decoding project: %v", err)
}
if diff := cmp.Diff(exp, got, regexpCompareOpt, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
}
func TestFindProjectByID(t *testing.T) {
t.Parallel()
t.Run("existing project", func(t *testing.T) {
t.Parallel()
badgerDB, err := badgerdb.Open(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
database := DatabaseFromBadgerDB(badgerDB)
defer database.Close()
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "foobar",
Settings: proj.Settings{},
}
buf := bytes.Buffer{}
err = gob.NewEncoder(&buf).Encode(exp)
if err != nil {
t.Fatalf("unexpected error encoding project: %v", err)
}
err = badgerDB.Update(func(txn *badgerdb.Txn) error {
return txn.Set(entryKey(projectPrefix, 0, exp.ID[:]), buf.Bytes())
})
if err != nil {
t.Fatalf("unexpected error setting project: %v", err)
}
got, err := database.FindProjectByID(context.Background(), exp.ID)
if err != nil {
t.Fatalf("unexpected error finding project: %v", err)
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
})
t.Run("project not found", func(t *testing.T) {
t.Parallel()
database, err := OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
_, err = database.FindProjectByID(context.Background(), projectID)
if !errors.Is(err, proj.ErrProjectNotFound) {
t.Fatalf("expected `proj.ErrProjectNotFound`, got: %v", err)
}
})
}
func TestDeleteProject(t *testing.T) {
t.Parallel()
badgerDB, err := badgerdb.Open(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
database := DatabaseFromBadgerDB(badgerDB)
defer database.Close()
// Store fixtures.
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = badgerDB.Update(func(txn *badgerdb.Txn) error {
if err := txn.Set(entryKey(projectPrefix, 0, projectID[:]), nil); err != nil {
return err
}
if err := txn.Set(entryKey(reqLogPrefix, 0, reqLogID[:]), nil); err != nil {
return err
}
if err := txn.Set(entryKey(resLogPrefix, 0, reqLogID[:]), nil); err != nil {
return err
}
err := txn.Set(entryKey(reqLogPrefix, reqLogProjectIDIndex, append(projectID[:], reqLogID[:]...)), nil)
if err != nil {
return err
}
return nil
})
if err != nil {
t.Fatalf("unexpected error creating fixtures: %v", err)
}
err = database.DeleteProject(context.Background(), projectID)
if err != nil {
t.Fatalf("unexpected error deleting project: %v", err)
}
// Assert project key was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(projectPrefix, 0, projectID[:]))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert request log item was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(reqLogPrefix, 0, reqLogID[:]))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert response log item was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(resLogPrefix, 0, reqLogID[:]))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
// Assert request log project ID index key was deleted.
err = badgerDB.View(func(txn *badgerdb.Txn) error {
_, err := txn.Get(entryKey(reqLogPrefix, reqLogProjectIDIndex, append(projectID[:], reqLogID[:]...)))
return err
})
if !errors.Is(err, badgerdb.ErrKeyNotFound) {
t.Fatalf("expected `badger.ErrKeyNotFound`, got: %v", err)
}
}
func TestProjects(t *testing.T) {
t.Parallel()
database, err := OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
exp := []proj.Project{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "one",
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
Name: "two",
},
}
// Store fixtures.
for _, project := range exp {
err = database.UpsertProject(context.Background(), project)
if err != nil {
t.Fatalf("unexpected error creating project fixture: %v", err)
}
}
got, err := database.Projects(context.Background())
if err != nil {
t.Fatalf("unexpected error finding projects: %v", err)
}
if len(exp) != len(got) {
t.Fatalf("expected %v projects, got: %v", len(exp), len(got))
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("projects not equal (-exp, +got):\n%v", diff)
}
}

251
pkg/db/badger/reqlog.go Normal file
View File

@ -0,0 +1,251 @@
package badger
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/dgraph-io/badger/v3"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
func (db *Database) FindRequestLogs(ctx context.Context, filter reqlog.FindRequestsFilter, scope *scope.Scope) ([]reqlog.RequestLog, error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, reqlog.ErrProjectIDMustBeSet
}
txn := db.badger.NewTransaction(false)
defer txn.Discard()
reqLogIDs, err := findRequestLogIDsByProjectID(txn, filter.ProjectID)
if err != nil {
return nil, fmt.Errorf("badger: failed to find request log IDs: %w", err)
}
reqLogs := make([]reqlog.RequestLog, 0, len(reqLogIDs))
for _, reqLogID := range reqLogIDs {
reqLog, err := getRequestLogWithResponse(txn, reqLogID)
if err != nil {
return nil, fmt.Errorf("badger: failed to get request log (id: %v): %w", reqLogID.String(), err)
}
if filter.OnlyInScope {
if !reqLog.MatchScope(scope) {
continue
}
}
// Filter by search expression.
// TODO: Once pagination is introduced, this filter logic should be done
// as items are retrieved (e.g. when using a `badger.Iterator`).
if filter.SearchExpr != nil {
match, err := reqLog.Matches(filter.SearchExpr)
if err != nil {
return nil, fmt.Errorf(
"badger: failed to match search expression for request log (id: %v): %w",
reqLogID.String(), err,
)
}
if !match {
continue
}
}
reqLogs = append(reqLogs, reqLog)
}
return reqLogs, nil
}
func getRequestLogWithResponse(txn *badger.Txn, reqLogID ulid.ULID) (reqlog.RequestLog, error) {
item, err := txn.Get(entryKey(reqLogPrefix, 0, reqLogID[:]))
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("failed to lookup request log item: %w", err)
}
reqLog := reqlog.RequestLog{
ID: reqLogID,
}
err = item.Value(func(rawReqLog []byte) error {
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
return nil
})
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("failed to retrieve or parse request log value: %w", err)
}
item, err = txn.Get(entryKey(resLogPrefix, 0, reqLogID[:]))
if errors.Is(err, badger.ErrKeyNotFound) {
return reqLog, nil
}
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("failed to get response log: %w", err)
}
err = item.Value(func(rawReslog []byte) error {
var resLog reqlog.ResponseLog
err = gob.NewDecoder(bytes.NewReader(rawReslog)).Decode(&resLog)
if err != nil {
return fmt.Errorf("failed to decode response log: %w", err)
}
reqLog.Response = &resLog
return nil
})
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("failed to retrieve or parse response log value: %w", err)
}
return reqLog, nil
}
func (db *Database) FindRequestLogByID(ctx context.Context, reqLogID ulid.ULID) (reqLog reqlog.RequestLog, err error) {
txn := db.badger.NewTransaction(false)
defer txn.Discard()
reqLog, err = getRequestLogWithResponse(txn, reqLogID)
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("badger: failed to get request log: %w", err)
}
return reqLog, nil
}
func (db *Database) StoreRequestLog(ctx context.Context, reqLog reqlog.RequestLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(reqLog)
if err != nil {
return fmt.Errorf("badger: failed to encode request log: %w", err)
}
entries := []*badger.Entry{
// Request log itself.
{
Key: entryKey(reqLogPrefix, 0, reqLog.ID[:]),
Value: buf.Bytes(),
},
// Index by project ID.
{
Key: entryKey(reqLogPrefix, reqLogProjectIDIndex, append(reqLog.ProjectID[:], reqLog.ID[:]...)),
},
}
err = db.badger.Update(func(txn *badger.Txn) error {
for i := range entries {
err := txn.SetEntry(entries[i])
if err != nil {
return err
}
}
return nil
})
if err != nil {
return fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) StoreResponseLog(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(resLog)
if err != nil {
return fmt.Errorf("badger: failed to encode response log: %w", err)
}
err = db.badger.Update(func(txn *badger.Txn) error {
return txn.SetEntry(&badger.Entry{
Key: entryKey(resLogPrefix, 0, reqLogID[:]),
Value: buf.Bytes(),
})
})
if err != nil {
return fmt.Errorf("badger: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) ClearRequestLogs(ctx context.Context, projectID ulid.ULID) error {
// Note: this transaction is used just for reading; we use the `badger.WriteBatch`
// API to bulk delete items.
txn := db.badger.NewTransaction(false)
defer txn.Discard()
reqLogIDs, err := findRequestLogIDsByProjectID(txn, projectID)
if err != nil {
return fmt.Errorf("badger: failed to find request log IDs: %w", err)
}
writeBatch := db.badger.NewWriteBatch()
defer writeBatch.Cancel()
for _, reqLogID := range reqLogIDs {
// Delete request logs.
err := writeBatch.Delete(entryKey(reqLogPrefix, 0, reqLogID[:]))
if err != nil {
return fmt.Errorf("badger: failed to delete request log: %w", err)
}
// Delete related response log.
err = writeBatch.Delete(entryKey(resLogPrefix, 0, reqLogID[:]))
if err != nil {
return fmt.Errorf("badger: failed to delete request log: %w", err)
}
}
if err := writeBatch.Flush(); err != nil {
return fmt.Errorf("badger: failed to commit batch write: %w", err)
}
err = db.badger.DropPrefix(entryKey(reqLogPrefix, reqLogProjectIDIndex, projectID[:]))
if err != nil {
return fmt.Errorf("badger: failed to drop request log project ID index items: %w", err)
}
return nil
}
func findRequestLogIDsByProjectID(txn *badger.Txn, projectID ulid.ULID) ([]ulid.ULID, error) {
reqLogIDs := make([]ulid.ULID, 0)
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
iterator := txn.NewIterator(opts)
defer iterator.Close()
var projectIndexKey []byte
prefix := entryKey(reqLogPrefix, reqLogProjectIDIndex, projectID[:])
for iterator.Seek(prefix); iterator.ValidForPrefix(prefix); iterator.Next() {
projectIndexKey = iterator.Item().KeyCopy(projectIndexKey)
var id ulid.ULID
// The request log ID starts *after* the first 2 prefix and index bytes
// and the 16 byte project ID.
if err := id.UnmarshalBinary(projectIndexKey[18:]); err != nil {
return nil, fmt.Errorf("failed to parse request log ID: %w", err)
}
reqLogIDs = append(reqLogIDs, id)
}
return reqLogIDs, nil
}

View File

@ -0,0 +1,121 @@
package badger
import (
"context"
"errors"
"net/http"
"net/url"
"testing"
"time"
badgerdb "github.com/dgraph-io/badger/v3"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
)
func TestFindRequestLogs(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
database, err := OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
filter := reqlog.FindRequestsFilter{}
_, err = database.FindRequestLogs(context.Background(), filter, nil)
if !errors.Is(err, reqlog.ErrProjectIDMustBeSet) {
t.Fatalf("expected `reqlog.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns request logs and related response logs", func(t *testing.T) {
t.Parallel()
database, err := OpenDatabase(badgerdb.DefaultOptions("").WithInMemory(true))
if err != nil {
t.Fatalf("failed to open badger database: %v", err)
}
defer database.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
exp := []reqlog.RequestLog{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foobar"),
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foo?bar=baz"),
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
},
}
// Store fixtures.
for _, reqLog := range exp {
err = database.StoreRequestLog(context.Background(), reqLog)
if err != nil {
t.Fatalf("unexpected error creating request log fixture: %v", err)
}
if reqLog.Response != nil {
err = database.StoreResponseLog(context.Background(), reqLog.ID, *reqLog.Response)
if err != nil {
t.Fatalf("unexpected error creating response log fixture: %v", err)
}
}
}
filter := reqlog.FindRequestsFilter{
ProjectID: projectID,
}
got, err := database.FindRequestLogs(context.Background(), filter, nil)
if err != nil {
t.Fatalf("unexpected error finding request logs: %v", err)
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request logs not equal (-exp, +got):\n%v", diff)
}
})
}
func mustParseURL(t *testing.T, s string) *url.URL {
t.Helper()
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}

View File

@ -1,82 +0,0 @@
package sqlite
import (
"database/sql"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/dstotijn/hetty/pkg/reqlog"
)
type reqURL url.URL
type httpRequest struct {
ID int64 `db:"req_id"`
Proto string `db:"req_proto"`
URL reqURL `db:"url"`
Method string `db:"method"`
Body []byte `db:"req_body"`
Timestamp time.Time `db:"req_timestamp"`
httpResponse
}
type httpResponse struct {
ID sql.NullInt64 `db:"res_id"`
RequestID sql.NullInt64 `db:"res_req_id"`
Proto sql.NullString `db:"res_proto"`
StatusCode sql.NullInt64 `db:"status_code"`
StatusReason sql.NullString `db:"status_reason"`
Body []byte `db:"res_body"`
Timestamp sql.NullTime `db:"res_timestamp"`
}
// Value implements driver.Valuer.
func (u *reqURL) Scan(value interface{}) error {
rawURL, ok := value.(string)
if !ok {
return errors.New("sqlite: cannot scan non-string value")
}
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("sqlite: could not parse URL: %w", err)
}
*u = reqURL(*parsed)
return nil
}
func (dto httpRequest) toRequestLog() reqlog.Request {
u := url.URL(dto.URL)
reqLog := reqlog.Request{
ID: dto.ID,
Request: http.Request{
Proto: dto.Proto,
Method: dto.Method,
URL: &u,
},
Body: dto.Body,
Timestamp: dto.Timestamp,
}
if dto.httpResponse.ID.Valid {
reqLog.Response = &reqlog.Response{
ID: dto.httpResponse.ID.Int64,
RequestID: dto.httpResponse.RequestID.Int64,
Response: http.Response{
Status: strconv.FormatInt(dto.StatusCode.Int64, 10) + " " + dto.StatusReason.String,
StatusCode: int(dto.StatusCode.Int64),
Proto: dto.httpResponse.Proto.String,
},
Body: dto.httpResponse.Body,
Timestamp: dto.httpResponse.Timestamp.Time,
}
}
return reqLog
}

View File

@ -1,135 +0,0 @@
package sqlite
import (
"errors"
"fmt"
"sort"
sq "github.com/Masterminds/squirrel"
"github.com/dstotijn/hetty/pkg/search"
)
var stringLiteralMap = map[string]string{
// http_requests
"req.id": "req.id",
"req.proto": "req.proto",
"req.url": "req.url",
"req.method": "req.method",
"req.body": "req.body",
"req.timestamp": "req.timestamp",
// http_responses
"res.id": "res.id",
"res.proto": "res.proto",
"res.statusCode": "res.status_code",
"res.statusReason": "res.status_reason",
"res.body": "res.body",
"res.timestamp": "res.timestamp",
// TODO: http_headers
}
func parseSearchExpr(expr search.Expression) (sq.Sqlizer, error) {
switch e := expr.(type) {
case *search.PrefixExpression:
return parsePrefixExpr(e)
case *search.InfixExpression:
return parseInfixExpr(e)
case *search.StringLiteral:
return parseStringLiteral(e)
default:
return nil, fmt.Errorf("expression type (%v) not supported", expr)
}
}
func parsePrefixExpr(expr *search.PrefixExpression) (sq.Sqlizer, error) {
switch expr.Operator {
case search.TokOpNot:
// TODO: Find a way to prefix an `sq.Sqlizer` with "NOT".
return nil, errors.New("not implemented")
default:
return nil, errors.New("operator is not supported")
}
}
func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := parseSearchExpr(expr.Left)
if err != nil {
return nil, err
}
right, err := parseSearchExpr(expr.Right)
if err != nil {
return nil, err
}
return sq.And{left, right}, nil
case search.TokOpOr:
left, err := parseSearchExpr(expr.Left)
if err != nil {
return nil, err
}
right, err := parseSearchExpr(expr.Right)
if err != nil {
return nil, err
}
return sq.Or{left, right}, nil
}
left, ok := expr.Left.(*search.StringLiteral)
if !ok {
return nil, errors.New("left operand must be a string literal")
}
right, ok := expr.Right.(*search.StringLiteral)
if !ok {
return nil, errors.New("right operand must be a string literal")
}
mappedLeft, ok := stringLiteralMap[left.Value]
if !ok {
return nil, fmt.Errorf("invalid string literal: %v", left)
}
switch expr.Operator {
case search.TokOpEq:
return sq.Eq{mappedLeft: right.Value}, nil
case search.TokOpNotEq:
return sq.NotEq{mappedLeft: right.Value}, nil
case search.TokOpGt:
return sq.Gt{mappedLeft: right.Value}, nil
case search.TokOpLt:
return sq.Lt{mappedLeft: right.Value}, nil
case search.TokOpGtEq:
return sq.GtOrEq{mappedLeft: right.Value}, nil
case search.TokOpLtEq:
return sq.LtOrEq{mappedLeft: right.Value}, nil
case search.TokOpRe:
return sq.Expr(fmt.Sprintf("regexp(?, %v)", mappedLeft), right.Value), nil
case search.TokOpNotRe:
return sq.Expr(fmt.Sprintf("NOT regexp(?, %v)", mappedLeft), right.Value), nil
default:
return nil, errors.New("unsupported operator")
}
}
func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) {
// Sorting is not necessary, but makes it easier to do assertions in tests.
sortedKeys := make([]string, 0, len(stringLiteralMap))
for _, v := range stringLiteralMap {
sortedKeys = append(sortedKeys, v)
}
sort.Strings(sortedKeys)
or := make(sq.Or, len(stringLiteralMap))
for i, value := range sortedKeys {
or[i] = sq.Like{value: "%" + strLiteral.Value + "%"}
}
return or, nil
}

View File

@ -1,221 +0,0 @@
package sqlite
import (
"reflect"
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/dstotijn/hetty/pkg/search"
)
func TestParseSearchExpr(t *testing.T) {
t.Parallel()
tests := []struct {
name string
searchExpr search.Expression
expectedSqlizer sq.Sqlizer
expectedError error
}{
{
name: "req.body = bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.Eq{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body != bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpNotEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.NotEq{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body > bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpGt,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.Gt{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body < bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpLt,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.Lt{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body >= bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpGtEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.GtOrEq{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body <= bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpLtEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.LtOrEq{"req.body": "bar"},
expectedError: nil,
},
{
name: "req.body =~ bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpRe,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.Expr("regexp(?, req.body)", "bar"),
expectedError: nil,
},
{
name: "req.body !~ bar",
searchExpr: &search.InfixExpression{
Operator: search.TokOpNotRe,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
expectedSqlizer: sq.Expr("NOT regexp(?, req.body)", "bar"),
expectedError: nil,
},
{
name: "req.body = bar AND res.body = yolo",
searchExpr: &search.InfixExpression{
Operator: search.TokOpAnd,
Left: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
Right: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "res.body"},
Right: &search.StringLiteral{Value: "yolo"},
},
},
expectedSqlizer: sq.And{
sq.Eq{"req.body": "bar"},
sq.Eq{"res.body": "yolo"},
},
expectedError: nil,
},
{
name: "req.body = bar AND res.body = yolo AND req.method = POST",
searchExpr: &search.InfixExpression{
Operator: search.TokOpAnd,
Left: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
Right: &search.InfixExpression{
Operator: search.TokOpAnd,
Left: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "res.body"},
Right: &search.StringLiteral{Value: "yolo"},
},
Right: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "req.method"},
Right: &search.StringLiteral{Value: "POST"},
},
},
},
expectedSqlizer: sq.And{
sq.Eq{"req.body": "bar"},
sq.And{
sq.Eq{"res.body": "yolo"},
sq.Eq{"req.method": "POST"},
},
},
expectedError: nil,
},
{
name: "req.body = bar OR res.body = yolo",
searchExpr: &search.InfixExpression{
Operator: search.TokOpOr,
Left: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "req.body"},
Right: &search.StringLiteral{Value: "bar"},
},
Right: &search.InfixExpression{
Operator: search.TokOpEq,
Left: &search.StringLiteral{Value: "res.body"},
Right: &search.StringLiteral{Value: "yolo"},
},
},
expectedSqlizer: sq.Or{
sq.Eq{"req.body": "bar"},
sq.Eq{"res.body": "yolo"},
},
expectedError: nil,
},
{
name: "foo",
searchExpr: &search.StringLiteral{
Value: "foo",
},
expectedSqlizer: sq.Or{
sq.Like{"req.body": "%foo%"},
sq.Like{"req.id": "%foo%"},
sq.Like{"req.method": "%foo%"},
sq.Like{"req.proto": "%foo%"},
sq.Like{"req.timestamp": "%foo%"},
sq.Like{"req.url": "%foo%"},
sq.Like{"res.body": "%foo%"},
sq.Like{"res.id": "%foo%"},
sq.Like{"res.proto": "%foo%"},
sq.Like{"res.status_code": "%foo%"},
sq.Like{"res.status_reason": "%foo%"},
sq.Like{"res.timestamp": "%foo%"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := parseSearchExpr(tt.searchExpr)
assertError(t, tt.expectedError, err)
if !reflect.DeepEqual(tt.expectedSqlizer, got) {
t.Errorf("expected: %#v, got: %#v", tt.expectedSqlizer, got)
}
})
}
}
func assertError(t *testing.T, exp, got error) {
t.Helper()
switch {
case exp == nil && got != nil:
t.Fatalf("expected: nil, got: %v", got)
case exp != nil && got == nil:
t.Fatalf("expected: %v, got: nil", exp.Error())
case exp != nil && got != nil && exp.Error() != got.Error():
t.Fatalf("expected: %v, got: %v", exp.Error(), got.Error())
}
}

View File

@ -1,709 +0,0 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/99designs/gqlgen/graphql"
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
var regexpFn = func(pattern string, value interface{}) (bool, error) {
switch v := value.(type) {
case string:
return regexp.MatchString(pattern, v)
case int64:
return regexp.MatchString(pattern, fmt.Sprintf("%v", v))
case []byte:
return regexp.Match(pattern, v)
default:
return false, fmt.Errorf("unsupported type %T", v)
}
}
// Client implements reqlog.Repository.
type Client struct {
db *sqlx.DB
dbPath string
activeProject string
}
type httpRequestLogsQuery struct {
requestCols []string
requestHeaderCols []string
responseHeaderCols []string
joinResponse bool
}
func init() {
sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regexp", regexpFn, false)
},
})
}
func New(dbPath string) (*Client, error) {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("proj: could not create project directory: %w", err)
}
}
return &Client{
dbPath: dbPath,
}, nil
}
// OpenProject opens a project database.
func (c *Client) OpenProject(name string) error {
if c.db != nil {
return errors.New("sqlite: there is already a project open")
}
opts := make(url.Values)
opts.Set("_foreign_keys", "1")
dbPath := filepath.Join(c.dbPath, name+".db")
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
db, err := sqlx.Open("sqlite3_with_regexp", dsn)
if err != nil {
return fmt.Errorf("sqlite: could not open database: %w", err)
}
if err := db.Ping(); err != nil {
return fmt.Errorf("sqlite: could not ping database: %w", err)
}
if err := prepareSchema(db); err != nil {
return fmt.Errorf("sqlite: could not prepare schema: %w", err)
}
c.db = db
c.activeProject = name
return nil
}
func (c *Client) Projects() ([]proj.Project, error) {
files, err := ioutil.ReadDir(c.dbPath)
if err != nil {
return nil, fmt.Errorf("sqlite: could not read projects directory: %w", err)
}
projects := make([]proj.Project, len(files))
for i, file := range files {
projName := strings.TrimSuffix(file.Name(), ".db")
projects[i] = proj.Project{
Name: projName,
IsActive: c.activeProject == projName,
}
}
return projects, nil
}
func prepareSchema(db *sqlx.DB) error {
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS http_requests (
id INTEGER PRIMARY KEY,
proto TEXT,
url TEXT,
method TEXT,
body BLOB,
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_requests table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
id INTEGER PRIMARY KEY,
req_id INTEGER REFERENCES http_requests(id) ON DELETE CASCADE,
proto TEXT,
status_code INTEGER,
status_reason TEXT,
body BLOB,
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_responses table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
id INTEGER PRIMARY KEY,
req_id INTEGER REFERENCES http_requests(id) ON DELETE CASCADE,
res_id INTEGER REFERENCES http_responses(id) ON DELETE CASCADE,
key TEXT,
value TEXT
)`)
if err != nil {
return fmt.Errorf("could not create http_headers table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
module TEXT PRIMARY KEY,
settings TEXT
)`)
if err != nil {
return fmt.Errorf("could not create settings table: %w", err)
}
return nil
}
// Close uses the underlying database if it's open.
func (c *Client) Close() error {
if c.db == nil {
return nil
}
if err := c.db.Close(); err != nil {
return fmt.Errorf("sqlite: could not close database: %w", err)
}
c.db = nil
c.activeProject = ""
return nil
}
func (c *Client) DeleteProject(name string) error {
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
return fmt.Errorf("sqlite: could not remove database file: %w", err)
}
return nil
}
var reqFieldToColumnMap = map[string]string{
"proto": "proto AS req_proto",
"url": "url",
"method": "method",
"body": "body AS req_body",
"timestamp": "timestamp AS req_timestamp",
}
var resFieldToColumnMap = map[string]string{
"requestId": "req_id AS res_req_id",
"proto": "proto AS res_proto",
"statusCode": "status_code",
"statusReason": "status_reason",
"body": "body AS res_body",
"timestamp": "timestamp AS res_timestamp",
}
var headerFieldToColumnMap = map[string]string{
"key": "key",
"value": "value",
}
func (c *Client) ClearRequestLogs(ctx context.Context) error {
if c.db == nil {
return proj.ErrNoProject
}
_, err := c.db.Exec("DELETE FROM http_requests")
if err != nil {
return fmt.Errorf("sqlite: could not delete requests: %w", err)
}
return nil
}
func (c *Client) FindRequestLogs(
ctx context.Context,
filter reqlog.FindRequestsFilter,
scope *scope.Scope,
) (reqLogs []reqlog.Request, err error) {
if c.db == nil {
return nil, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
reqQuery := sq.
Select(httpReqLogsQuery.requestCols...).
From("http_requests req").
OrderBy("req.id DESC")
if httpReqLogsQuery.joinResponse {
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
}
if filter.OnlyInScope && scope != nil {
var ruleExpr []sq.Sqlizer
for _, rule := range scope.Rules() {
if rule.URL != nil {
ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String()))
}
}
if len(ruleExpr) > 0 {
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
}
}
if filter.SearchExpr != nil {
sqlizer, err := parseSearchExpr(filter.SearchExpr)
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse search expression: %w", err)
}
reqQuery = reqQuery.Where(sqlizer)
}
sql, args, err := reqQuery.ToSql()
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse query: %w", err)
}
rows, err := c.db.QueryxContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
}
defer rows.Close()
for rows.Next() {
var dto httpRequest
err = rows.StructScan(&dto)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
}
reqLogs = append(reqLogs, dto.toRequestLog())
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
}
defer rows.Close()
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
return nil, fmt.Errorf("sqlite: could not query headers: %w", err)
}
return reqLogs, nil
}
func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) {
if c.db == nil {
return reqlog.Request{}, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
reqQuery := sq.
Select(httpReqLogsQuery.requestCols...).
From("http_requests req").
Where("req.id = ?")
if httpReqLogsQuery.joinResponse {
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
}
reqSQL, _, err := reqQuery.ToSql()
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %w", err)
}
row := c.db.QueryRowxContext(ctx, reqSQL, id)
var dto httpRequest
err = row.StructScan(&dto)
if errors.Is(err, sql.ErrNoRows) {
return reqlog.Request{}, reqlog.ErrRequestNotFound
} else if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %w", err)
}
reqLog := dto.toRequestLog()
reqLogs := []reqlog.Request{reqLog}
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %w", err)
}
return reqLogs[0], nil
}
func (c *Client) AddRequestLog(
ctx context.Context,
req http.Request,
body []byte,
timestamp time.Time,
) (*reqlog.Request, error) {
if c.db == nil {
return nil, proj.ErrNoProject
}
reqLog := &reqlog.Request{
Request: req,
Body: body,
Timestamp: timestamp,
}
tx, err := c.db.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
}
defer tx.Rollback()
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
proto,
url,
method,
body,
timestamp
) VALUES (?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer reqStmt.Close()
result, err := reqStmt.ExecContext(ctx,
reqLog.Request.Proto,
reqLog.Request.URL.String(),
reqLog.Request.Method,
reqLog.Body,
reqLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
}
reqID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
}
reqLog.ID = reqID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
req_id,
key,
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
}
return reqLog, nil
}
func (c *Client) AddResponseLog(
ctx context.Context,
reqID int64,
res http.Response,
body []byte,
timestamp time.Time,
) (*reqlog.Response, error) {
if c.db == nil {
return nil, proj.ErrNoProject
}
resLog := &reqlog.Response{
RequestID: reqID,
Response: res,
Body: body,
Timestamp: timestamp,
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
}
defer tx.Rollback()
resStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_responses (
req_id,
proto,
status_code,
status_reason,
body,
timestamp
) VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer resStmt.Close()
var statusReason string
if len(resLog.Response.Status) > 4 {
statusReason = resLog.Response.Status[4:]
}
result, err := resStmt.ExecContext(ctx,
resLog.RequestID,
resLog.Response.Proto,
resLog.Response.StatusCode,
statusReason,
resLog.Body,
resLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
}
resID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
}
resLog.ID = resID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
res_id,
key,
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
}
return resLog, nil
}
func (c *Client) UpsertSettings(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
// TODO: Fix where `ErrNoProject` lives.
return proj.ErrNoProject
}
jsonSettings, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("sqlite: could not encode settings as JSON: %w", err)
}
_, err = c.db.ExecContext(ctx,
`INSERT INTO settings (module, settings) VALUES (?, ?)
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
if err != nil {
return fmt.Errorf("sqlite: could not insert scope settings: %w", err)
}
return nil
}
func (c *Client) FindSettingsByModule(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
return proj.ErrNoProject
}
var jsonSettings []byte
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
err := row.Scan(&jsonSettings)
if errors.Is(err, sql.ErrNoRows) {
return proj.ErrNoSettings
} else if err != nil {
return fmt.Errorf("sqlite: could not scan row: %w", err)
}
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
return fmt.Errorf("sqlite: could not decode settings from JSON: %w", err)
}
return nil
}
func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.Header) error {
for key, values := range headers {
for _, value := range values {
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
return fmt.Errorf("could not execute statement: %w", err)
}
}
}
return nil
}
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
headers := make(http.Header)
rows, err := stmt.QueryContext(ctx, id)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
}
defer rows.Close()
for rows.Next() {
var key, value string
err := rows.Scan(&key, &value)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
}
headers.Add(key, value)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
}
return headers, nil
}
func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
var (
joinResponse bool
reqHeaderCols, resHeaderCols []string
)
opCtx := graphql.GetOperationContext(ctx)
reqFields := graphql.CollectFieldsCtx(ctx, nil)
reqCols := []string{"req.id AS req_id", "res.id AS res_id"}
for _, reqField := range reqFields {
if col, ok := reqFieldToColumnMap[reqField.Name]; ok {
reqCols = append(reqCols, "req."+col)
}
if reqField.Name == "headers" {
headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
for _, headerField := range headerFields {
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
reqHeaderCols = append(reqHeaderCols, col)
}
}
}
if reqField.Name == "response" {
joinResponse = true
resFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
for _, resField := range resFields {
if resField.Name == "headers" {
reqCols = append(reqCols, "res.id AS res_id")
headerFields := graphql.CollectFields(opCtx, resField.Selections, nil)
for _, headerField := range headerFields {
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
resHeaderCols = append(resHeaderCols, col)
}
}
}
if col, ok := resFieldToColumnMap[resField.Name]; ok {
reqCols = append(reqCols, "res."+col)
}
}
}
}
return httpRequestLogsQuery{
requestCols: reqCols,
requestHeaderCols: reqHeaderCols,
responseHeaderCols: resHeaderCols,
joinResponse: joinResponse,
}
}
func (c *Client) queryHeaders(
ctx context.Context,
query httpRequestLogsQuery,
reqLogs []reqlog.Request,
) error {
if len(query.requestHeaderCols) > 0 {
reqHeadersQuery, _, err := sq.
Select(query.requestHeaderCols...).
From("http_headers").Where("req_id = ?").
ToSql()
if err != nil {
return fmt.Errorf("could not parse request headers query: %w", err)
}
reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery)
if err != nil {
return fmt.Errorf("could not prepare statement: %w", err)
}
defer reqHeadersStmt.Close()
for i := range reqLogs {
headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID)
if err != nil {
return fmt.Errorf("could not query request headers: %w", err)
}
reqLogs[i].Request.Header = headers
}
}
if len(query.responseHeaderCols) > 0 {
resHeadersQuery, _, err := sq.
Select(query.responseHeaderCols...).
From("http_headers").Where("res_id = ?").
ToSql()
if err != nil {
return fmt.Errorf("could not parse response headers query: %w", err)
}
resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery)
if err != nil {
return fmt.Errorf("could not prepare statement: %w", err)
}
defer resHeadersStmt.Close()
for i := range reqLogs {
if reqLogs[i].Response == nil {
continue
}
headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID)
if err != nil {
return fmt.Errorf("could not query response headers: %w", err)
}
reqLogs[i].Response.Response.Header = headers
}
}
return nil
}
func (c *Client) IsOpen() bool {
return c.db != nil
}

View File

@ -5,129 +5,188 @@ import (
"errors"
"fmt"
"log"
"math/rand"
"regexp"
"sync"
"time"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type (
OnProjectOpenFn func(name string) error
OnProjectCloseFn func(name string) error
OnProjectOpenFn func(projectID ulid.ULID) error
OnProjectCloseFn func(projectID ulid.ULID) error
)
// Service is used for managing projects.
type Service interface {
Open(ctx context.Context, name string) (Project, error)
Close() error
Delete(name string) error
ActiveProject() (Project, error)
Projects() ([]Project, error)
CreateProject(ctx context.Context, name string) (Project, error)
OpenProject(ctx context.Context, projectID ulid.ULID) (Project, error)
CloseProject() error
DeleteProject(ctx context.Context, projectID ulid.ULID) error
ActiveProject(ctx context.Context) (Project, error)
IsProjectActive(projectID ulid.ULID) bool
Projects(ctx context.Context) ([]Project, error)
Scope() *scope.Scope
SetScopeRules(ctx context.Context, rules []scope.Rule) error
SetRequestLogFindFilter(ctx context.Context, filter reqlog.FindRequestsFilter) error
OnProjectOpen(fn OnProjectOpenFn)
OnProjectClose(fn OnProjectCloseFn)
}
type service struct {
repo Repository
activeProject string
reqLogSvc *reqlog.Service
scope *scope.Scope
activeProjectID ulid.ULID
onProjectOpenFns []OnProjectOpenFn
onProjectCloseFns []OnProjectCloseFn
mu sync.RWMutex
}
type Project struct {
ID ulid.ULID
Name string
IsActive bool
Settings Settings
isActive bool
}
type Settings struct {
ReqLogBypassOutOfScope bool
ReqLogOnlyFindInScope bool
ScopeRules []scope.Rule
SearchExpr search.Expression
}
var (
ErrNoProject = errors.New("proj: no open project")
ErrNoSettings = errors.New("proj: settings not found")
ErrInvalidName = errors.New("proj: invalid name, must be alphanumeric or whitespace chars")
ErrProjectNotFound = errors.New("proj: project not found")
ErrNoProject = errors.New("proj: no open project")
ErrNoSettings = errors.New("proj: settings not found")
ErrInvalidName = errors.New("proj: invalid name, must be alphanumeric or whitespace chars")
)
var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`)
type Config struct {
Repository Repository
ReqLogService *reqlog.Service
Scope *scope.Scope
}
// NewService returns a new Service.
func NewService(repo Repository) (Service, error) {
func NewService(cfg Config) (Service, error) {
return &service{
repo: repo,
repo: cfg.Repository,
reqLogSvc: cfg.ReqLogService,
scope: cfg.Scope,
}, nil
}
// Close closes the currently open project database (if there is one).
func (svc *service) Close() error {
func (svc *service) CreateProject(ctx context.Context, name string) (Project, error) {
if !nameRegexp.MatchString(name) {
return Project{}, ErrInvalidName
}
project := Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: name,
}
err := svc.repo.UpsertProject(ctx, project)
if err != nil {
return Project{}, fmt.Errorf("proj: could not create project: %w", err)
}
return project, nil
}
// CloseProject closes the currently open project (if there is one).
func (svc *service) CloseProject() error {
svc.mu.Lock()
defer svc.mu.Unlock()
closedProject := svc.activeProject
if err := svc.repo.Close(); err != nil {
return fmt.Errorf("proj: could not close project: %w", err)
if svc.activeProjectID.Compare(ulid.ULID{}) == 0 {
return nil
}
svc.activeProject = ""
closedProjectID := svc.activeProjectID
svc.emitProjectClosed(closedProject)
svc.activeProjectID = ulid.ULID{}
svc.reqLogSvc.ActiveProjectID = ulid.ULID{}
svc.reqLogSvc.BypassOutOfScopeRequests = false
svc.reqLogSvc.FindReqsFilter = reqlog.FindRequestsFilter{}
svc.scope.SetRules(nil)
svc.emitProjectClosed(closedProjectID)
return nil
}
// Delete removes a project database file from disk (if there is one).
func (svc *service) Delete(name string) error {
if name == "" {
return errors.New("proj: name cannot be empty")
// DeleteProject removes a project from the repository.
func (svc *service) DeleteProject(ctx context.Context, projectID ulid.ULID) error {
if svc.activeProjectID.Compare(projectID) == 0 {
return fmt.Errorf("proj: project (%v) is active", projectID.String())
}
if svc.activeProject == name {
return fmt.Errorf("proj: project (%v) is active", name)
}
if err := svc.repo.DeleteProject(name); err != nil {
if err := svc.repo.DeleteProject(ctx, projectID); err != nil {
return fmt.Errorf("proj: could not delete project: %w", err)
}
return nil
}
// Open opens a database identified with `name`. If a database with this
// identifier doesn't exist yet, it will be automatically created.
func (svc *service) Open(ctx context.Context, name string) (Project, error) {
if !nameRegexp.MatchString(name) {
return Project{}, ErrInvalidName
}
// OpenProject sets a project as the currently active project.
func (svc *service) OpenProject(ctx context.Context, projectID ulid.ULID) (Project, error) {
svc.mu.Lock()
defer svc.mu.Unlock()
if err := svc.repo.Close(); err != nil {
return Project{}, fmt.Errorf("proj: could not close previously open database: %w", err)
project, err := svc.repo.FindProjectByID(ctx, projectID)
if err != nil {
return Project{}, fmt.Errorf("proj: failed to get project: %w", err)
}
if err := svc.repo.OpenProject(name); err != nil {
return Project{}, fmt.Errorf("proj: could not open database: %w", err)
svc.activeProjectID = project.ID
svc.reqLogSvc.FindReqsFilter = reqlog.FindRequestsFilter{
ProjectID: project.ID,
OnlyInScope: project.Settings.ReqLogOnlyFindInScope,
SearchExpr: project.Settings.SearchExpr,
}
svc.reqLogSvc.BypassOutOfScopeRequests = project.Settings.ReqLogBypassOutOfScope
svc.reqLogSvc.ActiveProjectID = project.ID
svc.scope.SetRules(project.Settings.ScopeRules)
svc.activeProject = name
svc.emitProjectOpened()
return Project{
Name: name,
IsActive: true,
}, nil
return project, nil
}
func (svc *service) ActiveProject() (Project, error) {
activeProject := svc.activeProject
if activeProject == "" {
func (svc *service) ActiveProject(ctx context.Context) (Project, error) {
activeProjectID := svc.activeProjectID
if activeProjectID.Compare(ulid.ULID{}) == 0 {
return Project{}, ErrNoProject
}
return Project{
Name: activeProject,
}, nil
project, err := svc.repo.FindProjectByID(ctx, activeProjectID)
if err != nil {
return Project{}, fmt.Errorf("proj: failed to get active project: %w", err)
}
project.isActive = true
return project, nil
}
func (svc *service) Projects() ([]Project, error) {
projects, err := svc.repo.Projects()
func (svc *service) Projects(ctx context.Context) ([]Project, error) {
projects, err := svc.repo.Projects(ctx)
if err != nil {
return nil, fmt.Errorf("proj: could not get projects: %w", err)
}
@ -135,6 +194,10 @@ func (svc *service) Projects() ([]Project, error) {
return projects, nil
}
func (svc *service) Scope() *scope.Scope {
return svc.scope
}
func (svc *service) OnProjectOpen(fn OnProjectOpenFn) {
svc.mu.Lock()
defer svc.mu.Unlock()
@ -151,16 +214,59 @@ func (svc *service) OnProjectClose(fn OnProjectCloseFn) {
func (svc *service) emitProjectOpened() {
for _, fn := range svc.onProjectOpenFns {
if err := fn(svc.activeProject); err != nil {
if err := fn(svc.activeProjectID); err != nil {
log.Printf("[ERROR] Could not execute onProjectOpen function: %v", err)
}
}
}
func (svc *service) emitProjectClosed(name string) {
func (svc *service) emitProjectClosed(projectID ulid.ULID) {
for _, fn := range svc.onProjectCloseFns {
if err := fn(name); err != nil {
if err := fn(projectID); err != nil {
log.Printf("[ERROR] Could not execute onProjectClose function: %v", err)
}
}
}
func (svc *service) SetScopeRules(ctx context.Context, rules []scope.Rule) error {
project, err := svc.ActiveProject(ctx)
if err != nil {
return err
}
project.Settings.ScopeRules = rules
err = svc.repo.UpsertProject(ctx, project)
if err != nil {
return fmt.Errorf("proj: failed to update project: %w", err)
}
svc.scope.SetRules(rules)
return nil
}
func (svc *service) SetRequestLogFindFilter(ctx context.Context, filter reqlog.FindRequestsFilter) error {
project, err := svc.ActiveProject(ctx)
if err != nil {
return err
}
filter.ProjectID = project.ID
project.Settings.ReqLogOnlyFindInScope = filter.OnlyInScope
project.Settings.SearchExpr = filter.SearchExpr
err = svc.repo.UpsertProject(ctx, project)
if err != nil {
return fmt.Errorf("proj: failed to update project: %w", err)
}
svc.reqLogSvc.FindReqsFilter = filter
return nil
}
func (svc *service) IsProjectActive(projectID ulid.ULID) bool {
return projectID.Compare(svc.activeProjectID) == 0
}

View File

@ -2,13 +2,14 @@ package proj
import (
"context"
"github.com/oklog/ulid"
)
type Repository interface {
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
OpenProject(name string) error
DeleteProject(name string) error
Projects() ([]Project, error)
FindProjectByID(ctx context.Context, id ulid.ULID) (Project, error)
UpsertProject(ctx context.Context, project Project) error
DeleteProject(ctx context.Context, id ulid.ULID) error
Projects(ctx context.Context) ([]Project, error)
Close() error
}

View File

@ -15,7 +15,7 @@ import (
type contextKey int
const ReqIDKey contextKey = 0
const ReqLogIDKey contextKey = 0
// Proxy implements http.Handler and offers MITM behaviour for modifying
// HTTP requests and responses.

View File

@ -1,318 +0,0 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package reqlog_test
import (
"context"
"github.com/dstotijn/hetty/pkg/proj"
"sync"
)
// Ensure, that ProjServiceMock does implement proj.Service.
// If this is not the case, regenerate this file with moq.
var _ proj.Service = &ProjServiceMock{}
// ProjServiceMock is a mock implementation of proj.Service.
//
// func TestSomethingThatUsesService(t *testing.T) {
//
// // make and configure a mocked proj.Service
// mockedService := &ProjServiceMock{
// ActiveProjectFunc: func() (proj.Project, error) {
// panic("mock out the ActiveProject method")
// },
// CloseFunc: func() error {
// panic("mock out the Close method")
// },
// DeleteFunc: func(name string) error {
// panic("mock out the Delete method")
// },
// OnProjectCloseFunc: func(fn proj.OnProjectCloseFn) {
// panic("mock out the OnProjectClose method")
// },
// OnProjectOpenFunc: func(fn proj.OnProjectOpenFn) {
// panic("mock out the OnProjectOpen method")
// },
// OpenFunc: func(ctx context.Context, name string) (proj.Project, error) {
// panic("mock out the Open method")
// },
// ProjectsFunc: func() ([]proj.Project, error) {
// panic("mock out the Projects method")
// },
// }
//
// // use mockedService in code that requires proj.Service
// // and then make assertions.
//
// }
type ProjServiceMock struct {
// ActiveProjectFunc mocks the ActiveProject method.
ActiveProjectFunc func() (proj.Project, error)
// CloseFunc mocks the Close method.
CloseFunc func() error
// DeleteFunc mocks the Delete method.
DeleteFunc func(name string) error
// OnProjectCloseFunc mocks the OnProjectClose method.
OnProjectCloseFunc func(fn proj.OnProjectCloseFn)
// OnProjectOpenFunc mocks the OnProjectOpen method.
OnProjectOpenFunc func(fn proj.OnProjectOpenFn)
// OpenFunc mocks the Open method.
OpenFunc func(ctx context.Context, name string) (proj.Project, error)
// ProjectsFunc mocks the Projects method.
ProjectsFunc func() ([]proj.Project, error)
// calls tracks calls to the methods.
calls struct {
// ActiveProject holds details about calls to the ActiveProject method.
ActiveProject []struct {
}
// Close holds details about calls to the Close method.
Close []struct {
}
// Delete holds details about calls to the Delete method.
Delete []struct {
// Name is the name argument value.
Name string
}
// OnProjectClose holds details about calls to the OnProjectClose method.
OnProjectClose []struct {
// Fn is the fn argument value.
Fn proj.OnProjectCloseFn
}
// OnProjectOpen holds details about calls to the OnProjectOpen method.
OnProjectOpen []struct {
// Fn is the fn argument value.
Fn proj.OnProjectOpenFn
}
// Open holds details about calls to the Open method.
Open []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Name is the name argument value.
Name string
}
// Projects holds details about calls to the Projects method.
Projects []struct {
}
}
lockActiveProject sync.RWMutex
lockClose sync.RWMutex
lockDelete sync.RWMutex
lockOnProjectClose sync.RWMutex
lockOnProjectOpen sync.RWMutex
lockOpen sync.RWMutex
lockProjects sync.RWMutex
}
// ActiveProject calls ActiveProjectFunc.
func (mock *ProjServiceMock) ActiveProject() (proj.Project, error) {
if mock.ActiveProjectFunc == nil {
panic("ProjServiceMock.ActiveProjectFunc: method is nil but Service.ActiveProject was just called")
}
callInfo := struct {
}{}
mock.lockActiveProject.Lock()
mock.calls.ActiveProject = append(mock.calls.ActiveProject, callInfo)
mock.lockActiveProject.Unlock()
return mock.ActiveProjectFunc()
}
// ActiveProjectCalls gets all the calls that were made to ActiveProject.
// Check the length with:
// len(mockedService.ActiveProjectCalls())
func (mock *ProjServiceMock) ActiveProjectCalls() []struct {
} {
var calls []struct {
}
mock.lockActiveProject.RLock()
calls = mock.calls.ActiveProject
mock.lockActiveProject.RUnlock()
return calls
}
// Close calls CloseFunc.
func (mock *ProjServiceMock) Close() error {
if mock.CloseFunc == nil {
panic("ProjServiceMock.CloseFunc: method is nil but Service.Close was just called")
}
callInfo := struct {
}{}
mock.lockClose.Lock()
mock.calls.Close = append(mock.calls.Close, callInfo)
mock.lockClose.Unlock()
return mock.CloseFunc()
}
// CloseCalls gets all the calls that were made to Close.
// Check the length with:
// len(mockedService.CloseCalls())
func (mock *ProjServiceMock) CloseCalls() []struct {
} {
var calls []struct {
}
mock.lockClose.RLock()
calls = mock.calls.Close
mock.lockClose.RUnlock()
return calls
}
// Delete calls DeleteFunc.
func (mock *ProjServiceMock) Delete(name string) error {
if mock.DeleteFunc == nil {
panic("ProjServiceMock.DeleteFunc: method is nil but Service.Delete was just called")
}
callInfo := struct {
Name string
}{
Name: name,
}
mock.lockDelete.Lock()
mock.calls.Delete = append(mock.calls.Delete, callInfo)
mock.lockDelete.Unlock()
return mock.DeleteFunc(name)
}
// DeleteCalls gets all the calls that were made to Delete.
// Check the length with:
// len(mockedService.DeleteCalls())
func (mock *ProjServiceMock) DeleteCalls() []struct {
Name string
} {
var calls []struct {
Name string
}
mock.lockDelete.RLock()
calls = mock.calls.Delete
mock.lockDelete.RUnlock()
return calls
}
// OnProjectClose calls OnProjectCloseFunc.
func (mock *ProjServiceMock) OnProjectClose(fn proj.OnProjectCloseFn) {
if mock.OnProjectCloseFunc == nil {
panic("ProjServiceMock.OnProjectCloseFunc: method is nil but Service.OnProjectClose was just called")
}
callInfo := struct {
Fn proj.OnProjectCloseFn
}{
Fn: fn,
}
mock.lockOnProjectClose.Lock()
mock.calls.OnProjectClose = append(mock.calls.OnProjectClose, callInfo)
mock.lockOnProjectClose.Unlock()
mock.OnProjectCloseFunc(fn)
}
// OnProjectCloseCalls gets all the calls that were made to OnProjectClose.
// Check the length with:
// len(mockedService.OnProjectCloseCalls())
func (mock *ProjServiceMock) OnProjectCloseCalls() []struct {
Fn proj.OnProjectCloseFn
} {
var calls []struct {
Fn proj.OnProjectCloseFn
}
mock.lockOnProjectClose.RLock()
calls = mock.calls.OnProjectClose
mock.lockOnProjectClose.RUnlock()
return calls
}
// OnProjectOpen calls OnProjectOpenFunc.
func (mock *ProjServiceMock) OnProjectOpen(fn proj.OnProjectOpenFn) {
if mock.OnProjectOpenFunc == nil {
panic("ProjServiceMock.OnProjectOpenFunc: method is nil but Service.OnProjectOpen was just called")
}
callInfo := struct {
Fn proj.OnProjectOpenFn
}{
Fn: fn,
}
mock.lockOnProjectOpen.Lock()
mock.calls.OnProjectOpen = append(mock.calls.OnProjectOpen, callInfo)
mock.lockOnProjectOpen.Unlock()
mock.OnProjectOpenFunc(fn)
}
// OnProjectOpenCalls gets all the calls that were made to OnProjectOpen.
// Check the length with:
// len(mockedService.OnProjectOpenCalls())
func (mock *ProjServiceMock) OnProjectOpenCalls() []struct {
Fn proj.OnProjectOpenFn
} {
var calls []struct {
Fn proj.OnProjectOpenFn
}
mock.lockOnProjectOpen.RLock()
calls = mock.calls.OnProjectOpen
mock.lockOnProjectOpen.RUnlock()
return calls
}
// Open calls OpenFunc.
func (mock *ProjServiceMock) Open(ctx context.Context, name string) (proj.Project, error) {
if mock.OpenFunc == nil {
panic("ProjServiceMock.OpenFunc: method is nil but Service.Open was just called")
}
callInfo := struct {
Ctx context.Context
Name string
}{
Ctx: ctx,
Name: name,
}
mock.lockOpen.Lock()
mock.calls.Open = append(mock.calls.Open, callInfo)
mock.lockOpen.Unlock()
return mock.OpenFunc(ctx, name)
}
// OpenCalls gets all the calls that were made to Open.
// Check the length with:
// len(mockedService.OpenCalls())
func (mock *ProjServiceMock) OpenCalls() []struct {
Ctx context.Context
Name string
} {
var calls []struct {
Ctx context.Context
Name string
}
mock.lockOpen.RLock()
calls = mock.calls.Open
mock.lockOpen.RUnlock()
return calls
}
// Projects calls ProjectsFunc.
func (mock *ProjServiceMock) Projects() ([]proj.Project, error) {
if mock.ProjectsFunc == nil {
panic("ProjServiceMock.ProjectsFunc: method is nil but Service.Projects was just called")
}
callInfo := struct {
}{}
mock.lockProjects.Lock()
mock.calls.Projects = append(mock.calls.Projects, callInfo)
mock.lockProjects.Unlock()
return mock.ProjectsFunc()
}
// ProjectsCalls gets all the calls that were made to Projects.
// Check the length with:
// len(mockedService.ProjectsCalls())
func (mock *ProjServiceMock) ProjectsCalls() []struct {
} {
var calls []struct {
}
mock.lockProjects.RLock()
calls = mock.calls.Projects
mock.lockProjects.RUnlock()
return calls
}

View File

@ -2,18 +2,16 @@ package reqlog
import (
"context"
"net/http"
"time"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/scope"
)
type Repository interface {
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) // nolint:lll
ClearRequestLogs(ctx context.Context) error
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]RequestLog, error)
FindRequestLogByID(ctx context.Context, id ulid.ULID) (RequestLog, error)
StoreRequestLog(ctx context.Context, reqLog RequestLog) error
StoreResponseLog(ctx context.Context, reqLogID ulid.ULID, resLog ResponseLog) error
ClearRequestLogs(ctx context.Context, projectID ulid.ULID) error
}

View File

@ -7,9 +7,8 @@ import (
"context"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"net/http"
"github.com/oklog/ulid"
"sync"
"time"
)
// Ensure, that RepoMock does implement reqlog.Repository.
@ -22,26 +21,20 @@ var _ reqlog.Repository = &RepoMock{}
//
// // make and configure a mocked reqlog.Repository
// mockedRepository := &RepoMock{
// AddRequestLogFunc: func(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*reqlog.Request, error) {
// panic("mock out the AddRequestLog method")
// },
// AddResponseLogFunc: func(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*reqlog.Response, error) {
// panic("mock out the AddResponseLog method")
// },
// ClearRequestLogsFunc: func(ctx context.Context) error {
// ClearRequestLogsFunc: func(ctx context.Context, projectID ulid.ULID) error {
// panic("mock out the ClearRequestLogs method")
// },
// FindRequestLogByIDFunc: func(ctx context.Context, id int64) (reqlog.Request, error) {
// FindRequestLogByIDFunc: func(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error) {
// panic("mock out the FindRequestLogByID method")
// },
// FindRequestLogsFunc: func(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.Request, error) {
// FindRequestLogsFunc: func(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.RequestLog, error) {
// panic("mock out the FindRequestLogs method")
// },
// FindSettingsByModuleFunc: func(ctx context.Context, module string, settings interface{}) error {
// panic("mock out the FindSettingsByModule method")
// StoreRequestLogFunc: func(ctx context.Context, reqLog reqlog.RequestLog) error {
// panic("mock out the StoreRequestLog method")
// },
// UpsertSettingsFunc: func(ctx context.Context, module string, settings interface{}) error {
// panic("mock out the UpsertSettings method")
// StoreResponseLogFunc: func(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
// panic("mock out the StoreResponseLog method")
// },
// }
//
@ -50,64 +43,36 @@ var _ reqlog.Repository = &RepoMock{}
//
// }
type RepoMock struct {
// AddRequestLogFunc mocks the AddRequestLog method.
AddRequestLogFunc func(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*reqlog.Request, error)
// AddResponseLogFunc mocks the AddResponseLog method.
AddResponseLogFunc func(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*reqlog.Response, error)
// ClearRequestLogsFunc mocks the ClearRequestLogs method.
ClearRequestLogsFunc func(ctx context.Context) error
ClearRequestLogsFunc func(ctx context.Context, projectID ulid.ULID) error
// FindRequestLogByIDFunc mocks the FindRequestLogByID method.
FindRequestLogByIDFunc func(ctx context.Context, id int64) (reqlog.Request, error)
FindRequestLogByIDFunc func(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error)
// FindRequestLogsFunc mocks the FindRequestLogs method.
FindRequestLogsFunc func(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.Request, error)
FindRequestLogsFunc func(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.RequestLog, error)
// FindSettingsByModuleFunc mocks the FindSettingsByModule method.
FindSettingsByModuleFunc func(ctx context.Context, module string, settings interface{}) error
// StoreRequestLogFunc mocks the StoreRequestLog method.
StoreRequestLogFunc func(ctx context.Context, reqLog reqlog.RequestLog) error
// UpsertSettingsFunc mocks the UpsertSettings method.
UpsertSettingsFunc func(ctx context.Context, module string, settings interface{}) error
// StoreResponseLogFunc mocks the StoreResponseLog method.
StoreResponseLogFunc func(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error
// calls tracks calls to the methods.
calls struct {
// AddRequestLog holds details about calls to the AddRequestLog method.
AddRequestLog []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Req is the req argument value.
Req http.Request
// Body is the body argument value.
Body []byte
// Timestamp is the timestamp argument value.
Timestamp time.Time
}
// AddResponseLog holds details about calls to the AddResponseLog method.
AddResponseLog []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ReqID is the reqID argument value.
ReqID int64
// Res is the res argument value.
Res http.Response
// Body is the body argument value.
Body []byte
// Timestamp is the timestamp argument value.
Timestamp time.Time
}
// ClearRequestLogs holds details about calls to the ClearRequestLogs method.
ClearRequestLogs []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ProjectID is the projectID argument value.
ProjectID ulid.ULID
}
// FindRequestLogByID holds details about calls to the FindRequestLogByID method.
FindRequestLogByID []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ID is the id argument value.
ID int64
ID ulid.ULID
}
// FindRequestLogs holds details about calls to the FindRequestLogs method.
FindRequestLogs []struct {
@ -118,148 +83,58 @@ type RepoMock struct {
// ScopeMoqParam is the scopeMoqParam argument value.
ScopeMoqParam *scope.Scope
}
// FindSettingsByModule holds details about calls to the FindSettingsByModule method.
FindSettingsByModule []struct {
// StoreRequestLog holds details about calls to the StoreRequestLog method.
StoreRequestLog []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Module is the module argument value.
Module string
// Settings is the settings argument value.
Settings interface{}
// ReqLog is the reqLog argument value.
ReqLog reqlog.RequestLog
}
// UpsertSettings holds details about calls to the UpsertSettings method.
UpsertSettings []struct {
// StoreResponseLog holds details about calls to the StoreResponseLog method.
StoreResponseLog []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Module is the module argument value.
Module string
// Settings is the settings argument value.
Settings interface{}
// ReqLogID is the reqLogID argument value.
ReqLogID ulid.ULID
// ResLog is the resLog argument value.
ResLog reqlog.ResponseLog
}
}
lockAddRequestLog sync.RWMutex
lockAddResponseLog sync.RWMutex
lockClearRequestLogs sync.RWMutex
lockFindRequestLogByID sync.RWMutex
lockFindRequestLogs sync.RWMutex
lockFindSettingsByModule sync.RWMutex
lockUpsertSettings sync.RWMutex
}
// AddRequestLog calls AddRequestLogFunc.
func (mock *RepoMock) AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*reqlog.Request, error) {
if mock.AddRequestLogFunc == nil {
panic("RepoMock.AddRequestLogFunc: method is nil but Repository.AddRequestLog was just called")
}
callInfo := struct {
Ctx context.Context
Req http.Request
Body []byte
Timestamp time.Time
}{
Ctx: ctx,
Req: req,
Body: body,
Timestamp: timestamp,
}
mock.lockAddRequestLog.Lock()
mock.calls.AddRequestLog = append(mock.calls.AddRequestLog, callInfo)
mock.lockAddRequestLog.Unlock()
return mock.AddRequestLogFunc(ctx, req, body, timestamp)
}
// AddRequestLogCalls gets all the calls that were made to AddRequestLog.
// Check the length with:
// len(mockedRepository.AddRequestLogCalls())
func (mock *RepoMock) AddRequestLogCalls() []struct {
Ctx context.Context
Req http.Request
Body []byte
Timestamp time.Time
} {
var calls []struct {
Ctx context.Context
Req http.Request
Body []byte
Timestamp time.Time
}
mock.lockAddRequestLog.RLock()
calls = mock.calls.AddRequestLog
mock.lockAddRequestLog.RUnlock()
return calls
}
// AddResponseLog calls AddResponseLogFunc.
func (mock *RepoMock) AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*reqlog.Response, error) {
if mock.AddResponseLogFunc == nil {
panic("RepoMock.AddResponseLogFunc: method is nil but Repository.AddResponseLog was just called")
}
callInfo := struct {
Ctx context.Context
ReqID int64
Res http.Response
Body []byte
Timestamp time.Time
}{
Ctx: ctx,
ReqID: reqID,
Res: res,
Body: body,
Timestamp: timestamp,
}
mock.lockAddResponseLog.Lock()
mock.calls.AddResponseLog = append(mock.calls.AddResponseLog, callInfo)
mock.lockAddResponseLog.Unlock()
return mock.AddResponseLogFunc(ctx, reqID, res, body, timestamp)
}
// AddResponseLogCalls gets all the calls that were made to AddResponseLog.
// Check the length with:
// len(mockedRepository.AddResponseLogCalls())
func (mock *RepoMock) AddResponseLogCalls() []struct {
Ctx context.Context
ReqID int64
Res http.Response
Body []byte
Timestamp time.Time
} {
var calls []struct {
Ctx context.Context
ReqID int64
Res http.Response
Body []byte
Timestamp time.Time
}
mock.lockAddResponseLog.RLock()
calls = mock.calls.AddResponseLog
mock.lockAddResponseLog.RUnlock()
return calls
lockClearRequestLogs sync.RWMutex
lockFindRequestLogByID sync.RWMutex
lockFindRequestLogs sync.RWMutex
lockStoreRequestLog sync.RWMutex
lockStoreResponseLog sync.RWMutex
}
// ClearRequestLogs calls ClearRequestLogsFunc.
func (mock *RepoMock) ClearRequestLogs(ctx context.Context) error {
func (mock *RepoMock) ClearRequestLogs(ctx context.Context, projectID ulid.ULID) error {
if mock.ClearRequestLogsFunc == nil {
panic("RepoMock.ClearRequestLogsFunc: method is nil but Repository.ClearRequestLogs was just called")
}
callInfo := struct {
Ctx context.Context
Ctx context.Context
ProjectID ulid.ULID
}{
Ctx: ctx,
Ctx: ctx,
ProjectID: projectID,
}
mock.lockClearRequestLogs.Lock()
mock.calls.ClearRequestLogs = append(mock.calls.ClearRequestLogs, callInfo)
mock.lockClearRequestLogs.Unlock()
return mock.ClearRequestLogsFunc(ctx)
return mock.ClearRequestLogsFunc(ctx, projectID)
}
// ClearRequestLogsCalls gets all the calls that were made to ClearRequestLogs.
// Check the length with:
// len(mockedRepository.ClearRequestLogsCalls())
func (mock *RepoMock) ClearRequestLogsCalls() []struct {
Ctx context.Context
Ctx context.Context
ProjectID ulid.ULID
} {
var calls []struct {
Ctx context.Context
Ctx context.Context
ProjectID ulid.ULID
}
mock.lockClearRequestLogs.RLock()
calls = mock.calls.ClearRequestLogs
@ -268,13 +143,13 @@ func (mock *RepoMock) ClearRequestLogsCalls() []struct {
}
// FindRequestLogByID calls FindRequestLogByIDFunc.
func (mock *RepoMock) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) {
func (mock *RepoMock) FindRequestLogByID(ctx context.Context, id ulid.ULID) (reqlog.RequestLog, error) {
if mock.FindRequestLogByIDFunc == nil {
panic("RepoMock.FindRequestLogByIDFunc: method is nil but Repository.FindRequestLogByID was just called")
}
callInfo := struct {
Ctx context.Context
ID int64
ID ulid.ULID
}{
Ctx: ctx,
ID: id,
@ -290,11 +165,11 @@ func (mock *RepoMock) FindRequestLogByID(ctx context.Context, id int64) (reqlog.
// len(mockedRepository.FindRequestLogByIDCalls())
func (mock *RepoMock) FindRequestLogByIDCalls() []struct {
Ctx context.Context
ID int64
ID ulid.ULID
} {
var calls []struct {
Ctx context.Context
ID int64
ID ulid.ULID
}
mock.lockFindRequestLogByID.RLock()
calls = mock.calls.FindRequestLogByID
@ -303,7 +178,7 @@ func (mock *RepoMock) FindRequestLogByIDCalls() []struct {
}
// FindRequestLogs calls FindRequestLogsFunc.
func (mock *RepoMock) FindRequestLogs(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.Request, error) {
func (mock *RepoMock) FindRequestLogs(ctx context.Context, filter reqlog.FindRequestsFilter, scopeMoqParam *scope.Scope) ([]reqlog.RequestLog, error) {
if mock.FindRequestLogsFunc == nil {
panic("RepoMock.FindRequestLogsFunc: method is nil but Repository.FindRequestLogs was just called")
}
@ -341,80 +216,76 @@ func (mock *RepoMock) FindRequestLogsCalls() []struct {
return calls
}
// FindSettingsByModule calls FindSettingsByModuleFunc.
func (mock *RepoMock) FindSettingsByModule(ctx context.Context, module string, settings interface{}) error {
if mock.FindSettingsByModuleFunc == nil {
panic("RepoMock.FindSettingsByModuleFunc: method is nil but Repository.FindSettingsByModule was just called")
// StoreRequestLog calls StoreRequestLogFunc.
func (mock *RepoMock) StoreRequestLog(ctx context.Context, reqLog reqlog.RequestLog) error {
if mock.StoreRequestLogFunc == nil {
panic("RepoMock.StoreRequestLogFunc: method is nil but Repository.StoreRequestLog was just called")
}
callInfo := struct {
Ctx context.Context
Module string
Settings interface{}
Ctx context.Context
ReqLog reqlog.RequestLog
}{
Ctx: ctx,
Module: module,
Settings: settings,
Ctx: ctx,
ReqLog: reqLog,
}
mock.lockFindSettingsByModule.Lock()
mock.calls.FindSettingsByModule = append(mock.calls.FindSettingsByModule, callInfo)
mock.lockFindSettingsByModule.Unlock()
return mock.FindSettingsByModuleFunc(ctx, module, settings)
mock.lockStoreRequestLog.Lock()
mock.calls.StoreRequestLog = append(mock.calls.StoreRequestLog, callInfo)
mock.lockStoreRequestLog.Unlock()
return mock.StoreRequestLogFunc(ctx, reqLog)
}
// FindSettingsByModuleCalls gets all the calls that were made to FindSettingsByModule.
// StoreRequestLogCalls gets all the calls that were made to StoreRequestLog.
// Check the length with:
// len(mockedRepository.FindSettingsByModuleCalls())
func (mock *RepoMock) FindSettingsByModuleCalls() []struct {
Ctx context.Context
Module string
Settings interface{}
// len(mockedRepository.StoreRequestLogCalls())
func (mock *RepoMock) StoreRequestLogCalls() []struct {
Ctx context.Context
ReqLog reqlog.RequestLog
} {
var calls []struct {
Ctx context.Context
Module string
Settings interface{}
Ctx context.Context
ReqLog reqlog.RequestLog
}
mock.lockFindSettingsByModule.RLock()
calls = mock.calls.FindSettingsByModule
mock.lockFindSettingsByModule.RUnlock()
mock.lockStoreRequestLog.RLock()
calls = mock.calls.StoreRequestLog
mock.lockStoreRequestLog.RUnlock()
return calls
}
// UpsertSettings calls UpsertSettingsFunc.
func (mock *RepoMock) UpsertSettings(ctx context.Context, module string, settings interface{}) error {
if mock.UpsertSettingsFunc == nil {
panic("RepoMock.UpsertSettingsFunc: method is nil but Repository.UpsertSettings was just called")
// StoreResponseLog calls StoreResponseLogFunc.
func (mock *RepoMock) StoreResponseLog(ctx context.Context, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
if mock.StoreResponseLogFunc == nil {
panic("RepoMock.StoreResponseLogFunc: method is nil but Repository.StoreResponseLog was just called")
}
callInfo := struct {
Ctx context.Context
Module string
Settings interface{}
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
}{
Ctx: ctx,
Module: module,
Settings: settings,
ReqLogID: reqLogID,
ResLog: resLog,
}
mock.lockUpsertSettings.Lock()
mock.calls.UpsertSettings = append(mock.calls.UpsertSettings, callInfo)
mock.lockUpsertSettings.Unlock()
return mock.UpsertSettingsFunc(ctx, module, settings)
mock.lockStoreResponseLog.Lock()
mock.calls.StoreResponseLog = append(mock.calls.StoreResponseLog, callInfo)
mock.lockStoreResponseLog.Unlock()
return mock.StoreResponseLogFunc(ctx, reqLogID, resLog)
}
// UpsertSettingsCalls gets all the calls that were made to UpsertSettings.
// StoreResponseLogCalls gets all the calls that were made to StoreResponseLog.
// Check the length with:
// len(mockedRepository.UpsertSettingsCalls())
func (mock *RepoMock) UpsertSettingsCalls() []struct {
// len(mockedRepository.StoreResponseLogCalls())
func (mock *RepoMock) StoreResponseLogCalls() []struct {
Ctx context.Context
Module string
Settings interface{}
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
} {
var calls []struct {
Ctx context.Context
Module string
Settings interface{}
ReqLogID ulid.ULID
ResLog reqlog.ResponseLog
}
mock.lockUpsertSettings.RLock()
calls = mock.calls.UpsertSettings
mock.lockUpsertSettings.RUnlock()
mock.lockStoreResponseLog.RLock()
calls = mock.calls.StoreResponseLog
mock.lockStoreResponseLog.RUnlock()
return calls
}

View File

@ -4,15 +4,18 @@ import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
@ -22,127 +25,109 @@ type contextKey int
const LogBypassedKey contextKey = 0
const moduleName = "reqlog"
var (
ErrRequestNotFound = errors.New("reqlog: request not found")
ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set")
)
var ErrRequestNotFound = errors.New("reqlog: request not found")
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
type Request struct {
ID int64
Request http.Request
Body []byte
Timestamp time.Time
Response *Response
type RequestLog struct {
ID ulid.ULID
ProjectID ulid.ULID
URL *url.URL
Method string
Proto string
Header http.Header
Body []byte
Response *ResponseLog
}
type Response struct {
ID int64
RequestID int64
Response http.Response
Body []byte
Timestamp time.Time
type ResponseLog struct {
Proto string
StatusCode int
Status string
Header http.Header
Body []byte
}
type Service struct {
BypassOutOfScopeRequests bool
FindReqsFilter FindRequestsFilter
ActiveProjectID ulid.ULID
scope *scope.Scope
repo Repository
}
type FindRequestsFilter struct {
OnlyInScope bool
SearchExpr search.Expression `json:"-"`
RawSearchExpr string
ProjectID ulid.ULID
OnlyInScope bool
SearchExpr search.Expression
}
type Config struct {
Scope *scope.Scope
Repository Repository
ProjectService proj.Service
BypassOutOfScopeRequests bool
Scope *scope.Scope
Repository Repository
}
func NewService(cfg Config) *Service {
svc := &Service{
scope: cfg.Scope,
repo: cfg.Repository,
BypassOutOfScopeRequests: cfg.BypassOutOfScopeRequests,
return &Service{
repo: cfg.Repository,
scope: cfg.Scope,
}
cfg.ProjectService.OnProjectOpen(func(_ string) error {
err := svc.repo.FindSettingsByModule(context.Background(), moduleName, svc)
if errors.Is(err, proj.ErrNoSettings) {
return nil
}
if err != nil {
return fmt.Errorf("reqlog: could not load settings: %w", err)
}
return nil
})
cfg.ProjectService.OnProjectClose(func(_ string) error {
svc.BypassOutOfScopeRequests = false
svc.FindReqsFilter = FindRequestsFilter{}
return nil
})
return svc
}
func (svc *Service) FindRequests(ctx context.Context) ([]Request, error) {
func (svc *Service) FindRequests(ctx context.Context) ([]RequestLog, error) {
return svc.repo.FindRequestLogs(ctx, svc.FindReqsFilter, svc.scope)
}
func (svc *Service) FindRequestLogByID(ctx context.Context, id int64) (Request, error) {
func (svc *Service) FindRequestLogByID(ctx context.Context, id ulid.ULID) (RequestLog, error) {
return svc.repo.FindRequestLogByID(ctx, id)
}
func (svc *Service) SetRequestLogFilter(ctx context.Context, filter FindRequestsFilter) error {
svc.FindReqsFilter = filter
return svc.repo.UpsertSettings(ctx, "reqlog", svc)
func (svc *Service) ClearRequests(ctx context.Context, projectID ulid.ULID) error {
return svc.repo.ClearRequestLogs(ctx, projectID)
}
func (svc *Service) ClearRequests(ctx context.Context) error {
return svc.repo.ClearRequestLogs(ctx)
}
func (svc *Service) addRequest(
ctx context.Context,
req http.Request,
body []byte,
timestamp time.Time,
) (*Request, error) {
return svc.repo.AddRequestLog(ctx, req, body, timestamp)
}
func (svc *Service) addResponse(
ctx context.Context,
reqID int64,
res http.Response,
body []byte,
timestamp time.Time,
) (*Response, error) {
func (svc *Service) storeResponse(ctx context.Context, reqLogID ulid.ULID, res *http.Response) error {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
gzipReader, err := gzip.NewReader(res.Body)
if err != nil {
return nil, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
return fmt.Errorf("could not create gzip reader: %w", err)
}
defer gzipReader.Close()
body, err = ioutil.ReadAll(gzipReader)
if err != nil {
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, gzipReader); err != nil {
return fmt.Errorf("could not read gzipped response body: %w", err)
}
res.Body = io.NopCloser(buf)
}
return svc.repo.AddResponseLog(ctx, reqID, res, body, timestamp)
body, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("could not read body: %w", err)
}
resLog := ResponseLog{
Proto: res.Proto,
StatusCode: res.StatusCode,
Status: res.Status,
Header: res.Header,
Body: body,
}
return svc.repo.StoreResponseLog(ctx, reqLogID, resLog)
}
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
now := time.Now()
next(req)
clone := req.Clone(req.Context())
@ -160,10 +145,19 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
clone.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
// Bypass logging if no project is active.
if svc.ActiveProjectID.Compare(ulid.ULID{}) == 0 {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
return
}
// Bypass logging if this setting is enabled and the incoming request
// doens't match any rules of the scope.
// doesn't match any scope rules.
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
@ -171,26 +165,29 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
return
}
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
if errors.Is(err, proj.ErrNoProject) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
reqLog := RequestLog{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: svc.ActiveProjectID,
Method: clone.Method,
URL: clone.URL,
Proto: clone.Proto,
Header: clone.Header,
Body: body,
}
return
} else if err != nil {
err := svc.repo.StoreRequestLog(req.Context(), reqLog)
if err != nil {
log.Printf("[ERROR] Could not store request log: %v", err)
return
}
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID)
*req = *req.WithContext(ctx)
}
}
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
now := time.Now()
if err := next(res); err != nil {
return err
}
@ -199,8 +196,8 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
reqID, _ := res.Request.Context().Value(proxy.ReqIDKey).(int64)
if reqID == 0 {
reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID)
if !ok {
return errors.New("reqlog: request is missing ID")
}
@ -213,9 +210,10 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
clone.Body = ioutil.NopCloser(bytes.NewBuffer(body))
go func() {
if _, err := svc.addResponse(context.Background(), reqID, clone, body, now); err != nil {
if err := svc.storeResponse(context.Background(), reqLogID, &clone); err != nil {
log.Printf("[ERROR] Could not store response log: %v", err)
}
}()
@ -223,33 +221,3 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
}
// UnmarshalJSON implements json.Unmarshaler.
func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
var dto struct {
OnlyInScope bool
RawSearchExpr string
}
if err := json.Unmarshal(b, &dto); err != nil {
return err
}
filter := FindRequestsFilter{
OnlyInScope: dto.OnlyInScope,
RawSearchExpr: dto.RawSearchExpr,
}
if dto.RawSearchExpr != "" {
expr, err := search.ParseQuery(dto.RawSearchExpr)
if err != nil {
return err
}
filter.SearchExpr = expr
}
*f = filter
return nil
}

View File

@ -1,124 +1,43 @@
package reqlog_test
//go:generate moq -out proj_mock_test.go -pkg reqlog_test ../proj Service:ProjServiceMock
//go:generate moq -out repo_mock_test.go -pkg reqlog_test . Repository:RepoMock
//go:generate go run github.com/matryer/moq -out repo_mock_test.go -pkg reqlog_test . Repository:RepoMock
import (
"context"
"io/ioutil"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
//nolint:paralleltest
func TestNewService(t *testing.T) {
projSvcMock := &ProjServiceMock{
OnProjectOpenFunc: func(fn proj.OnProjectOpenFn) {},
OnProjectCloseFunc: func(fn proj.OnProjectCloseFn) {},
}
func TestRequestModifier(t *testing.T) {
repoMock := &RepoMock{
FindSettingsByModuleFunc: func(_ context.Context, _ string, _ interface{}) error {
StoreRequestLogFunc: func(_ context.Context, _ reqlog.RequestLog) error {
return nil
},
}
svc := reqlog.NewService(reqlog.Config{
ProjectService: projSvcMock,
Repository: repoMock,
})
t.Run("registered handlers for project open and close", func(t *testing.T) {
got := len(projSvcMock.OnProjectOpenCalls())
if exp := 1; exp != got {
t.Fatalf("incorrect `proj.Service.OnProjectOpen` calls (expected: %v, got: %v)", exp, got)
}
got = len(projSvcMock.OnProjectCloseCalls())
if exp := 1; exp != got {
t.Fatalf("incorrect `proj.Service.OnProjectClose` calls (expected: %v, got: %v)", exp, got)
}
})
t.Run("calls handler when project is opened", func(t *testing.T) {
// Mock opening a project.
err := projSvcMock.OnProjectOpenCalls()[0].Fn("foobar")
if err != nil {
t.Errorf("unexpected error (expected: nil, got: %v)", err)
}
// Assert that settings were fetched from repository, with `svc` as the
// destination.
got := len(repoMock.FindSettingsByModuleCalls())
if exp := 1; exp != got {
t.Fatalf("incorrect `proj.Service.OnProjectOpen` calls (expected: %v, got: %v)", exp, got)
}
findSettingsByModuleCall := repoMock.FindSettingsByModuleCalls()[0]
expModule := "reqlog"
expSettings := svc
if expModule != findSettingsByModuleCall.Module {
t.Fatalf("incorrect `module` argument for `proj.Service.OnProjectOpen` (expected: %v, got: %v)",
expModule, findSettingsByModuleCall.Module)
}
if expSettings != findSettingsByModuleCall.Settings {
t.Fatalf("incorrect `settings` argument for `proj.Service.OnProjectOpen` (expected: %v, got: %v)",
expModule, findSettingsByModuleCall.Settings)
}
})
t.Run("calls handler when project is closed", func(t *testing.T) {
// Mock updating service settings.
svc.BypassOutOfScopeRequests = true
svc.FindReqsFilter = reqlog.FindRequestsFilter{OnlyInScope: true}
// Mock closing a project.
err := projSvcMock.OnProjectCloseCalls()[0].Fn("foobar")
if err != nil {
t.Errorf("unexpected error (expected: nil, got: %v)", err)
}
// Assert that settings were set to defaults on project close.
expBypassOutOfScopeReqs := false
expFindReqsFilter := reqlog.FindRequestsFilter{}
if expBypassOutOfScopeReqs != svc.BypassOutOfScopeRequests {
t.Fatalf("incorrect `Service.BypassOutOfScopeRequests` value (expected: %v, got: %v)",
expBypassOutOfScopeReqs, svc.BypassOutOfScopeRequests)
}
if expFindReqsFilter != svc.FindReqsFilter {
t.Fatalf("incorrect `Service.FindReqsFilter` value (expected: %v, got: %v)",
expFindReqsFilter, svc.FindReqsFilter)
}
})
}
//nolint:paralleltest
func TestRequestModifier(t *testing.T) {
projSvcMock := &ProjServiceMock{
OnProjectOpenFunc: func(fn proj.OnProjectOpenFn) {},
OnProjectCloseFunc: func(fn proj.OnProjectCloseFn) {},
}
repoMock := &RepoMock{
AddRequestLogFunc: func(_ context.Context, _ http.Request, _ []byte, _ time.Time) (*reqlog.Request, error) {
return &reqlog.Request{}, nil
},
}
svc := reqlog.NewService(reqlog.Config{
ProjectService: projSvcMock,
Repository: repoMock,
Repository: repoMock,
Scope: &scope.Scope{},
})
svc.ActiveProjectID = ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
next := func(req *http.Request) {
req.Body = ioutil.NopCloser(strings.NewReader("modified body"))
req.Body = io.NopCloser(strings.NewReader("modified body"))
}
reqModFn := svc.RequestModifier(next)
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
@ -126,49 +45,54 @@ func TestRequestModifier(t *testing.T) {
reqModFn(req)
t.Run("request log was stored in repository", func(t *testing.T) {
got := len(repoMock.AddRequestLogCalls())
if exp := 1; exp != got {
t.Fatalf("incorrect `proj.Service.AddRequestLog` calls (expected: %v, got: %v)", exp, got)
gotCount := len(repoMock.StoreRequestLogCalls())
if expCount := 1; expCount != gotCount {
t.Fatalf("incorrect `proj.Service.AddRequestLog` calls (expected: %v, got: %v)", expCount, gotCount)
}
})
t.Run("ran next modifier first, before calling repository", func(t *testing.T) {
got := repoMock.AddRequestLogCalls()[0].Body
if exp := "modified body"; exp != string(got) {
t.Fatalf("incorrect `body` argument for `Repository.AddRequestLogCalls` (expected: %v, got: %v)", exp, string(got))
exp := reqlog.RequestLog{
ID: ulid.ULID{}, // Empty value
ProjectID: svc.ActiveProjectID,
Method: req.Method,
URL: req.URL,
Proto: req.Proto,
Header: req.Header,
Body: []byte("modified body"),
}
got := repoMock.StoreRequestLogCalls()[0].ReqLog
got.ID = ulid.ULID{} // Override to empty value so we can compare against expected value.
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request log not equal (-exp, +got):\n%v", diff)
}
})
}
//nolint:paralleltest
func TestResponseModifier(t *testing.T) {
projSvcMock := &ProjServiceMock{
OnProjectOpenFunc: func(fn proj.OnProjectOpenFn) {},
OnProjectCloseFunc: func(fn proj.OnProjectCloseFn) {},
}
repoMock := &RepoMock{
AddResponseLogFunc: func(_ context.Context, _ int64, _ http.Response,
_ []byte, _ time.Time) (*reqlog.Response, error) {
return &reqlog.Response{}, nil
StoreResponseLogFunc: func(_ context.Context, _ ulid.ULID, _ reqlog.ResponseLog) error {
return nil
},
}
svc := reqlog.NewService(reqlog.Config{
ProjectService: projSvcMock,
Repository: repoMock,
Repository: repoMock,
})
svc.ActiveProjectID = ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
next := func(res *http.Response) error {
res.Body = ioutil.NopCloser(strings.NewReader("modified body"))
res.Body = io.NopCloser(strings.NewReader("modified body"))
return nil
}
resModFn := svc.ResponseModifier(next)
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqIDKey, int64(42)))
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID))
res := &http.Response{
Request: req,
Body: ioutil.NopCloser(strings.NewReader("bar")),
Body: io.NopCloser(strings.NewReader("bar")),
}
if err := resModFn(res); err != nil {
@ -178,16 +102,23 @@ func TestResponseModifier(t *testing.T) {
t.Run("request log was stored in repository", func(t *testing.T) {
// Dirty (but simple) wait for other goroutine to finish calling repository.
time.Sleep(10 * time.Millisecond)
got := len(repoMock.AddResponseLogCalls())
got := len(repoMock.StoreResponseLogCalls())
if exp := 1; exp != got {
t.Fatalf("incorrect `proj.Service.AddResponseLog` calls (expected: %v, got: %v)", exp, got)
}
})
t.Run("ran next modifier first, before calling repository", func(t *testing.T) {
got := repoMock.AddResponseLogCalls()[0].Body
if exp := "modified body"; exp != string(got) {
t.Fatalf("incorrect `body` argument for `Repository.AddResponseLogCalls` (expected: %v, got: %v)", exp, string(got))
}
t.Run("ran next modifier first, before calling repository", func(t *testing.T) {
got := repoMock.StoreResponseLogCalls()[0].ResLog.Body
if exp := "modified body"; exp != string(got) {
t.Fatalf("incorrect `ResponseLog.Body` value (expected: %v, got: %v)", exp, string(got))
}
})
t.Run("called repository with request log id", func(t *testing.T) {
got := repoMock.StoreResponseLogCalls()[0].ReqLogID
if exp := reqLogID; exp.Compare(got) != 0 {
t.Fatalf("incorrect `reqLogID` argument for `Repository.AddResponseLogCalls` (expected: %v, got: %v)", exp.String(), got.String())
}
})
})
}

235
pkg/reqlog/search.go Normal file
View File

@ -0,0 +1,235 @@
package reqlog
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/oklog/ulid"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
)
var reqLogSearchKeyFns = map[string]func(rl RequestLog) string{
"req.id": func(rl RequestLog) string { return rl.ID.String() },
"req.proto": func(rl RequestLog) string { return rl.Proto },
"req.url": func(rl RequestLog) string {
if rl.URL == nil {
return ""
}
return rl.URL.String()
},
"req.method": func(rl RequestLog) string { return rl.Method },
"req.body": func(rl RequestLog) string { return string(rl.Body) },
"req.timestamp": func(rl RequestLog) string { return ulid.Time(rl.ID.Time()).String() },
}
var resLogSearchKeyFns = map[string]func(rl ResponseLog) string{
"res.proto": func(rl ResponseLog) string { return rl.Proto },
"res.statusCode": func(rl ResponseLog) string { return strconv.Itoa(rl.StatusCode) },
"res.statusReason": func(rl ResponseLog) string { return rl.Status },
"res.body": func(rl ResponseLog) string { return string(rl.Body) },
}
// TODO: Request and response headers search key functions.
// Matches returns true if the supplied search expression evaluates to true.
func (reqLog RequestLog) Matches(expr search.Expression) (bool, error) {
switch e := expr.(type) {
case search.PrefixExpression:
return reqLog.matchPrefixExpr(e)
case search.InfixExpression:
return reqLog.matchInfixExpr(e)
case search.StringLiteral:
return reqLog.matchStringLiteral(e)
default:
return false, fmt.Errorf("expression type (%T) not supported", expr)
}
}
func (reqLog RequestLog) matchPrefixExpr(expr search.PrefixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpNot:
match, err := reqLog.Matches(expr.Right)
if err != nil {
return false, err
}
return !match, nil
default:
return false, errors.New("operator is not supported")
}
}
func (reqLog RequestLog) matchInfixExpr(expr search.InfixExpression) (bool, error) {
switch expr.Operator {
case search.TokOpAnd:
left, err := reqLog.Matches(expr.Left)
if err != nil {
return false, err
}
right, err := reqLog.Matches(expr.Right)
if err != nil {
return false, err
}
return left && right, nil
case search.TokOpOr:
left, err := reqLog.Matches(expr.Left)
if err != nil {
return false, err
}
right, err := reqLog.Matches(expr.Right)
if err != nil {
return false, err
}
return left || right, nil
}
left, ok := expr.Left.(search.StringLiteral)
if !ok {
return false, errors.New("left operand must be a string literal")
}
leftVal := reqLog.getMappedStringLiteral(left.Value)
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
right, ok := expr.Right.(*regexp.Regexp)
if !ok {
return false, errors.New("right operand must be a regular expression")
}
switch expr.Operator {
case search.TokOpRe:
return right.MatchString(leftVal), nil
case search.TokOpNotRe:
return !right.MatchString(leftVal), nil
}
}
right, ok := expr.Right.(search.StringLiteral)
if !ok {
return false, errors.New("right operand must be a string literal")
}
rightVal := reqLog.getMappedStringLiteral(right.Value)
switch expr.Operator {
case search.TokOpEq:
return leftVal == rightVal, nil
case search.TokOpNotEq:
return leftVal != rightVal, nil
case search.TokOpGt:
// TODO(?) attempt to parse as int.
return leftVal > rightVal, nil
case search.TokOpLt:
// TODO(?) attempt to parse as int.
return leftVal < rightVal, nil
case search.TokOpGtEq:
// TODO(?) attempt to parse as int.
return leftVal >= rightVal, nil
case search.TokOpLtEq:
// TODO(?) attempt to parse as int.
return leftVal <= rightVal, nil
default:
return false, errors.New("unsupported operator")
}
}
func (reqLog RequestLog) getMappedStringLiteral(s string) string {
switch {
case strings.HasPrefix(s, "req."):
fn, ok := reqLogSearchKeyFns[s]
if ok {
return fn(reqLog)
}
case strings.HasPrefix(s, "res."):
if reqLog.Response == nil {
return ""
}
fn, ok := resLogSearchKeyFns[s]
if ok {
return fn(*reqLog.Response)
}
}
return s
}
func (reqLog RequestLog) matchStringLiteral(strLiteral search.StringLiteral) (bool, error) {
for _, fn := range reqLogSearchKeyFns {
if strings.Contains(
strings.ToLower(fn(reqLog)),
strings.ToLower(strLiteral.Value),
) {
return true, nil
}
}
if reqLog.Response != nil {
for _, fn := range resLogSearchKeyFns {
if strings.Contains(
strings.ToLower(fn(*reqLog.Response)),
strings.ToLower(strLiteral.Value),
) {
return true, nil
}
}
}
return false, nil
}
func (reqLog RequestLog) MatchScope(s *scope.Scope) bool {
for _, rule := range s.Rules() {
if rule.URL != nil && reqLog.URL != nil {
if matches := rule.URL.MatchString(reqLog.URL.String()); matches {
return true
}
}
for key, values := range reqLog.Header {
var keyMatches, valueMatches bool
if rule.Header.Key != nil {
if matches := rule.Header.Key.MatchString(key); matches {
keyMatches = true
}
}
if rule.Header.Value != nil {
for _, value := range values {
if matches := rule.Header.Value.MatchString(value); matches {
valueMatches = true
break
}
}
}
// When only key or value is set, match on whatever is set.
// When both are set, both must match.
switch {
case rule.Header.Key != nil && rule.Header.Value == nil && keyMatches:
return true
case rule.Header.Key == nil && rule.Header.Value != nil && valueMatches:
return true
case rule.Header.Key != nil && rule.Header.Value != nil && keyMatches && valueMatches:
return true
}
}
if rule.Body != nil {
if matches := rule.Body.Match(reqLog.Body); matches {
return true
}
}
}
return false
}

203
pkg/reqlog/search_test.go Normal file
View File

@ -0,0 +1,203 @@
package reqlog_test
import (
"testing"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/search"
)
func TestRequestLogMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
requestLog reqlog.RequestLog
expectedMatch bool
expectedError error
}{
{
name: "infix expression, equal operator, match",
query: "req.body = foo",
requestLog: reqlog.RequestLog{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, not equal operator, match",
query: "req.body != bar",
requestLog: reqlog.RequestLog{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than operator, match",
query: "req.body > a",
requestLog: reqlog.RequestLog{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than operator, match",
query: "req.body < b",
requestLog: reqlog.RequestLog{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than or equal operator, match greater than",
query: "req.body >= a",
requestLog: reqlog.RequestLog{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, greater than or equal operator, match equal",
query: "req.body >= a",
requestLog: reqlog.RequestLog{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than or equal operator, match less than",
query: "req.body <= b",
requestLog: reqlog.RequestLog{
Body: []byte("a"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, less than or equal operator, match equal",
query: "req.body <= b",
requestLog: reqlog.RequestLog{
Body: []byte("b"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, regular expression operator, match",
query: `req.body =~ "^foo(.*)$"`,
requestLog: reqlog.RequestLog{
Body: []byte("foobar"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, negate regular expression operator, match",
query: `req.body !~ "^foo(.*)$"`,
requestLog: reqlog.RequestLog{
Body: []byte("xoobar"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, and operator, match",
query: "req.body = bar AND res.body = yolo",
requestLog: reqlog.RequestLog{
Body: []byte("bar"),
Response: &reqlog.ResponseLog{
Body: []byte("yolo"),
},
},
expectedMatch: true,
expectedError: nil,
},
{
name: "infix expression, or operator, match",
query: "req.body = bar OR res.body = yolo",
requestLog: reqlog.RequestLog{
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Body: []byte("yolo"),
},
},
expectedMatch: true,
expectedError: nil,
},
{
name: "prefix expression, not operator, match",
query: "NOT (req.body = bar)",
requestLog: reqlog.RequestLog{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "string literal expression, match in request log",
query: "foo",
requestLog: reqlog.RequestLog{
Body: []byte("foo"),
},
expectedMatch: true,
expectedError: nil,
},
{
name: "string literal expression, no match",
query: "foo",
requestLog: reqlog.RequestLog{
Body: []byte("bar"),
},
expectedMatch: false,
expectedError: nil,
},
{
name: "string literal expression, match in response log",
query: "foo",
requestLog: reqlog.RequestLog{
Response: &reqlog.ResponseLog{
Body: []byte("foo"),
},
},
expectedMatch: true,
expectedError: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
searchExpr, err := search.ParseQuery(tt.query)
assertError(t, nil, err)
got, err := tt.requestLog.Matches(searchExpr)
assertError(t, tt.expectedError, err)
if tt.expectedMatch != got {
t.Errorf("expected match result: %v, got: %v", tt.expectedMatch, got)
}
})
}
}
func assertError(t *testing.T, exp, got error) {
t.Helper()
switch {
case exp == nil && got != nil:
t.Fatalf("expected: nil, got: %v", got)
case exp != nil && got == nil:
t.Fatalf("expected: %v, got: nil", exp.Error())
case exp != nil && got != nil && exp.Error() != got.Error():
t.Fatalf("expected: %v, got: %v", exp.Error(), got.Error())
}
}

View File

@ -1,8 +0,0 @@
package scope
import "context"
type Repository interface {
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
}

View File

@ -1,24 +1,16 @@
package scope
import (
"context"
"encoding/json"
"errors"
"fmt"
"bytes"
"encoding/gob"
"net/http"
"regexp"
"sync"
"github.com/dstotijn/hetty/pkg/proj"
)
const moduleName = "scope"
type Scope struct {
rules []Rule
repo Repository
mu sync.RWMutex
mu sync.RWMutex
}
type Rule struct {
@ -32,30 +24,6 @@ type Header struct {
Value *regexp.Regexp
}
func New(repo Repository, projService proj.Service) *Scope {
s := &Scope{
repo: repo,
}
projService.OnProjectOpen(func(_ string) error {
err := s.load(context.Background())
if errors.Is(err, proj.ErrNoSettings) {
return nil
}
if err != nil {
return fmt.Errorf("scope: could not load scope: %w", err)
}
return nil
})
projService.OnProjectClose(func(_ string) error {
s.unload()
return nil
})
return s
}
func (s *Scope) Rules() []Rule {
s.mu.RLock()
defer s.mu.RUnlock()
@ -63,41 +31,11 @@ func (s *Scope) Rules() []Rule {
return s.rules
}
func (s *Scope) load(ctx context.Context) error {
func (s *Scope) SetRules(rules []Rule) {
s.mu.Lock()
defer s.mu.Unlock()
var rules []Rule
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
if errors.Is(err, proj.ErrNoSettings) {
return err
} else if err != nil {
return fmt.Errorf("scope: could not load scope settings: %w", err)
}
s.rules = rules
return nil
}
func (s *Scope) unload() {
s.mu.Lock()
defer s.mu.Unlock()
s.rules = nil
}
func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
return fmt.Errorf("scope: cannot set rules in repository: %w", err)
}
s.rules = rules
return nil
}
func (s *Scope) Match(req *http.Request, body []byte) bool {
@ -158,48 +96,54 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
return false
}
// MarshalJSON implements json.Marshaler.
func (r Rule) MarshalJSON() ([]byte, error) {
type (
headerDTO struct {
Key string
Value string
}
ruleDTO struct {
URL string
Header headerDTO
Body string
}
)
dto := ruleDTO{
URL: regexpToString(r.URL),
Header: headerDTO{
Key: regexpToString(r.Header.Key),
Value: regexpToString(r.Header.Value),
},
Body: regexpToString(r.Body),
func regexpToString(r *regexp.Regexp) string {
if r == nil {
return ""
}
return json.Marshal(dto)
return r.String()
}
// UnmarshalJSON implements json.Unmarshaler.
func (r *Rule) UnmarshalJSON(data []byte) error {
type (
headerDTO struct {
Key string
Value string
}
ruleDTO struct {
URL string
Header headerDTO
Body string
}
)
func stringToRegexp(s string) (*regexp.Regexp, error) {
if s == "" {
return nil, nil
}
var dto ruleDTO
if err := json.Unmarshal(data, &dto); err != nil {
return regexp.Compile(s)
}
type ruleDTO struct {
URL string
Header struct {
Key string
Value string
}
Body string
}
func (r Rule) MarshalBinary() ([]byte, error) {
dto := ruleDTO{
URL: regexpToString(r.URL),
Body: regexpToString(r.Body),
}
dto.Header.Key = regexpToString(r.Header.Key)
dto.Header.Value = regexpToString(r.Header.Value)
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(dto)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (r *Rule) UnmarshalBinary(data []byte) error {
dto := ruleDTO{}
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&dto)
if err != nil {
return err
}
@ -234,19 +178,3 @@ func (r *Rule) UnmarshalJSON(data []byte) error {
return nil
}
func regexpToString(r *regexp.Regexp) string {
if r == nil {
return ""
}
return r.String()
}
func stringToRegexp(s string) (*regexp.Regexp, error) {
if s == "" {
return nil, nil
}
return regexp.Compile(s)
}

View File

@ -1,6 +1,10 @@
package search
import "strings"
import (
"encoding/gob"
"regexp"
"strings"
)
type Expression interface {
String() string
@ -11,7 +15,7 @@ type PrefixExpression struct {
Right Expression
}
func (pe *PrefixExpression) String() string {
func (pe PrefixExpression) String() string {
b := strings.Builder{}
b.WriteString("(")
b.WriteString(pe.Operator.String())
@ -28,7 +32,7 @@ type InfixExpression struct {
Right Expression
}
func (ie *InfixExpression) String() string {
func (ie InfixExpression) String() string {
b := strings.Builder{}
b.WriteString("(")
b.WriteString(ie.Left.String())
@ -45,6 +49,32 @@ type StringLiteral struct {
Value string
}
func (sl *StringLiteral) String() string {
func (sl StringLiteral) String() string {
return sl.Value
}
type RegexpLiteral struct {
*regexp.Regexp
}
func (rl RegexpLiteral) MarshalBinary() ([]byte, error) {
return []byte(rl.Regexp.String()), nil
}
func (rl *RegexpLiteral) UnmarshalBinary(data []byte) error {
re, err := regexp.Compile(string(data))
if err != nil {
return err
}
*rl = RegexpLiteral{re}
return nil
}
func init() {
gob.Register(PrefixExpression{})
gob.Register(InfixExpression{})
gob.Register(StringLiteral{})
gob.Register(RegexpLiteral{})
}

View File

@ -2,6 +2,7 @@ package search
import (
"fmt"
"regexp"
)
type precedence int
@ -99,7 +100,7 @@ func ParseQuery(input string) (expr Expression, err error) {
case expr == nil:
expr = right
default:
expr = &InfixExpression{
expr = InfixExpression{
Operator: TokOpAnd,
Left: expr,
Right: right,
@ -170,7 +171,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
}
func parsePrefixExpression(p *Parser) (Expression, error) {
expr := &PrefixExpression{
expr := PrefixExpression{
Operator: p.cur.Type,
}
@ -187,7 +188,7 @@ func parsePrefixExpression(p *Parser) (Expression, error) {
}
func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
expr := &InfixExpression{
expr := InfixExpression{
Operator: p.cur.Type,
Left: left,
}
@ -200,13 +201,24 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
}
if expr.Operator == TokOpRe || expr.Operator == TokOpNotRe {
if rightStr, ok := right.(StringLiteral); ok {
re, err := regexp.Compile(rightStr.Value)
if err != nil {
return nil, fmt.Errorf("could not compile regular expression %q: %w", rightStr.Value, err)
}
right = re
}
}
expr.Right = right
return expr, nil
}
func parseStringLiteral(p *Parser) (Expression, error) {
return &StringLiteral{Value: p.cur.Literal}, nil
return StringLiteral{Value: p.cur.Literal}, nil
}
func parseGroupedExpression(p *Parser) (Expression, error) {
@ -227,7 +239,7 @@ func parseGroupedExpression(p *Parser) (Expression, error) {
return nil, fmt.Errorf("could not parse expression: %w", err)
}
expr = &InfixExpression{
expr = InfixExpression{
Operator: TokOpAnd,
Left: expr,
Right: right,

View File

@ -3,6 +3,7 @@ package search
import (
"errors"
"reflect"
"regexp"
"testing"
)
@ -24,101 +25,101 @@ func TestParseQuery(t *testing.T) {
{
name: "string literal expression",
input: "foobar",
expectedExpression: &StringLiteral{Value: "foobar"},
expectedExpression: StringLiteral{Value: "foobar"},
expectedError: nil,
},
{
name: "boolean expression with equal operator",
input: "foo = bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with not equal operator",
input: "foo != bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpNotEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with greater than operator",
input: "foo > bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpGt,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with less than operator",
input: "foo < bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpLt,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with greater than or equal operator",
input: "foo >= bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpGtEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with less than or equal operator",
input: "foo <= bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpLtEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "boolean expression with regular expression operator",
input: "foo =~ bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpRe,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: regexp.MustCompile("bar"),
},
expectedError: nil,
},
{
name: "boolean expression with not regular expression operator",
input: "foo !~ bar",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpNotRe,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: regexp.MustCompile("bar"),
},
expectedError: nil,
},
{
name: "boolean expression with AND, OR and NOT operators",
input: "foo AND bar OR NOT baz",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpAnd,
Left: &StringLiteral{Value: "foo"},
Right: &InfixExpression{
Left: StringLiteral{Value: "foo"},
Right: InfixExpression{
Operator: TokOpOr,
Left: &StringLiteral{Value: "bar"},
Right: &PrefixExpression{
Left: StringLiteral{Value: "bar"},
Right: PrefixExpression{
Operator: TokOpNot,
Right: &StringLiteral{Value: "baz"},
Right: StringLiteral{Value: "baz"},
},
},
},
@ -127,16 +128,16 @@ func TestParseQuery(t *testing.T) {
{
name: "boolean expression with nested group",
input: "(foo AND bar) OR NOT baz",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpOr,
Left: &InfixExpression{
Left: InfixExpression{
Operator: TokOpAnd,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
Right: &PrefixExpression{
Right: PrefixExpression{
Operator: TokOpNot,
Right: &StringLiteral{Value: "baz"},
Right: StringLiteral{Value: "baz"},
},
},
expectedError: nil,
@ -144,59 +145,59 @@ func TestParseQuery(t *testing.T) {
{
name: "implicit boolean expression with string literal operands",
input: "foo bar baz",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpAnd,
Left: &InfixExpression{
Left: InfixExpression{
Operator: TokOpAnd,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
Right: &StringLiteral{Value: "baz"},
Right: StringLiteral{Value: "baz"},
},
expectedError: nil,
},
{
name: "implicit boolean expression nested in group",
input: "(foo bar)",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpAnd,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
expectedError: nil,
},
{
name: "implicit and explicit boolean expression with string literal operands",
input: "foo bar OR baz yolo",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpAnd,
Left: &InfixExpression{
Left: InfixExpression{
Operator: TokOpAnd,
Left: &StringLiteral{Value: "foo"},
Right: &InfixExpression{
Left: StringLiteral{Value: "foo"},
Right: InfixExpression{
Operator: TokOpOr,
Left: &StringLiteral{Value: "bar"},
Right: &StringLiteral{Value: "baz"},
Left: StringLiteral{Value: "bar"},
Right: StringLiteral{Value: "baz"},
},
},
Right: &StringLiteral{Value: "yolo"},
Right: StringLiteral{Value: "yolo"},
},
expectedError: nil,
},
{
name: "implicit boolean expression with comparison operands",
input: "foo=bar baz=~yolo",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpAnd,
Left: &InfixExpression{
Left: InfixExpression{
Operator: TokOpEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
Right: &InfixExpression{
Right: InfixExpression{
Operator: TokOpRe,
Left: &StringLiteral{Value: "baz"},
Right: &StringLiteral{Value: "yolo"},
Left: StringLiteral{Value: "baz"},
Right: regexp.MustCompile("yolo"),
},
},
expectedError: nil,
@ -204,17 +205,17 @@ func TestParseQuery(t *testing.T) {
{
name: "eq operator takes precedence over boolean ops",
input: "foo=bar OR baz=yolo",
expectedExpression: &InfixExpression{
expectedExpression: InfixExpression{
Operator: TokOpOr,
Left: &InfixExpression{
Left: InfixExpression{
Operator: TokOpEq,
Left: &StringLiteral{Value: "foo"},
Right: &StringLiteral{Value: "bar"},
Left: StringLiteral{Value: "foo"},
Right: StringLiteral{Value: "bar"},
},
Right: &InfixExpression{
Right: InfixExpression{
Operator: TokOpEq,
Left: &StringLiteral{Value: "baz"},
Right: &StringLiteral{Value: "yolo"},
Left: StringLiteral{Value: "baz"},
Right: StringLiteral{Value: "yolo"},
},
},
expectedError: nil,