Replace Cayley with SQLite3

This commit is contained in:
David Stotijn
2020-10-04 11:50:03 +02:00
parent d48f1f058d
commit ba7d88dfc5
22 changed files with 649 additions and 861 deletions

View File

@ -68,13 +68,13 @@ type ComplexityRoot struct {
}
Query struct {
HTTPRequestLog func(childComplexity int, id string) int
HTTPRequestLog func(childComplexity int, id int64) int
HTTPRequestLogs func(childComplexity int) int
}
}
type QueryResolver interface {
HTTPRequestLog(ctx context.Context, id string) (*HTTPRequestLog, error)
HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error)
HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error)
}
@ -215,7 +215,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Query.HTTPRequestLog(childComplexity, args["id"].(string)), true
return e.complexity.Query.HTTPRequestLog(childComplexity, args["id"].(int64)), true
case "Query.httpRequestLogs":
if e.complexity.Query.HTTPRequestLogs == nil {
@ -342,9 +342,9 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
func (ec *executionContext) field_Query_httpRequestLog_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 string
var arg0 int64
if tmp, ok := rawArgs["id"]; ok {
arg0, err = ec.unmarshalNID2string(ctx, tmp)
arg0, err = ec.unmarshalNID2int64(ctx, tmp)
if err != nil {
return nil, err
}
@ -486,9 +486,9 @@ func (ec *executionContext) _HttpRequestLog_id(ctx context.Context, field graphq
}
return graphql.Null
}
res := resTmp.(string)
res := resTmp.(int64)
fc.Result = res
return ec.marshalNID2string(ctx, field.Selections, res)
return ec.marshalNID2int64(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpRequestLog_url(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLog) (ret graphql.Marshaler) {
@ -752,9 +752,9 @@ func (ec *executionContext) _HttpResponseLog_requestId(ctx context.Context, fiel
}
return graphql.Null
}
res := resTmp.(string)
res := resTmp.(int64)
fc.Result = res
return ec.marshalNID2string(ctx, field.Selections, res)
return ec.marshalNID2int64(ctx, field.Selections, res)
}
func (ec *executionContext) _HttpResponseLog_proto(ctx context.Context, field graphql.CollectedField, obj *HTTPResponseLog) (ret graphql.Marshaler) {
@ -948,7 +948,7 @@ func (ec *executionContext) _Query_httpRequestLog(ctx context.Context, field gra
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().HTTPRequestLog(rctx, args["id"].(string))
return ec.resolvers.Query().HTTPRequestLog(rctx, args["id"].(int64))
})
if err != nil {
ec.Error(ctx, err)
@ -2670,12 +2670,12 @@ func (ec *executionContext) marshalNHttpRequestLog2ᚕgithubᚗcomᚋdstotijnᚋ
return ret
}
func (ec *executionContext) unmarshalNID2string(ctx context.Context, v interface{}) (string, error) {
return graphql.UnmarshalID(v)
func (ec *executionContext) unmarshalNID2int64(ctx context.Context, v interface{}) (int64, error) {
return graphql.UnmarshalInt64(v)
}
func (ec *executionContext) marshalNID2string(ctx context.Context, sel ast.SelectionSet, v string) graphql.Marshaler {
res := graphql.MarshalID(v)
func (ec *executionContext) marshalNID2int64(ctx context.Context, sel ast.SelectionSet, v int64) graphql.Marshaler {
res := graphql.MarshalInt64(v)
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")

View File

@ -15,7 +15,7 @@ type HTTPHeader struct {
}
type HTTPRequestLog struct {
ID string `json:"id"`
ID int64 `json:"id"`
URL string `json:"url"`
Method HTTPMethod `json:"method"`
Proto string `json:"proto"`
@ -26,7 +26,7 @@ type HTTPRequestLog struct {
}
type HTTPResponseLog struct {
RequestID string `json:"requestId"`
RequestID int64 `json:"requestId"`
Proto string `json:"proto"`
Status string `json:"status"`
StatusCode int `json:"statusCode"`

View File

@ -6,8 +6,6 @@ import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/dstotijn/hetty/pkg/reqlog"
)
@ -38,12 +36,8 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
return logs, nil
}
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id string) (*HTTPRequestLog, error) {
reqLogID, err := uuid.Parse(id)
if err != nil {
return nil, fmt.Errorf("invalid id: %v", err)
}
log, err := r.RequestLogService.FindRequestLogByID(ctx, reqLogID)
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
if err == reqlog.ErrRequestNotFound {
return nil, nil
}
@ -65,7 +59,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
}
log := HTTPRequestLog{
ID: req.ID.String(),
ID: req.ID,
URL: req.Request.URL.String(),
Proto: req.Request.Proto,
Method: method,
@ -91,7 +85,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
if req.Response != nil {
log.Response = &HTTPResponseLog{
RequestID: req.ID.String(),
RequestID: req.ID,
Proto: req.Response.Response.Proto,
Status: req.Response.Response.Status,
StatusCode: req.Response.Response.StatusCode,

View File

@ -1,64 +0,0 @@
package cayley
import (
"os"
"path/filepath"
"github.com/cayleygraph/cayley/clog"
"github.com/cayleygraph/cayley/graph"
hkv "github.com/hidal-go/hidalgo/kv"
"github.com/hidal-go/hidalgo/kv/bolt"
)
const Type = bolt.Name
func boltFilePath(path, filename string) string {
return filepath.Join(path, filename)
}
func boltCreate(path string, opt graph.Options) (hkv.KV, error) {
filename, err := opt.StringKey("filename", "indexes.bolt")
if err != nil {
return nil, err
}
err = os.MkdirAll(path, 0700)
if err != nil {
return nil, err
}
db, err := bolt.Open(boltFilePath(path, filename), nil)
if err != nil {
clog.Errorf("Error: couldn't create Bolt database: %v", err)
return nil, err
}
return db, nil
}
func boltOpen(path string, opt graph.Options) (hkv.KV, error) {
filename, err := opt.StringKey("filename", "indexes.bolt")
if err != nil {
return nil, err
}
db, err := bolt.Open(boltFilePath(path, filename), nil)
if err != nil {
clog.Errorf("Error, couldn't open! %v", err)
return nil, err
}
bdb := db.DB()
bdb.NoSync, err = opt.BoolKey("nosync", false)
if err != nil {
db.Close()
return nil, err
}
bdb.NoGrowSync = bdb.NoSync
if bdb.NoSync {
clog.Infof("Running in nosync mode")
}
return db, nil
}

View File

@ -1,325 +0,0 @@
package cayley
import (
"context"
"fmt"
"log"
"net/http"
"net/url"
"path"
"strings"
"sync"
"time"
"github.com/cayleygraph/cayley"
"github.com/cayleygraph/cayley/graph"
"github.com/cayleygraph/cayley/graph/kv"
cpath "github.com/cayleygraph/cayley/graph/path"
"github.com/cayleygraph/cayley/schema"
"github.com/cayleygraph/quad"
"github.com/cayleygraph/quad/voc"
"github.com/cayleygraph/quad/voc/rdf"
"github.com/google/uuid"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
type HTTPRequest struct {
rdfType struct{} `quad:"@type > hy:HTTPRequest"`
ID quad.IRI `quad:"@id"`
Proto string `quad:"hy:proto"`
URL string `quad:"hy:url"`
Method string `quad:"hy:method"`
Body string `quad:"hy:body,optional"`
Headers []HTTPHeader `quad:"hy:header"`
Timestamp time.Time `quad:"hy:timestamp"`
Response *HTTPResponse `quad:"hy:request < *,optional"`
}
type HTTPResponse struct {
rdfType struct{} `quad:"@type > hy:HTTPResponse"`
RequestID quad.IRI `quad:"hy:request"`
Proto string `quad:"hy:proto"`
Status string `quad:"hy:status"`
StatusCode int `quad:"hy:status_code"`
Headers []HTTPHeader `quad:"hy:header"`
Body string `quad:"hy:body,optional"`
Timestamp time.Time `quad:"hy:timestamp"`
}
type HTTPHeader struct {
rdfType struct{} `quad:"@type > hy:HTTPHeader"`
Key string `quad:"hy:key"`
Value string `quad:"hy:value,optional"`
}
type Database struct {
store *cayley.Handle
schema *schema.Config
mu sync.Mutex
}
func init() {
voc.RegisterPrefix("hy:", "https://hetty.xyz/")
schema.RegisterType(quad.IRI("hy:HTTPRequest"), HTTPRequest{})
schema.RegisterType(quad.IRI("hy:HTTPResponse"), HTTPResponse{})
schema.RegisterType(quad.IRI("hy:HTTPHeader"), HTTPHeader{})
kv.Register(Type, kv.Registration{
NewFunc: boltOpen,
InitFunc: boltCreate,
IsPersistent: true,
})
}
func NewDatabase(filename string) (*Database, error) {
dir, file := path.Split(filename)
if dir == "" {
dir = "."
}
opts := graph.Options{
"filename": file,
}
schemaCfg := schema.NewConfig()
schemaCfg.GenerateID = func(_ interface{}) quad.Value {
return quad.BNode(uuid.New().String())
}
// Initialize the database.
err := graph.InitQuadStore("bolt", dir, opts)
if err != nil && err != graph.ErrDatabaseExists {
return nil, fmt.Errorf("cayley: could not initialize database: %v", err)
}
// Open the database.
store, err := cayley.NewGraph("bolt", dir, opts)
if err != nil {
return nil, fmt.Errorf("cayley: could not open database: %v", err)
}
return &Database{
store: store,
schema: schemaCfg,
}, nil
}
func (db *Database) Close() error {
return db.store.Close()
}
func (db *Database) FindRequestLogs(ctx context.Context, opts reqlog.FindRequestsOptions, scope *scope.Scope) ([]reqlog.Request, error) {
db.mu.Lock()
defer db.mu.Unlock()
var reqLogs []reqlog.Request
var reqs []HTTPRequest
reqPath := cayley.StartPath(db.store, quad.IRI("hy:HTTPRequest")).In(quad.IRI(rdf.Type))
if opts.OmitOutOfScope {
var filterPath *cpath.Path
for _, rule := range scope.Rules() {
if rule.URL != nil {
if filterPath == nil {
filterPath = reqPath.Out(quad.IRI("hy:url")).Regex(rule.URL).In(quad.IRI("hy:url"))
} else {
filterPath = filterPath.Or(reqPath.Out(quad.IRI("hy:url")).Regex(rule.URL).In(quad.IRI("hy:url")))
}
}
}
if filterPath != nil {
reqPath = filterPath
}
}
err := reqPath.Iterate(ctx).EachValue(db.store, func(v quad.Value) {
var req HTTPRequest
if err := db.schema.LoadToDepth(ctx, db.store, &req, 0, v); err != nil {
log.Printf("[ERROR] Could not load sub-graph for http requests: %v", err)
return
}
reqs = append(reqs, req)
})
if err != nil {
return nil, fmt.Errorf("cayley: could not iterate over http requests: %v", err)
}
for _, req := range reqs {
reqLog, err := parseRequestQuads(req, nil)
if err != nil {
return nil, fmt.Errorf("cayley: could not parse request quads (id: %v): %v", req.ID, err)
}
reqLogs = append(reqLogs, reqLog)
}
// By default, all retrieved requests are ordered chronologically, oldest first.
// Reverse the order, so newest logs are first.
for i := len(reqLogs)/2 - 1; i >= 0; i-- {
opp := len(reqLogs) - 1 - i
reqLogs[i], reqLogs[opp] = reqLogs[opp], reqLogs[i]
}
return reqLogs, nil
}
func (db *Database) FindRequestLogByID(ctx context.Context, id uuid.UUID) (reqlog.Request, error) {
db.mu.Lock()
defer db.mu.Unlock()
var req HTTPRequest
err := db.schema.LoadTo(ctx, db.store, &req, iriFromUUID(id))
if schema.IsNotFound(err) {
return reqlog.Request{}, reqlog.ErrRequestNotFound
}
if err != nil {
return reqlog.Request{}, fmt.Errorf("cayley: could not load value: %v", err)
}
reqLog, err := parseRequestQuads(req, nil)
if err != nil {
return reqlog.Request{}, fmt.Errorf("cayley: could not parse request log (id: %v): %v", req.ID, err)
}
return reqLog, nil
}
func (db *Database) AddRequestLog(ctx context.Context, reqLog reqlog.Request) error {
db.mu.Lock()
defer db.mu.Unlock()
httpReq := HTTPRequest{
ID: iriFromUUID(reqLog.ID),
Proto: reqLog.Request.Proto,
Method: reqLog.Request.Method,
URL: reqLog.Request.URL.String(),
Headers: httpHeadersSliceFromMap(reqLog.Request.Header),
Body: string(reqLog.Body),
Timestamp: reqLog.Timestamp,
}
tx := cayley.NewTransaction()
qw := graph.NewTxWriter(tx, graph.Add)
_, err := db.schema.WriteAsQuads(qw, httpReq)
if err != nil {
return fmt.Errorf("cayley: could not write quads: %v", err)
}
if err := db.store.ApplyTransaction(tx); err != nil {
return fmt.Errorf("cayley: could not apply transaction: %v", err)
}
return nil
}
func (db *Database) AddResponseLog(ctx context.Context, resLog reqlog.Response) error {
db.mu.Lock()
defer db.mu.Unlock()
httpRes := HTTPResponse{
RequestID: iriFromUUID(resLog.RequestID),
Proto: resLog.Response.Proto,
Status: resLog.Response.Status,
StatusCode: resLog.Response.StatusCode,
Headers: httpHeadersSliceFromMap(resLog.Response.Header),
Body: string(resLog.Body),
Timestamp: resLog.Timestamp,
}
tx := cayley.NewTransaction()
qw := graph.NewTxWriter(tx, graph.Add)
_, err := db.schema.WriteAsQuads(qw, httpRes)
if err != nil {
return fmt.Errorf("cayley: could not write response quads: %v", err)
}
if err := db.store.ApplyTransaction(tx); err != nil {
return fmt.Errorf("cayley: could not apply transaction: %v", err)
}
return nil
}
func iriFromUUID(id uuid.UUID) quad.IRI {
return quad.IRI("hy:" + id.String()).Full().Short()
}
func uuidFromIRI(iri quad.IRI) (uuid.UUID, error) {
iriString := iri.Short().String()
stripped := strings.TrimRight(strings.TrimLeft(iriString, "<hy:"), ">")
id, err := uuid.Parse(stripped)
if err != nil {
return uuid.Nil, err
}
return id, nil
}
func httpHeadersSliceFromMap(hm http.Header) []HTTPHeader {
if hm == nil {
return nil
}
var hs []HTTPHeader
for key, values := range hm {
for _, value := range values {
hs = append(hs, HTTPHeader{Key: key, Value: value})
}
}
return hs
}
func httpHeadersMapFromSlice(hs []HTTPHeader) http.Header {
if hs == nil {
return nil
}
hm := make(http.Header)
for _, header := range hs {
hm.Add(header.Key, header.Value)
}
return hm
}
func parseRequestQuads(req HTTPRequest, _ *HTTPResponse) (reqlog.Request, error) {
reqID, err := uuidFromIRI(req.ID)
if err != nil {
return reqlog.Request{}, fmt.Errorf("cannot parse request id: %v", err)
}
u, err := url.Parse(req.URL)
if err != nil {
return reqlog.Request{}, fmt.Errorf("cannot parse request url: %v", err)
}
reqLog := reqlog.Request{
ID: reqID,
Request: http.Request{
Method: req.Method,
URL: u,
Proto: req.Proto,
Header: httpHeadersMapFromSlice(req.Headers),
},
Timestamp: req.Timestamp,
}
if req.Body != "" {
reqLog.Body = []byte(reqLog.Body)
}
if req.Response != nil {
reqLog.Response = &reqlog.Response{
RequestID: reqID,
Response: http.Response{
Proto: req.Response.Proto,
Status: req.Response.Status,
StatusCode: req.Response.StatusCode,
Header: httpHeadersMapFromSlice(req.Response.Headers),
},
}
if req.Response.Body != "" {
reqLog.Response.Body = []byte(req.Response.Body)
}
}
return reqLog, nil
}

11
pkg/db/sqlite/dto.go Normal file
View File

@ -0,0 +1,11 @@
package sqlite
import "time"
type httpResponse struct {
ID *int64
Proto *string
StatusCode *int
Body *[]byte
Timestamp *time.Time
}

493
pkg/db/sqlite/sqlite.go Normal file
View File

@ -0,0 +1,493 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"time"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
// Register sqlite3 for use via database/sql.
_ "github.com/mattn/go-sqlite3"
)
// Client implements reqlog.Repository.
type Client struct {
db *sql.DB
}
// New returns a new Client.
func New(filename string) (*Client, error) {
// Create directory for DB if it doesn't exist yet.
if dbDir, _ := filepath.Split(filename); dbDir != "" {
if _, err := os.Stat(dbDir); os.IsNotExist(err) {
os.Mkdir(dbDir, 0755)
}
}
opts := make(url.Values)
opts.Set("_foreign_keys", "1")
dsn := fmt.Sprintf("file:%v?%v", filename, opts.Encode())
db, err := sql.Open("sqlite3", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("sqlite: could not ping database: %v", err)
}
c := &Client{db: db}
if err := c.prepareSchema(); err != nil {
return nil, fmt.Errorf("sqlite: could not prepare schema: %v", err)
}
return &Client{db: db}, nil
}
func (c Client) prepareSchema() error {
_, err := c.db.Exec(`CREATE TABLE IF NOT EXISTS http_requests (
id INTEGER PRIMARY KEY,
proto TEXT,
url TEXT,
method TEXT,
body BLOB,
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_requests table: %v", err)
}
_, err = c.db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
id INTEGER PRIMARY KEY,
req_id INTEGER REFERENCES http_requests(id) ON DELETE CASCADE,
proto TEXT,
status_code INTEGER,
status_reason TEXT,
body BLOB,
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_responses table: %v", err)
}
_, err = c.db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
id INTEGER PRIMARY KEY,
req_id INTEGER REFERENCES http_requests(id) ON DELETE CASCADE,
res_id INTEGER REFERENCES http_responses(id) ON DELETE CASCADE,
key TEXT,
value TEXT
)`)
if err != nil {
return fmt.Errorf("could not create http_headers table: %v", err)
}
return nil
}
// Close uses the underlying database.
func (c *Client) Close() error {
return c.db.Close()
}
func (c *Client) FindRequestLogs(
ctx context.Context,
opts reqlog.FindRequestsOptions,
scope *scope.Scope,
) (reqLogs []reqlog.Request, err error) {
// TODO: Pass GraphQL field collections upstream, so we can query only
// requested fields.
// TODO: Use opts and scope to filter.
reqQuery := `SELECT
req.id,
req.proto,
req.url,
req.method,
req.body,
req.timestamp,
res.id,
res.proto,
res.status_code,
res.status_reason,
res.body,
res.timestamp
FROM http_requests req
LEFT JOIN http_responses res ON req.id = res.req_id
ORDER BY req.id DESC`
rows, err := c.db.QueryContext(ctx, reqQuery)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
}
defer rows.Close()
for rows.Next() {
var reqLog reqlog.Request
var resDTO httpResponse
var statusReason *string
var rawURL string
err := rows.Scan(
&reqLog.ID,
&reqLog.Request.Proto,
&rawURL,
&reqLog.Request.Method,
&reqLog.Body,
&reqLog.Timestamp,
&resDTO.ID,
&resDTO.Proto,
&resDTO.StatusCode,
&statusReason,
&resDTO.Body,
&resDTO.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
}
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse URL: %v", err)
}
reqLog.Request.URL = u
if resDTO.ID != nil {
status := strconv.Itoa(*resDTO.StatusCode) + " " + *statusReason
reqLog.Response = &reqlog.Response{
ID: *resDTO.ID,
RequestID: reqLog.ID,
Response: http.Response{
Status: status,
StatusCode: *resDTO.StatusCode,
Proto: *resDTO.Proto,
},
Body: *resDTO.Body,
Timestamp: *resDTO.Timestamp,
}
}
reqLogs = append(reqLogs, reqLog)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
}
rows.Close()
reqHeadersStmt, err := c.db.PrepareContext(ctx, `SELECT key, value FROM http_headers WHERE req_id = ?`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer reqHeadersStmt.Close()
resHeadersStmt, err := c.db.PrepareContext(ctx, `SELECT key, value FROM http_headers WHERE res_id = ?`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer resHeadersStmt.Close()
for _, reqLog := range reqLogs {
headers, err := findHeaders(ctx, reqHeadersStmt, reqLog.ID)
if err != nil {
return nil, fmt.Errorf("sqlite: could not query request headers: %v", err)
}
reqLog.Request.Header = headers
if reqLog.Response != nil {
headers, err := findHeaders(ctx, resHeadersStmt, reqLog.Response.ID)
if err != nil {
return nil, fmt.Errorf("sqlite: could not query response headers: %v", err)
}
reqLog.Response.Response.Header = headers
}
}
return reqLogs, nil
}
func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) {
// TODO: Pass GraphQL field collections upstream, so we can query only
// requested fields.
reqQuery := `SELECT
req.id,
req.proto,
req.url,
req.method,
req.body,
req.timestamp,
res.id,
res.proto,
res.status_code,
res.status_reason,
res.body,
res.timestamp
FROM http_requests req
LEFT JOIN http_responses res ON req.id = res.req_id
WHERE req_id = ?
ORDER BY req.id DESC`
var reqLog reqlog.Request
var resDTO httpResponse
var statusReason *string
var rawURL string
err := c.db.QueryRowContext(ctx, reqQuery, id).Scan(
&reqLog.ID,
&reqLog.Request.Proto,
&rawURL,
&reqLog.Request.Method,
&reqLog.Body,
&reqLog.Timestamp,
&resDTO.ID,
&resDTO.Proto,
&resDTO.StatusCode,
&statusReason,
&resDTO.Body,
&resDTO.Timestamp,
)
if err == sql.ErrNoRows {
return reqlog.Request{}, reqlog.ErrRequestNotFound
}
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %v", err)
}
u, err := url.Parse(rawURL)
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse URL: %v", err)
}
reqLog.Request.URL = u
if resDTO.ID != nil {
status := strconv.Itoa(*resDTO.StatusCode) + " " + *statusReason
reqLog.Response = &reqlog.Response{
ID: *resDTO.ID,
RequestID: reqLog.ID,
Response: http.Response{
Status: status,
StatusCode: *resDTO.StatusCode,
Proto: *resDTO.Proto,
},
Body: *resDTO.Body,
Timestamp: *resDTO.Timestamp,
}
}
reqHeadersStmt, err := c.db.PrepareContext(ctx, `SELECT key, value FROM http_headers WHERE req_id = ?`)
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer reqHeadersStmt.Close()
resHeadersStmt, err := c.db.PrepareContext(ctx, `SELECT key, value FROM http_headers WHERE res_id = ?`)
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer resHeadersStmt.Close()
headers, err := findHeaders(ctx, reqHeadersStmt, reqLog.ID)
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not query request headers: %v", err)
}
reqLog.Request.Header = headers
if reqLog.Response != nil {
headers, err := findHeaders(ctx, resHeadersStmt, reqLog.Response.ID)
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not query response headers: %v", err)
}
reqLog.Response.Response.Header = headers
}
return reqLog, nil
}
func (c *Client) AddRequestLog(
ctx context.Context,
req http.Request,
body []byte,
timestamp time.Time,
) (*reqlog.Request, error) {
reqLog := &reqlog.Request{
Request: req,
Body: body,
Timestamp: timestamp,
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
}
defer tx.Rollback()
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
proto,
url,
method,
body,
timestamp
) VALUES (?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer reqStmt.Close()
result, err := reqStmt.ExecContext(ctx,
reqLog.Request.Proto,
reqLog.Request.URL.String(),
reqLog.Request.Method,
reqLog.Body,
reqLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
}
reqID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
}
reqLog.ID = reqID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
req_id,
key,
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
}
return reqLog, nil
}
func (c *Client) AddResponseLog(
ctx context.Context,
reqID int64,
res http.Response,
body []byte,
timestamp time.Time,
) (*reqlog.Response, error) {
resLog := &reqlog.Response{
RequestID: reqID,
Response: res,
Body: body,
Timestamp: timestamp,
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
}
defer tx.Rollback()
resStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_responses (
req_id,
proto,
status_code,
status_reason,
body,
timestamp
) VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer resStmt.Close()
var statusReason string
if len(resLog.Response.Status) > 4 {
statusReason = resLog.Response.Status[4:]
}
result, err := resStmt.ExecContext(ctx,
resLog.RequestID,
resLog.Response.Proto,
resLog.Response.StatusCode,
statusReason,
resLog.Body,
resLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
}
resID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
}
resLog.ID = resID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
res_id,
key,
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
}
return resLog, nil
}
func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.Header) error {
for key, values := range headers {
for _, value := range values {
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
return fmt.Errorf("could not execute statement: %v", err)
}
}
}
return nil
}
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
headers := make(http.Header)
rows, err := stmt.QueryContext(ctx, id)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
}
defer rows.Close()
for rows.Next() {
var key, value string
err := rows.Scan(
&key,
&value,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
}
headers.Add(key, value)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
}
return headers, nil
}

View File

@ -12,8 +12,6 @@ import (
"net/http/httputil"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/google/uuid"
)
type contextKey int
@ -56,11 +54,6 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Add a unique request ID, to be used for correlating responses to requests.
reqID := uuid.New()
ctx := context.WithValue(r.Context(), ReqIDKey, reqID)
r = r.WithContext(ctx)
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
return

View File

@ -2,14 +2,15 @@ package reqlog
import (
"context"
"net/http"
"time"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/google/uuid"
)
type Repository interface {
FindRequestLogs(ctx context.Context, opts FindRequestsOptions, scope *scope.Scope) ([]Request, error)
FindRequestLogByID(ctx context.Context, id uuid.UUID) (Request, error)
AddRequestLog(ctx context.Context, reqLog Request) error
AddResponseLog(ctx context.Context, resLog Response) error
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error)
}

View File

@ -13,8 +13,6 @@ import (
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/google/uuid"
)
type contextKey int
@ -24,7 +22,7 @@ const LogBypassedKey contextKey = 0
var ErrRequestNotFound = errors.New("reqlog: request not found")
type Request struct {
ID uuid.UUID
ID int64
Request http.Request
Body []byte
Timestamp time.Time
@ -32,7 +30,8 @@ type Request struct {
}
type Response struct {
RequestID uuid.UUID
ID int64
RequestID int64
Response http.Response
Body []byte
Timestamp time.Time
@ -72,46 +71,44 @@ func (svc *Service) FindRequests(ctx context.Context, opts FindRequestsOptions)
return svc.repo.FindRequestLogs(ctx, opts, scope)
}
func (svc *Service) FindRequestLogByID(ctx context.Context, id uuid.UUID) (Request, error) {
func (svc *Service) FindRequestLogByID(ctx context.Context, id int64) (Request, error) {
return svc.repo.FindRequestLogByID(ctx, id)
}
func (svc *Service) addRequest(ctx context.Context, reqID uuid.UUID, req http.Request, body []byte) error {
reqLog := Request{
ID: reqID,
Request: req,
Body: body,
Timestamp: time.Now(),
}
return svc.repo.AddRequestLog(ctx, reqLog)
func (svc *Service) addRequest(
ctx context.Context,
req http.Request,
body []byte,
timestamp time.Time,
) (*Request, error) {
return svc.repo.AddRequestLog(ctx, req, body, timestamp)
}
func (svc *Service) addResponse(ctx context.Context, reqID uuid.UUID, res http.Response, body []byte) error {
func (svc *Service) addResponse(
ctx context.Context,
reqID int64,
res http.Response,
body []byte,
timestamp time.Time,
) (*Response, error) {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
if err != nil {
return fmt.Errorf("reqlog: could not create gzip reader: %v", err)
return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err)
}
defer gzipReader.Close()
body, err = ioutil.ReadAll(gzipReader)
if err != nil {
return fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
}
}
resLog := Response{
RequestID: reqID,
Response: res,
Body: body,
Timestamp: time.Now(),
}
return svc.repo.AddResponseLog(ctx, resLog)
return svc.repo.AddResponseLog(ctx, reqID, res, body, timestamp)
}
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
now := time.Now()
next(req)
clone := req.Clone(req.Context())
@ -131,26 +128,23 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
// doens't match any rules of the scope.
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
req = req.WithContext(ctx)
*req = *req.WithContext(ctx)
return
}
reqID, _ := req.Context().Value(proxy.ReqIDKey).(uuid.UUID)
if reqID == uuid.Nil {
log.Println("[ERROR] Request is missing a related request ID")
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
if err != nil {
log.Printf("[ERROR] Could not store request log: %v", err)
return
}
go func() {
if err := svc.addRequest(context.Background(), reqID, *clone, body); err != nil {
log.Printf("[ERROR] Could not store request log: %v", err)
}
}()
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
*req = *req.WithContext(ctx)
}
}
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
now := time.Now()
if err := next(res); err != nil {
return err
}
@ -159,8 +153,8 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
reqID, _ := res.Request.Context().Value(proxy.ReqIDKey).(uuid.UUID)
if reqID == uuid.Nil {
reqID, _ := res.Request.Context().Value(proxy.ReqIDKey).(int64)
if reqID == 0 {
return errors.New("reqlog: request is missing ID")
}
@ -174,7 +168,7 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
go func() {
if err := svc.addResponse(res.Request.Context(), reqID, clone, body); err != nil {
if _, err := svc.addResponse(context.Background(), reqID, clone, body, now); err != nil {
log.Printf("[ERROR] Could not store response log: %v", err)
}
}()