Replace GraphQL server with Connect RPC

This commit is contained in:
David Stotijn
2025-02-05 21:54:59 +01:00
parent 52c83a1989
commit 6889c9c183
53 changed files with 5875 additions and 11685 deletions

View File

@ -1,14 +1,12 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"google.golang.org/protobuf/proto"
"github.com/dstotijn/hetty/pkg/proj"
)
@ -32,13 +30,13 @@ func projectsBucket(tx *bolt.Tx) (*bolt.Bucket, error) {
return b, nil
}
func projectBucket(tx *bolt.Tx, projectID []byte) (*bolt.Bucket, error) {
func projectBucket(tx *bolt.Tx, projectID string) (*bolt.Bucket, error) {
pb, err := projectsBucket(tx)
if err != nil {
return nil, err
}
b := pb.Bucket(projectID[:])
b := pb.Bucket([]byte(projectID))
if b == nil {
return nil, ErrProjectBucketNotFound
}
@ -46,21 +44,19 @@ func projectBucket(tx *bolt.Tx, projectID []byte) (*bolt.Bucket, error) {
return b, nil
}
func (db *Database) UpsertProject(ctx context.Context, project proj.Project) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(project)
if err != nil {
return fmt.Errorf("bolt: failed to encode project: %w", err)
}
err = db.bolt.Update(func(tx *bolt.Tx) error {
b, err := createNestedBucket(tx, projectsBucketName, project.ID[:])
func (db *Database) UpsertProject(ctx context.Context, project *proj.Project) error {
err := db.bolt.Update(func(tx *bolt.Tx) error {
b, err := createNestedBucket(tx, projectsBucketName, []byte(project.Id))
if err != nil {
return fmt.Errorf("bolt: failed to create project bucket: %w", err)
}
err = b.Put(projectKey, buf.Bytes())
buf, err := proto.Marshal(project)
if err != nil {
return fmt.Errorf("bolt: failed to marshal project: %w", err)
}
err = b.Put(projectKey, buf)
if err != nil {
return fmt.Errorf("bolt: failed to upsert project: %w", err)
}
@ -84,9 +80,11 @@ func (db *Database) UpsertProject(ctx context.Context, project proj.Project) err
return nil
}
func (db *Database) FindProjectByID(ctx context.Context, projectID ulid.ULID) (project proj.Project, err error) {
err = db.bolt.View(func(tx *bolt.Tx) error {
bucket, err := projectBucket(tx, projectID[:])
func (db *Database) FindProjectByID(ctx context.Context, projectID string) (*proj.Project, error) {
project := &proj.Project{}
err := db.bolt.View(func(tx *bolt.Tx) error {
bucket, err := projectBucket(tx, projectID)
if errors.Is(err, ErrProjectsBucketNotFound) || errors.Is(err, ErrProjectBucketNotFound) {
return proj.ErrProjectNotFound
}
@ -99,28 +97,28 @@ func (db *Database) FindProjectByID(ctx context.Context, projectID ulid.ULID) (p
return proj.ErrProjectNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
err = proto.Unmarshal(rawProject, project)
if err != nil {
return fmt.Errorf("failed to decode project: %w", err)
return fmt.Errorf("failed to unmarshal project: %w", err)
}
return nil
})
if err != nil {
return proj.Project{}, fmt.Errorf("bolt: failed to commit transaction: %w", err)
return nil, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return project, nil
}
func (db *Database) DeleteProject(ctx context.Context, projectID ulid.ULID) error {
func (db *Database) DeleteProject(ctx context.Context, projectID string) error {
err := db.bolt.Update(func(tx *bolt.Tx) error {
pb, err := projectsBucket(tx)
if err != nil {
return err
}
err = pb.DeleteBucket(projectID[:])
err = pb.DeleteBucket([]byte(projectID))
if err != nil {
return fmt.Errorf("failed to delete project bucket: %w", err)
}
@ -134,8 +132,8 @@ func (db *Database) DeleteProject(ctx context.Context, projectID ulid.ULID) erro
return nil
}
func (db *Database) Projects(ctx context.Context) ([]proj.Project, error) {
projects := make([]proj.Project, 0)
func (db *Database) Projects(ctx context.Context) ([]*proj.Project, error) {
projects := make([]*proj.Project, 0)
err := db.bolt.View(func(tx *bolt.Tx) error {
pb, err := projectsBucket(tx)
@ -144,7 +142,7 @@ func (db *Database) Projects(ctx context.Context) ([]proj.Project, error) {
}
err = pb.ForEachBucket(func(projectID []byte) error {
bucket, err := projectBucket(tx, projectID)
bucket, err := projectBucket(tx, string(projectID))
if err != nil {
return err
}
@ -154,16 +152,16 @@ func (db *Database) Projects(ctx context.Context) ([]proj.Project, error) {
return proj.ErrProjectNotFound
}
var project proj.Project
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
project := &proj.Project{}
err = proto.Unmarshal(rawProject, project)
if err != nil {
return fmt.Errorf("bolt: failed to decode project: %w", err)
return fmt.Errorf("failed to unmarshal project: %w", err)
}
projects = append(projects, project)
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to iterate over projects: %w", err)
return fmt.Errorf("failed to iterate over projects: %w", err)
}
return nil

View File

@ -1,40 +1,21 @@
package bolt_test
import (
"bytes"
"context"
"encoding/gob"
"errors"
"math/rand"
"regexp"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/oklog/ulid"
"github.com/oklog/ulid/v2"
"go.etcd.io/bbolt"
"google.golang.org/protobuf/proto"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/filter"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/testutil"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var regexpCompareOpt = cmp.Comparer(func(x, y *regexp.Regexp) bool {
switch {
case x == nil && y == nil:
return true
case x == nil || y == nil:
return false
default:
return x.String() == y.String()
}
})
func TestUpsertProject(t *testing.T) {
t.Parallel()
@ -50,27 +31,20 @@ func TestUpsertProject(t *testing.T) {
}
defer db.Close()
searchExpr, err := filter.ParseQuery("foo AND bar OR NOT baz")
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "foobar",
Settings: proj.Settings{
ReqLogBypassOutOfScope: true,
ReqLogOnlyFindInScope: true,
ReqLogSearchExpr: searchExpr,
ScopeRules: []scope.Rule{
{
URL: regexp.MustCompile("^https://(.*)example.com(.*)$"),
Header: scope.Header{
Key: regexp.MustCompile("^X-Foo(.*)$"),
Value: regexp.MustCompile("^foo(.*)$"),
},
Body: regexp.MustCompile("^foo(.*)"),
},
exp := &proj.Project{
Id: "foobar-project-id",
Name: "foobar",
ReqLogBypassOutOfScope: true,
ReqLogFilter: &reqlog.RequestLogsFilter{
OnlyInScope: true,
SearchExpr: "foo AND bar OR NOT baz",
},
ScopeRules: []*scope.ScopeRule{
{
UrlRegexp: "^https://(.*)example.com(.*)$",
HeaderKeyRegexp: "^X-Foo(.*)$",
HeaderValueRegexp: "^foo(.*)$",
BodyRegexp: "^foo(.*)",
},
},
}
@ -83,7 +57,7 @@ func TestUpsertProject(t *testing.T) {
var rawProject []byte
err = boltDB.View(func(tx *bbolt.Tx) error {
rawProject = tx.Bucket([]byte("projects")).Bucket(exp.ID[:]).Get([]byte("project"))
rawProject = tx.Bucket([]byte("projects")).Bucket([]byte(exp.Id)).Get([]byte("project"))
return nil
})
if err != nil {
@ -93,16 +67,14 @@ func TestUpsertProject(t *testing.T) {
t.Fatalf("expected raw project to be retrieved, got: nil")
}
got := proj.Project{}
got := &proj.Project{}
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&got)
err = proto.Unmarshal(rawProject, got)
if err != nil {
t.Fatalf("unexpected error decoding project: %v", err)
}
if diff := cmp.Diff(exp, got, regexpCompareOpt, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoDiff(t, "project not equal", exp, got, "id")
}
func TestFindProjectByID(t *testing.T) {
@ -123,36 +95,32 @@ func TestFindProjectByID(t *testing.T) {
}
defer db.Close()
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
exp := &proj.Project{
Id: ulid.Make().String(),
}
buf := bytes.Buffer{}
err = gob.NewEncoder(&buf).Encode(exp)
buf, err := proto.Marshal(exp)
if err != nil {
t.Fatalf("unexpected error encoding project: %v", err)
}
err = boltDB.Update(func(tx *bbolt.Tx) error {
b, err := tx.Bucket([]byte("projects")).CreateBucket(exp.ID[:])
b, err := tx.Bucket([]byte("projects")).CreateBucket([]byte(exp.Id))
if err != nil {
return err
}
return b.Put([]byte("project"), buf.Bytes())
return b.Put([]byte("project"), buf)
})
if err != nil {
t.Fatalf("unexpected error setting project: %v", err)
}
got, err := db.FindProjectByID(context.Background(), exp.ID)
got, err := db.FindProjectByID(context.Background(), exp.Id)
if err != nil {
t.Fatalf("unexpected error finding project: %v", err)
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoDiff(t, "project not equal", exp, got)
})
t.Run("project not found", func(t *testing.T) {
@ -170,7 +138,7 @@ func TestFindProjectByID(t *testing.T) {
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
projectID := ulid.Make().String()
_, err = db.FindProjectByID(context.Background(), projectID)
if !errors.Is(err, proj.ErrProjectNotFound) {
@ -195,9 +163,9 @@ func TestDeleteProject(t *testing.T) {
defer db.Close()
// Insert test fixture.
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
projectID := ulid.Make().String()
err = db.UpsertProject(context.Background(), &proj.Project{
Id: projectID,
})
if err != nil {
t.Fatalf("unexpected error storing project: %v", err)
@ -209,8 +177,8 @@ func TestDeleteProject(t *testing.T) {
}
var got *bbolt.Bucket
err = boltDB.View(func(tx *bbolt.Tx) error {
got = tx.Bucket([]byte("projects")).Bucket(projectID[:])
_ = boltDB.View(func(tx *bbolt.Tx) error {
got = tx.Bucket([]byte("projects")).Bucket([]byte(projectID))
return nil
})
if got != nil {
@ -233,13 +201,13 @@ func TestProjects(t *testing.T) {
}
defer db.Close()
exp := []proj.Project{
exp := []*proj.Project{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Id: ulid.Make().String(),
Name: "one",
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
Id: ulid.Make().String(),
Name: "two",
},
}
@ -261,7 +229,5 @@ func TestProjects(t *testing.T) {
t.Fatalf("expected %v projects, got: %v", len(exp), len(got))
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("projects not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoSlicesDiff(t, "projects not equal", exp, got)
}

View File

@ -1,25 +1,23 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"google.golang.org/protobuf/proto"
"github.com/dstotijn/hetty/pkg/http"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
var ErrRequestLogsBucketNotFound = errors.New("bolt: request logs bucket not found")
var reqLogsBucketName = []byte("request_logs")
func requestLogsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID[:])
func requestLogsBucket(tx *bolt.Tx, projectID string) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID)
if err != nil {
return nil, err
}
@ -32,47 +30,36 @@ func requestLogsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
return b, nil
}
func (db *Database) FindRequestLogs(ctx context.Context, filter reqlog.FindRequestsFilter, scope *scope.Scope) (reqLogs []reqlog.RequestLog, err error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, reqlog.ErrProjectIDMustBeSet
}
func (db *Database) FindRequestLogs(ctx context.Context, projectID string, filterFn func(*reqlog.HttpRequestLog) (bool, error)) (reqLogs []*reqlog.HttpRequestLog, err error) {
tx, err := db.bolt.Begin(false)
if err != nil {
return nil, fmt.Errorf("bolt: failed to begin transaction: %w", err)
}
defer tx.Rollback()
b, err := requestLogsBucket(tx, filter.ProjectID)
b, err := requestLogsBucket(tx, projectID)
if err != nil {
return nil, fmt.Errorf("bolt: failed to get request logs bucket: %w", err)
}
err = b.ForEach(func(reqLogID, rawReqLog []byte) error {
var reqLog reqlog.RequestLog
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
var reqLog reqlog.HttpRequestLog
err = proto.Unmarshal(rawReqLog, &reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
if filter.OnlyInScope && !reqLog.MatchScope(scope) {
return nil
}
// Filter by search expression. TODO: Once pagination is introduced,
// this filter logic should be done as items are retrieved.
if filter.SearchExpr != nil {
match, err := reqLog.Matches(filter.SearchExpr)
if filterFn != nil {
match, err := filterFn(&reqLog)
if err != nil {
return fmt.Errorf("failed to match search expression for request log (id: %v): %w", reqLogID, err)
return fmt.Errorf("failed to filter request log: %w", err)
}
if !match {
return nil
}
}
reqLogs = append(reqLogs, reqLog)
reqLogs = append(reqLogs, &reqLog)
return nil
})
if err != nil {
@ -87,46 +74,45 @@ func (db *Database) FindRequestLogs(ctx context.Context, filter reqlog.FindReque
return reqLogs, nil
}
func (db *Database) FindRequestLogByID(ctx context.Context, projectID, reqLogID ulid.ULID) (reqLog reqlog.RequestLog, err error) {
err = db.bolt.View(func(tx *bolt.Tx) error {
func (db *Database) FindRequestLogByID(ctx context.Context, projectID, reqLogID string) (*reqlog.HttpRequestLog, error) {
reqLog := &reqlog.HttpRequestLog{}
err := db.bolt.View(func(tx *bolt.Tx) error {
b, err := requestLogsBucket(tx, projectID)
if err != nil {
return fmt.Errorf("bolt: failed to get request logs bucket: %w", err)
}
rawReqLog := b.Get(reqLogID[:])
rawReqLog := b.Get([]byte(reqLogID))
if rawReqLog == nil {
return reqlog.ErrRequestNotFound
return reqlog.ErrRequestLogNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
err = proto.Unmarshal(rawReqLog, reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
return fmt.Errorf("failed to unmarshal request log: %w", err)
}
return nil
})
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("bolt: failed to find request log by ID: %w", err)
return nil, fmt.Errorf("bolt: failed to find request log by ID: %w", err)
}
return reqLog, nil
}
func (db *Database) StoreRequestLog(ctx context.Context, reqLog reqlog.RequestLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(reqLog)
func (db *Database) StoreRequestLog(ctx context.Context, reqLog *reqlog.HttpRequestLog) error {
encReqLog, err := proto.Marshal(reqLog)
if err != nil {
return fmt.Errorf("bolt: failed to encode request log: %w", err)
return fmt.Errorf("bolt: failed to marshal request log: %w", err)
}
err = db.bolt.Update(func(txn *bolt.Tx) error {
b, err := requestLogsBucket(txn, reqLog.ProjectID)
b, err := requestLogsBucket(txn, reqLog.ProjectId)
if err != nil {
return fmt.Errorf("failed to get request logs bucket: %w", err)
}
err = b.Put(reqLog.ID[:], buf.Bytes())
err = b.Put([]byte(reqLog.Id), encReqLog)
if err != nil {
return fmt.Errorf("failed to put request log: %w", err)
}
@ -140,40 +126,32 @@ func (db *Database) StoreRequestLog(ctx context.Context, reqLog reqlog.RequestLo
return nil
}
func (db *Database) StoreResponseLog(ctx context.Context, projectID, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(resLog)
if err != nil {
return fmt.Errorf("bolt: failed to encode response log: %w", err)
}
err = db.bolt.Update(func(txn *bolt.Tx) error {
func (db *Database) StoreResponseLog(ctx context.Context, projectID, reqLogID string, resLog *http.Response) error {
err := db.bolt.Update(func(txn *bolt.Tx) error {
b, err := requestLogsBucket(txn, projectID)
if err != nil {
return fmt.Errorf("failed to get request logs bucket: %w", err)
}
rawReqLog := b.Get(reqLogID[:])
if rawReqLog == nil {
return reqlog.ErrRequestNotFound
encReqLog := b.Get([]byte(reqLogID))
if encReqLog == nil {
return reqlog.ErrRequestLogNotFound
}
var reqLog reqlog.RequestLog
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
var reqLog reqlog.HttpRequestLog
err = proto.Unmarshal(encReqLog, &reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
reqLog.Response = &resLog
reqLog.Response = resLog
buf := bytes.Buffer{}
err = gob.NewEncoder(&buf).Encode(reqLog)
encReqLog, err = proto.Marshal(&reqLog)
if err != nil {
return fmt.Errorf("failed to encode request log: %w", err)
}
err = b.Put(reqLog.ID[:], buf.Bytes())
err = b.Put([]byte(reqLogID), encReqLog)
if err != nil {
return fmt.Errorf("failed to put request log: %w", err)
}
@ -187,9 +165,9 @@ func (db *Database) StoreResponseLog(ctx context.Context, projectID, reqLogID ul
return nil
}
func (db *Database) ClearRequestLogs(ctx context.Context, projectID ulid.ULID) error {
func (db *Database) ClearRequestLogs(ctx context.Context, projectID string) error {
err := db.bolt.Update(func(txn *bolt.Tx) error {
pb, err := projectBucket(txn, projectID[:])
pb, err := projectBucket(txn, projectID)
if err != nil {
return fmt.Errorf("failed to get project bucket: %w", err)
}

View File

@ -2,47 +2,21 @@ package bolt_test
import (
"context"
"errors"
"net/http"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/oklog/ulid/v2"
"go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/http"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/testutil"
)
func TestFindRequestLogs(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
filter := reqlog.FindRequestsFilter{}
_, err = db.FindRequestLogs(context.Background(), filter, nil)
if !errors.Is(err, reqlog.ErrProjectIDMustBeSet) {
t.Fatalf("expected `reqlog.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns request logs and related response logs", func(t *testing.T) {
t.Parallel()
@ -58,44 +32,56 @@ func TestFindRequestLogs(t *testing.T) {
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
projectID := ulid.Make().String()
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
err = db.UpsertProject(context.Background(), &proj.Project{
Id: projectID,
})
if err != nil {
t.Fatalf("unexpected error upserting project: %v", err)
}
fixtures := []reqlog.RequestLog{
fixtures := []*reqlog.HttpRequestLog{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foobar"),
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
Id: ulid.Make().String(),
ProjectId: projectID,
Request: &http.Request{
Url: "https://example.com/foobar",
Method: http.Method_METHOD_POST,
Protocol: http.Protocol_PROTOCOL_HTTP11,
Headers: []*http.Header{
{Key: "X-Foo", Value: "baz"},
},
Body: []byte("foo"),
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
Response: &http.Response{
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
Headers: []*http.Header{
{Key: "X-Yolo", Value: "swag"},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foo?bar=baz"),
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
Id: ulid.Make().String(),
ProjectId: projectID,
Request: &http.Request{
Url: "https://example.com/foo?bar=baz",
Method: http.Method_METHOD_GET,
Protocol: http.Protocol_PROTOCOL_HTTP11,
Headers: []*http.Header{
{Key: "X-Foo", Value: "baz"},
},
Body: []byte("foo"),
},
Response: &http.Response{
Status: "200 OK",
StatusCode: 200,
Headers: []*http.Header{
{Key: "X-Yolo", Value: "swag"},
},
Body: []byte("bar"),
},
},
}
@ -108,34 +94,17 @@ func TestFindRequestLogs(t *testing.T) {
}
}
filter := reqlog.FindRequestsFilter{
ProjectID: projectID,
}
got, err := db.FindRequestLogs(context.Background(), filter, nil)
got, err := db.FindRequestLogs(context.Background(), projectID, nil)
if err != nil {
t.Fatalf("unexpected error finding request logs: %v", err)
}
// We expect the found request logs are *reversed*, e.g. newest first.
exp := make([]reqlog.RequestLog, len(fixtures))
exp := make([]*reqlog.HttpRequestLog, len(fixtures))
for i, j := 0, len(fixtures)-1; i < j; i, j = i+1, j-1 {
exp[i], exp[j] = fixtures[j], fixtures[i]
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request logs not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoSlicesDiff(t, "request logs not equal", exp, got)
})
}
func mustParseURL(t *testing.T, s string) *url.URL {
t.Helper()
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}

View File

@ -1,16 +1,13 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"google.golang.org/protobuf/proto"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/sender"
)
@ -18,8 +15,8 @@ var ErrSenderRequestsBucketNotFound = errors.New("bolt: sender requests bucket n
var senderReqsBucketName = []byte("sender_requests")
func senderReqsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID[:])
func senderReqsBucket(tx *bolt.Tx, projectID string) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID)
if err != nil {
return nil, err
}
@ -32,21 +29,19 @@ func senderReqsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
return b, nil
}
func (db *Database) StoreSenderRequest(ctx context.Context, req sender.Request) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(req)
func (db *Database) StoreSenderRequest(ctx context.Context, req *sender.Request) error {
rawReq, err := proto.Marshal(req)
if err != nil {
return fmt.Errorf("bolt: failed to encode sender request: %w", err)
return fmt.Errorf("bolt: failed to marshal sender request: %w", err)
}
err = db.bolt.Update(func(tx *bolt.Tx) error {
senderReqsBucket, err := senderReqsBucket(tx, req.ProjectID)
senderReqsBucket, err := senderReqsBucket(tx, req.ProjectId)
if err != nil {
return fmt.Errorf("failed to get sender requests bucket: %w", err)
}
err = senderReqsBucket.Put(req.ID[:], buf.Bytes())
err = senderReqsBucket.Put([]byte(req.Id), rawReq)
if err != nil {
return fmt.Errorf("failed to put sender request: %w", err)
}
@ -60,9 +55,9 @@ func (db *Database) StoreSenderRequest(ctx context.Context, req sender.Request)
return nil
}
func (db *Database) FindSenderRequestByID(ctx context.Context, projectID, senderReqID ulid.ULID) (req sender.Request, err error) {
if projectID.Compare(ulid.ULID{}) == 0 {
return sender.Request{}, sender.ErrProjectIDMustBeSet
func (db *Database) FindSenderRequestByID(ctx context.Context, projectID, senderReqID string) (req *sender.Request, err error) {
if projectID == "" {
return nil, sender.ErrProjectIDMustBeSet
}
err = db.bolt.View(func(tx *bolt.Tx) error {
@ -71,63 +66,49 @@ func (db *Database) FindSenderRequestByID(ctx context.Context, projectID, sender
return fmt.Errorf("failed to get sender requests bucket: %w", err)
}
rawSenderReq := senderReqsBucket.Get(senderReqID[:])
rawSenderReq := senderReqsBucket.Get([]byte(senderReqID))
if rawSenderReq == nil {
return sender.ErrRequestNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawSenderReq)).Decode(&req)
req = &sender.Request{}
err = proto.Unmarshal(rawSenderReq, req)
if err != nil {
return fmt.Errorf("failed to decode sender request: %w", err)
return fmt.Errorf("failed to unmarshal sender request: %w", err)
}
return nil
})
if err != nil {
return sender.Request{}, fmt.Errorf("bolt: failed to commit transaction: %w", err)
return nil, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return req, nil
}
func (db *Database) FindSenderRequests(ctx context.Context, filter sender.FindRequestsFilter, scope *scope.Scope) (reqs []sender.Request, err error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, sender.ErrProjectIDMustBeSet
}
func (db *Database) FindSenderRequests(ctx context.Context, projectID string, filterFn func(req *sender.Request) (bool, error)) (reqs []*sender.Request, err error) {
tx, err := db.bolt.Begin(false)
if err != nil {
return nil, fmt.Errorf("bolt: failed to begin transaction: %w", err)
}
defer tx.Rollback()
b, err := senderReqsBucket(tx, filter.ProjectID)
b, err := senderReqsBucket(tx, projectID)
if err != nil {
return nil, fmt.Errorf("failed to get sender requests bucket: %w", err)
}
err = b.ForEach(func(senderReqID, rawSenderReq []byte) error {
var req sender.Request
err = gob.NewDecoder(bytes.NewReader(rawSenderReq)).Decode(&req)
req := &sender.Request{}
err = proto.Unmarshal(rawSenderReq, req)
if err != nil {
return fmt.Errorf("failed to decode sender request: %w", err)
return fmt.Errorf("failed to unmarshal sender request: %w", err)
}
if filter.OnlyInScope {
if !req.MatchScope(scope) {
return nil
}
}
// Filter by search expression. TODO: Once pagination is introduced,
// this filter logic should be done as items are retrieved.
if filter.SearchExpr != nil {
match, err := req.Matches(filter.SearchExpr)
if filterFn != nil {
match, err := filterFn(req)
if err != nil {
return fmt.Errorf(
"bolt: failed to match search expression for sender request (id: %v): %w",
senderReqID, err,
)
return fmt.Errorf("failed to filter sender request: %w", err)
}
if !match {
@ -150,7 +131,7 @@ func (db *Database) FindSenderRequests(ctx context.Context, filter sender.FindRe
return reqs, nil
}
func (db *Database) DeleteSenderRequests(ctx context.Context, projectID ulid.ULID) error {
func (db *Database) DeleteSenderRequests(ctx context.Context, projectID string) error {
err := db.bolt.Update(func(tx *bolt.Tx) error {
senderReqsBucket, err := senderReqsBucket(tx, projectID)
if err != nil {

View File

@ -3,19 +3,17 @@ package bolt_test
import (
"context"
"errors"
"net/http"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"github.com/oklog/ulid/v2"
"go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/http"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/sender"
"github.com/dstotijn/hetty/pkg/testutil"
)
var exampleURL = func() *url.URL {
@ -43,11 +41,11 @@ func TestFindRequestByID(t *testing.T) {
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
projectID := "foobar-project-id"
reqID := "foobar-req-id"
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
err = db.UpsertProject(context.Background(), &proj.Project{
Id: projectID,
})
if err != nil {
t.Fatalf("unexpected error upserting project: %v", err)
@ -67,24 +65,31 @@ func TestFindRequestByID(t *testing.T) {
t.Run("sender request found", func(t *testing.T) {
t.Parallel()
exp := sender.Request{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: sender.HTTPProto20,
Header: http.Header{
"X-Foo": []string{"bar"},
exp := &sender.Request{
Id: "foobar-sender-req-id",
ProjectId: projectID,
SourceRequestLogId: "foobar-req-log-id",
HttpRequest: &http.Request{
Url: exampleURL.String(),
Method: http.Method_METHOD_GET,
Protocol: http.Protocol_PROTOCOL_HTTP20,
Headers: []*http.Header{
{
Key: "X-Foo",
Value: "bar",
},
},
Body: []byte("foo"),
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/2.0",
HttpResponse: &http.Response{
Protocol: http.Protocol_PROTOCOL_HTTP20,
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
Headers: []*http.Header{
{
Key: "X-Yolo",
Value: "swag",
},
},
Body: []byte("bar"),
},
@ -95,14 +100,12 @@ func TestFindRequestByID(t *testing.T) {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
got, err := db.FindSenderRequestByID(context.Background(), exp.ProjectID, exp.ID)
got, err := db.FindSenderRequestByID(context.Background(), projectID, exp.Id)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender request not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoDiff(t, "sender request not equal", exp, got, "id")
})
})
}
@ -110,30 +113,6 @@ func TestFindRequestByID(t *testing.T) {
func TestFindSenderRequests(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
defer boltDB.Close()
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
filter := sender.FindRequestsFilter{}
_, err = db.FindSenderRequests(context.Background(), filter, nil)
if !errors.Is(err, sender.ErrProjectIDMustBeSet) {
t.Fatalf("expected `sender.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns sender requests and related response logs", func(t *testing.T) {
t.Parallel()
@ -150,48 +129,61 @@ func TestFindSenderRequests(t *testing.T) {
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
projectID := "foobar-project-id"
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
Name: "foobar",
Settings: proj.Settings{},
err = db.UpsertProject(context.Background(), &proj.Project{
Id: projectID,
Name: "foobar",
})
if err != nil {
t.Fatalf("unexpected error creating project (expected: nil, got: %v)", err)
}
fixtures := []sender.Request{
fixtures := []*sender.Request{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
Id: ulid.Make().String(),
ProjectId: projectID,
SourceRequestLogId: "foobar-req-log-id-1",
HttpRequest: &http.Request{
Url: exampleURL.String(),
Method: http.Method_METHOD_POST,
Protocol: http.Protocol_PROTOCOL_HTTP11,
Headers: []*http.Header{
{
Key: "X-Foo",
Value: "baz",
},
},
Body: []byte("foo"),
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
HttpResponse: &http.Response{
Protocol: http.Protocol_PROTOCOL_HTTP11,
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
Headers: []*http.Header{
{
Key: "X-Yolo",
Value: "swag",
},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
Id: ulid.Make().String(),
ProjectId: projectID,
SourceRequestLogId: "foobar-req-log-id-2",
HttpRequest: &http.Request{
Url: exampleURL.String(),
Method: http.Method_METHOD_GET,
Protocol: http.Protocol_PROTOCOL_HTTP11,
Headers: []*http.Header{
{
Key: "X-Foo",
Value: "baz",
},
},
Body: []byte("foo"),
},
},
}
@ -204,23 +196,17 @@ func TestFindSenderRequests(t *testing.T) {
}
}
filter := sender.FindRequestsFilter{
ProjectID: projectID,
}
got, err := db.FindSenderRequests(context.Background(), filter, nil)
got, err := db.FindSenderRequests(context.Background(), projectID, nil)
if err != nil {
t.Fatalf("unexpected error finding sender requests: %v", err)
}
// We expect the found sender requests are *reversed*, e.g. newest first.
exp := make([]sender.Request, len(fixtures))
exp := make([]*sender.Request, len(fixtures))
for i, j := 0, len(fixtures)-1; i < j; i, j = i+1, j-1 {
exp[i], exp[j] = fixtures[j], fixtures[i]
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender requests not equal (-exp, +got):\n%v", diff)
}
testutil.ProtoSlicesDiff(t, "sender requests not equal", exp, got)
})
}