Replace BadgerDB with bbolt

This commit is contained in:
David Stotijn
2025-01-13 23:15:18 +01:00
parent 24c2ecfa4b
commit fcf0e1c51e
23 changed files with 1153 additions and 1244 deletions

60
pkg/db/bolt/bolt.go Normal file
View File

@ -0,0 +1,60 @@
package bolt
import (
"fmt"
bolt "go.etcd.io/bbolt"
)
// Database is used to store and retrieve data from an underlying Bolt database.
type Database struct {
bolt *bolt.DB
}
// OpenDatabase opens a new Bolt database.
func OpenDatabase(path string, opts *bolt.Options) (*Database, error) {
db, err := bolt.Open(path, 0o600, opts)
if err != nil {
return nil, fmt.Errorf("bolt: failed to open database: %w", err)
}
return DatabaseFromBoltDB(db)
}
// Close closes the underlying Bolt database.
func (db *Database) Close() error {
return db.bolt.Close()
}
// DatabaseFromBoltDB returns a Database with `db` set as the underlying Bolt
// database.
func DatabaseFromBoltDB(db *bolt.DB) (*Database, error) {
err := db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(projectsBucketName)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("bolt: failed to create projects bucket: %w", err)
}
return &Database{bolt: db}, nil
}
func createNestedBucket(tx *bolt.Tx, names ...[]byte) (b *bolt.Bucket, err error) {
for i, name := range names {
if b == nil {
b, err = tx.CreateBucketIfNotExists(name)
} else {
b, err = b.CreateBucketIfNotExists(name)
}
if err != nil {
return nil, fmt.Errorf("bolt: failed to create nested bucket %q: %w", names[:i+1], err)
}
}
return b, nil
}

23
pkg/db/bolt/logger.go Normal file
View File

@ -0,0 +1,23 @@
package bolt
import (
bolt "go.etcd.io/bbolt"
"go.uber.org/zap"
)
// Interface guard.
var _ bolt.Logger = (*Logger)(nil)
type Logger struct {
*zap.SugaredLogger
}
// Warning implements bbolt.Logger.
func (l *Logger) Warning(v ...interface{}) {
l.Warn(v...)
}
// Warningf implements bbolt.Logger.
func (l *Logger) Warningf(format string, v ...interface{}) {
l.Warnf(format, v...)
}

176
pkg/db/bolt/proj.go Normal file
View File

@ -0,0 +1,176 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/proj"
)
var (
ErrProjectsBucketNotFound = errors.New("bolt: projects bucket not found")
ErrProjectBucketNotFound = errors.New("bolt: project bucket not found")
)
var (
projectsBucketName = []byte("projects")
projectKey = []byte("project")
)
func projectsBucket(tx *bolt.Tx) (*bolt.Bucket, error) {
b := tx.Bucket(projectsBucketName)
if b == nil {
return nil, ErrProjectsBucketNotFound
}
return b, nil
}
func projectBucket(tx *bolt.Tx, projectID []byte) (*bolt.Bucket, error) {
pb, err := projectsBucket(tx)
if err != nil {
return nil, err
}
b := pb.Bucket(projectID[:])
if b == nil {
return nil, ErrProjectBucketNotFound
}
return b, nil
}
func (db *Database) UpsertProject(ctx context.Context, project proj.Project) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(project)
if err != nil {
return fmt.Errorf("bolt: failed to encode project: %w", err)
}
err = db.bolt.Update(func(tx *bolt.Tx) error {
b, err := createNestedBucket(tx, projectsBucketName, project.ID[:])
if err != nil {
return fmt.Errorf("bolt: failed to create project bucket: %w", err)
}
err = b.Put(projectKey, buf.Bytes())
if err != nil {
return fmt.Errorf("bolt: failed to upsert project: %w", err)
}
_, err = b.CreateBucketIfNotExists(reqLogsBucketName)
if err != nil {
return fmt.Errorf("bolt: failed to create request logs bucket: %w", err)
}
_, err = b.CreateBucketIfNotExists(senderReqsBucketName)
if err != nil {
return fmt.Errorf("bolt: failed to create sender requests bucket: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) FindProjectByID(ctx context.Context, projectID ulid.ULID) (project proj.Project, err error) {
err = db.bolt.View(func(tx *bolt.Tx) error {
bucket, err := projectBucket(tx, projectID[:])
if errors.Is(err, ErrProjectsBucketNotFound) || errors.Is(err, ErrProjectBucketNotFound) {
return proj.ErrProjectNotFound
}
if err != nil {
return err
}
rawProject := bucket.Get(projectKey)
if rawProject == nil {
return proj.ErrProjectNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
if err != nil {
return fmt.Errorf("failed to decode project: %w", err)
}
return nil
})
if err != nil {
return proj.Project{}, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return project, nil
}
func (db *Database) DeleteProject(ctx context.Context, projectID ulid.ULID) error {
err := db.bolt.Update(func(tx *bolt.Tx) error {
pb, err := projectsBucket(tx)
if err != nil {
return err
}
err = pb.DeleteBucket(projectID[:])
if err != nil {
return fmt.Errorf("failed to delete project bucket: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) Projects(ctx context.Context) ([]proj.Project, error) {
projects := make([]proj.Project, 0)
err := db.bolt.View(func(tx *bolt.Tx) error {
pb, err := projectsBucket(tx)
if err != nil {
return err
}
err = pb.ForEachBucket(func(projectID []byte) error {
bucket, err := projectBucket(tx, projectID)
if err != nil {
return err
}
rawProject := bucket.Get(projectKey)
if rawProject == nil {
return proj.ErrProjectNotFound
}
var project proj.Project
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&project)
if err != nil {
return fmt.Errorf("bolt: failed to decode project: %w", err)
}
projects = append(projects, project)
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to iterate over projects: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return projects, nil
}

267
pkg/db/bolt/proj_test.go Normal file
View File

@ -0,0 +1,267 @@
package bolt_test
import (
"bytes"
"context"
"encoding/gob"
"errors"
"math/rand"
"regexp"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/oklog/ulid"
"go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/filter"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/scope"
)
//nolint:gosec
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
var regexpCompareOpt = cmp.Comparer(func(x, y *regexp.Regexp) bool {
switch {
case x == nil && y == nil:
return true
case x == nil || y == nil:
return false
default:
return x.String() == y.String()
}
})
func TestUpsertProject(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
searchExpr, err := filter.ParseQuery("foo AND bar OR NOT baz")
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "foobar",
Settings: proj.Settings{
ReqLogBypassOutOfScope: true,
ReqLogOnlyFindInScope: true,
ReqLogSearchExpr: searchExpr,
ScopeRules: []scope.Rule{
{
URL: regexp.MustCompile("^https://(.*)example.com(.*)$"),
Header: scope.Header{
Key: regexp.MustCompile("^X-Foo(.*)$"),
Value: regexp.MustCompile("^foo(.*)$"),
},
Body: regexp.MustCompile("^foo(.*)"),
},
},
},
}
err = db.UpsertProject(context.Background(), exp)
if err != nil {
t.Fatalf("unexpected error storing project: %v", err)
}
var rawProject []byte
err = boltDB.View(func(tx *bbolt.Tx) error {
rawProject = tx.Bucket([]byte("projects")).Bucket(exp.ID[:]).Get([]byte("project"))
return nil
})
if err != nil {
t.Fatalf("unexpected error retrieving project from database: %v", err)
}
if rawProject == nil {
t.Fatalf("expected raw project to be retrieved, got: nil")
}
got := proj.Project{}
err = gob.NewDecoder(bytes.NewReader(rawProject)).Decode(&got)
if err != nil {
t.Fatalf("unexpected error decoding project: %v", err)
}
if diff := cmp.Diff(exp, got, regexpCompareOpt, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
}
func TestFindProjectByID(t *testing.T) {
t.Parallel()
t.Run("existing project", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
exp := proj.Project{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
}
buf := bytes.Buffer{}
err = gob.NewEncoder(&buf).Encode(exp)
if err != nil {
t.Fatalf("unexpected error encoding project: %v", err)
}
err = boltDB.Update(func(tx *bbolt.Tx) error {
b, err := tx.Bucket([]byte("projects")).CreateBucket(exp.ID[:])
if err != nil {
return err
}
return b.Put([]byte("project"), buf.Bytes())
})
if err != nil {
t.Fatalf("unexpected error setting project: %v", err)
}
got, err := db.FindProjectByID(context.Background(), exp.ID)
if err != nil {
t.Fatalf("unexpected error finding project: %v", err)
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("project not equal (-exp, +got):\n%v", diff)
}
})
t.Run("project not found", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
_, err = db.FindProjectByID(context.Background(), projectID)
if !errors.Is(err, proj.ErrProjectNotFound) {
t.Fatalf("expected `proj.ErrProjectNotFound`, got: %v", err)
}
})
}
func TestDeleteProject(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
// Insert test fixture.
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
})
if err != nil {
t.Fatalf("unexpected error storing project: %v", err)
}
err = db.DeleteProject(context.Background(), projectID)
if err != nil {
t.Fatalf("unexpected error deleting project: %v", err)
}
var got *bbolt.Bucket
err = boltDB.View(func(tx *bbolt.Tx) error {
got = tx.Bucket([]byte("projects")).Bucket(projectID[:])
return nil
})
if got != nil {
t.Fatalf("expected bucket to be nil, got: %v", got)
}
}
func TestProjects(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
exp := []proj.Project{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
Name: "one",
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
Name: "two",
},
}
// Store fixtures.
for _, project := range exp {
err = db.UpsertProject(context.Background(), project)
if err != nil {
t.Fatalf("unexpected error creating project fixture: %v", err)
}
}
got, err := db.Projects(context.Background())
if err != nil {
t.Fatalf("unexpected error finding projects: %v", err)
}
if len(exp) != len(got) {
t.Fatalf("expected %v projects, got: %v", len(exp), len(got))
}
if diff := cmp.Diff(exp, got, cmpopts.IgnoreUnexported(proj.Project{})); diff != "" {
t.Fatalf("projects not equal (-exp, +got):\n%v", diff)
}
}

204
pkg/db/bolt/reqlog.go Normal file
View File

@ -0,0 +1,204 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
var ErrRequestLogsBucketNotFound = errors.New("bolt: request logs bucket not found")
var reqLogsBucketName = []byte("request_logs")
func requestLogsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID[:])
if err != nil {
return nil, err
}
b := pb.Bucket(reqLogsBucketName)
if b == nil {
return nil, ErrRequestLogsBucketNotFound
}
return b, nil
}
func (db *Database) FindRequestLogs(ctx context.Context, filter reqlog.FindRequestsFilter, scope *scope.Scope) (reqLogs []reqlog.RequestLog, err error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, reqlog.ErrProjectIDMustBeSet
}
tx, err := db.bolt.Begin(false)
if err != nil {
return nil, fmt.Errorf("bolt: failed to begin transaction: %w", err)
}
defer tx.Rollback()
b, err := requestLogsBucket(tx, filter.ProjectID)
if err != nil {
return nil, fmt.Errorf("bolt: failed to get request logs bucket: %w", err)
}
err = b.ForEach(func(reqLogID, rawReqLog []byte) error {
var reqLog reqlog.RequestLog
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
if filter.OnlyInScope && !reqLog.MatchScope(scope) {
return nil
}
// Filter by search expression. TODO: Once pagination is introduced,
// this filter logic should be done as items are retrieved.
if filter.SearchExpr != nil {
match, err := reqLog.Matches(filter.SearchExpr)
if err != nil {
return fmt.Errorf("failed to match search expression for request log (id: %v): %w", reqLogID, err)
}
if !match {
return nil
}
}
reqLogs = append(reqLogs, reqLog)
return nil
})
if err != nil {
return nil, fmt.Errorf("bolt: failed to iterate over request logs: %w", err)
}
// Reverse items, so newest requests appear first.
for i, j := 0, len(reqLogs)-1; i < j; i, j = i+1, j-1 {
reqLogs[i], reqLogs[j] = reqLogs[j], reqLogs[i]
}
return reqLogs, nil
}
func (db *Database) FindRequestLogByID(ctx context.Context, projectID, reqLogID ulid.ULID) (reqLog reqlog.RequestLog, err error) {
err = db.bolt.View(func(tx *bolt.Tx) error {
b, err := requestLogsBucket(tx, projectID)
if err != nil {
return fmt.Errorf("bolt: failed to get request logs bucket: %w", err)
}
rawReqLog := b.Get(reqLogID[:])
if rawReqLog == nil {
return reqlog.ErrRequestNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
return nil
})
if err != nil {
return reqlog.RequestLog{}, fmt.Errorf("bolt: failed to find request log by ID: %w", err)
}
return reqLog, nil
}
func (db *Database) StoreRequestLog(ctx context.Context, reqLog reqlog.RequestLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(reqLog)
if err != nil {
return fmt.Errorf("bolt: failed to encode request log: %w", err)
}
err = db.bolt.Update(func(txn *bolt.Tx) error {
b, err := requestLogsBucket(txn, reqLog.ProjectID)
if err != nil {
return fmt.Errorf("failed to get request logs bucket: %w", err)
}
err = b.Put(reqLog.ID[:], buf.Bytes())
if err != nil {
return fmt.Errorf("failed to put request log: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) StoreResponseLog(ctx context.Context, projectID, reqLogID ulid.ULID, resLog reqlog.ResponseLog) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(resLog)
if err != nil {
return fmt.Errorf("bolt: failed to encode response log: %w", err)
}
err = db.bolt.Update(func(txn *bolt.Tx) error {
b, err := requestLogsBucket(txn, projectID)
if err != nil {
return fmt.Errorf("failed to get request logs bucket: %w", err)
}
rawReqLog := b.Get(reqLogID[:])
if rawReqLog == nil {
return reqlog.ErrRequestNotFound
}
var reqLog reqlog.RequestLog
err = gob.NewDecoder(bytes.NewReader(rawReqLog)).Decode(&reqLog)
if err != nil {
return fmt.Errorf("failed to decode request log: %w", err)
}
reqLog.Response = &resLog
buf := bytes.Buffer{}
err = gob.NewEncoder(&buf).Encode(reqLog)
if err != nil {
return fmt.Errorf("failed to encode request log: %w", err)
}
err = b.Put(reqLog.ID[:], buf.Bytes())
if err != nil {
return fmt.Errorf("failed to put request log: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) ClearRequestLogs(ctx context.Context, projectID ulid.ULID) error {
err := db.bolt.Update(func(txn *bolt.Tx) error {
pb, err := projectBucket(txn, projectID[:])
if err != nil {
return fmt.Errorf("failed to get project bucket: %w", err)
}
return pb.DeleteBucket(reqLogsBucketName)
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}

141
pkg/db/bolt/reqlog_test.go Normal file
View File

@ -0,0 +1,141 @@
package bolt_test
import (
"context"
"errors"
"net/http"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
)
func TestFindRequestLogs(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
filter := reqlog.FindRequestsFilter{}
_, err = db.FindRequestLogs(context.Background(), filter, nil)
if !errors.Is(err, reqlog.ErrProjectIDMustBeSet) {
t.Fatalf("expected `reqlog.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns request logs and related response logs", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
})
if err != nil {
t.Fatalf("unexpected error upserting project: %v", err)
}
fixtures := []reqlog.RequestLog{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foobar"),
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
URL: mustParseURL(t, "https://example.com/foo?bar=baz"),
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
},
}
// Store fixtures.
for _, reqLog := range fixtures {
err = db.StoreRequestLog(context.Background(), reqLog)
if err != nil {
t.Fatalf("unexpected error creating request log fixture: %v", err)
}
}
filter := reqlog.FindRequestsFilter{
ProjectID: projectID,
}
got, err := db.FindRequestLogs(context.Background(), filter, nil)
if err != nil {
t.Fatalf("unexpected error finding request logs: %v", err)
}
// We expect the found request logs are *reversed*, e.g. newest first.
exp := make([]reqlog.RequestLog, len(fixtures))
for i, j := 0, len(fixtures)-1; i < j; i, j = i+1, j-1 {
exp[i], exp[j] = fixtures[j], fixtures[i]
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("request logs not equal (-exp, +got):\n%v", diff)
}
})
}
func mustParseURL(t *testing.T, s string) *url.URL {
t.Helper()
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}

172
pkg/db/bolt/sender.go Normal file
View File

@ -0,0 +1,172 @@
package bolt
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"github.com/oklog/ulid"
bolt "go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/sender"
)
var ErrSenderRequestsBucketNotFound = errors.New("bolt: sender requests bucket not found")
var senderReqsBucketName = []byte("sender_requests")
func senderReqsBucket(tx *bolt.Tx, projectID ulid.ULID) (*bolt.Bucket, error) {
pb, err := projectBucket(tx, projectID[:])
if err != nil {
return nil, err
}
b := pb.Bucket(senderReqsBucketName)
if b == nil {
return nil, ErrSenderRequestsBucketNotFound
}
return b, nil
}
func (db *Database) StoreSenderRequest(ctx context.Context, req sender.Request) error {
buf := bytes.Buffer{}
err := gob.NewEncoder(&buf).Encode(req)
if err != nil {
return fmt.Errorf("bolt: failed to encode sender request: %w", err)
}
err = db.bolt.Update(func(tx *bolt.Tx) error {
senderReqsBucket, err := senderReqsBucket(tx, req.ProjectID)
if err != nil {
return fmt.Errorf("failed to get sender requests bucket: %w", err)
}
err = senderReqsBucket.Put(req.ID[:], buf.Bytes())
if err != nil {
return fmt.Errorf("failed to put sender request: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}
func (db *Database) FindSenderRequestByID(ctx context.Context, projectID, senderReqID ulid.ULID) (req sender.Request, err error) {
if projectID.Compare(ulid.ULID{}) == 0 {
return sender.Request{}, sender.ErrProjectIDMustBeSet
}
err = db.bolt.View(func(tx *bolt.Tx) error {
senderReqsBucket, err := senderReqsBucket(tx, projectID)
if err != nil {
return fmt.Errorf("failed to get sender requests bucket: %w", err)
}
rawSenderReq := senderReqsBucket.Get(senderReqID[:])
if rawSenderReq == nil {
return sender.ErrRequestNotFound
}
err = gob.NewDecoder(bytes.NewReader(rawSenderReq)).Decode(&req)
if err != nil {
return fmt.Errorf("failed to decode sender request: %w", err)
}
return nil
})
if err != nil {
return sender.Request{}, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return req, nil
}
func (db *Database) FindSenderRequests(ctx context.Context, filter sender.FindRequestsFilter, scope *scope.Scope) (reqs []sender.Request, err error) {
if filter.ProjectID.Compare(ulid.ULID{}) == 0 {
return nil, sender.ErrProjectIDMustBeSet
}
tx, err := db.bolt.Begin(false)
if err != nil {
return nil, fmt.Errorf("bolt: failed to begin transaction: %w", err)
}
defer tx.Rollback()
b, err := senderReqsBucket(tx, filter.ProjectID)
if err != nil {
return nil, fmt.Errorf("failed to get sender requests bucket: %w", err)
}
err = b.ForEach(func(senderReqID, rawSenderReq []byte) error {
var req sender.Request
err = gob.NewDecoder(bytes.NewReader(rawSenderReq)).Decode(&req)
if err != nil {
return fmt.Errorf("failed to decode sender request: %w", err)
}
if filter.OnlyInScope {
if !req.MatchScope(scope) {
return nil
}
}
// Filter by search expression. TODO: Once pagination is introduced,
// this filter logic should be done as items are retrieved.
if filter.SearchExpr != nil {
match, err := req.Matches(filter.SearchExpr)
if err != nil {
return fmt.Errorf(
"bolt: failed to match search expression for sender request (id: %v): %w",
senderReqID, err,
)
}
if !match {
return nil
}
}
reqs = append(reqs, req)
return nil
})
if err != nil {
return nil, fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
// Reverse items, so newest requests appear first.
for i, j := 0, len(reqs)-1; i < j; i, j = i+1, j-1 {
reqs[i], reqs[j] = reqs[j], reqs[i]
}
return reqs, nil
}
func (db *Database) DeleteSenderRequests(ctx context.Context, projectID ulid.ULID) error {
err := db.bolt.Update(func(tx *bolt.Tx) error {
senderReqsBucket, err := senderReqsBucket(tx, projectID)
if err != nil {
return fmt.Errorf("failed to get sender requests bucket: %w", err)
}
err = senderReqsBucket.DeleteBucket(senderReqsBucketName)
if err != nil {
return fmt.Errorf("failed to delete sender requests bucket: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("bolt: failed to commit transaction: %w", err)
}
return nil
}

226
pkg/db/bolt/sender_test.go Normal file
View File

@ -0,0 +1,226 @@
package bolt_test
import (
"context"
"errors"
"net/http"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/oklog/ulid"
"go.etcd.io/bbolt"
"github.com/dstotijn/hetty/pkg/db/bolt"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/sender"
)
var exampleURL = func() *url.URL {
u, err := url.Parse("https://example.com/foobar")
if err != nil {
panic(err)
}
return u
}()
func TestFindRequestByID(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
defer boltDB.Close()
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
})
if err != nil {
t.Fatalf("unexpected error upserting project: %v", err)
}
// See: https://go.dev/blog/subtests#cleaning-up-after-a-group-of-parallel-tests
t.Run("group", func(t *testing.T) {
t.Run("sender request not found", func(t *testing.T) {
t.Parallel()
_, err := db.FindSenderRequestByID(context.Background(), projectID, reqID)
if !errors.Is(err, sender.ErrRequestNotFound) {
t.Fatalf("expected `sender.ErrRequestNotFound`, got: %v", err)
}
})
t.Run("sender request found", func(t *testing.T) {
t.Parallel()
exp := sender.Request{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: sender.HTTPProto20,
Header: http.Header{
"X-Foo": []string{"bar"},
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/2.0",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
},
}
err := db.StoreSenderRequest(context.Background(), exp)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
got, err := db.FindSenderRequestByID(context.Background(), exp.ProjectID, exp.ID)
if err != nil {
t.Fatalf("unexpected error (expected: nil, got: %v)", err)
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender request not equal (-exp, +got):\n%v", diff)
}
})
})
}
func TestFindSenderRequests(t *testing.T) {
t.Parallel()
t.Run("without project ID in filter", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
defer boltDB.Close()
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
filter := sender.FindRequestsFilter{}
_, err = db.FindSenderRequests(context.Background(), filter, nil)
if !errors.Is(err, sender.ErrProjectIDMustBeSet) {
t.Fatalf("expected `sender.ErrProjectIDMustBeSet`, got: %v", err)
}
})
t.Run("returns sender requests and related response logs", func(t *testing.T) {
t.Parallel()
path := t.TempDir() + "bolt.db"
boltDB, err := bbolt.Open(path, 0o600, nil)
if err != nil {
t.Fatalf("failed to open bolt database: %v", err)
}
defer boltDB.Close()
db, err := bolt.DatabaseFromBoltDB(boltDB)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
projectID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
err = db.UpsertProject(context.Background(), proj.Project{
ID: projectID,
Name: "foobar",
Settings: proj.Settings{},
})
if err != nil {
t.Fatalf("unexpected error creating project (expected: nil, got: %v)", err)
}
fixtures := []sender.Request{
{
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodPost,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
Body: []byte("foo"),
Response: &reqlog.ResponseLog{
Proto: "HTTP/1.1",
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Yolo": []string{"swag"},
},
Body: []byte("bar"),
},
},
{
ID: ulid.MustNew(ulid.Timestamp(time.Now())+100, ulidEntropy),
ProjectID: projectID,
SourceRequestLogID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
URL: exampleURL,
Method: http.MethodGet,
Proto: "HTTP/1.1",
Header: http.Header{
"X-Foo": []string{"baz"},
},
},
}
// Store fixtures.
for _, senderReq := range fixtures {
err = db.StoreSenderRequest(context.Background(), senderReq)
if err != nil {
t.Fatalf("unexpected error creating request log fixture: %v", err)
}
}
filter := sender.FindRequestsFilter{
ProjectID: projectID,
}
got, err := db.FindSenderRequests(context.Background(), filter, nil)
if err != nil {
t.Fatalf("unexpected error finding sender requests: %v", err)
}
// We expect the found sender requests are *reversed*, e.g. newest first.
exp := make([]sender.Request, len(fixtures))
for i, j := 0, len(fixtures)-1; i < j; i, j = i+1, j-1 {
exp[i], exp[j] = fixtures[j], fixtures[i]
}
if diff := cmp.Diff(exp, got); diff != "" {
t.Fatalf("sender requests not equal (-exp, +got):\n%v", diff)
}
})
}