Add scope support

This commit is contained in:
David Stotijn
2020-10-29 20:54:17 +01:00
parent 98dacbe849
commit 0d04996f06
30 changed files with 2807 additions and 119 deletions

View File

@ -3,12 +3,18 @@ package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
@ -16,13 +22,18 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
// Register sqlite3 for use via database/sql.
// Register `sqlite3` driver.
_ "github.com/mattn/go-sqlite3"
// Register `regexp()` function.
_ "github.com/dstotijn/hetty/pkg/db/sqlite/regexp"
)
// Client implements reqlog.Repository.
type Client struct {
db *sqlx.DB
db *sqlx.DB
dbPath string
activeProject string
}
type httpRequestLogsQuery struct {
@ -32,16 +43,28 @@ type httpRequestLogsQuery struct {
joinResponse bool
}
// Open opens a database.
func (c *Client) Open(filename string) error {
func New(dbPath string) (*Client, error) {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
}
}
return &Client{
dbPath: dbPath,
}, nil
}
// OpenProject opens a project database.
func (c *Client) OpenProject(name string) error {
if c.db != nil {
return errors.New("sqlite: database already open")
return errors.New("sqlite: there is already a project open")
}
opts := make(url.Values)
opts.Set("_foreign_keys", "1")
dsn := fmt.Sprintf("file:%v?%v", filename, opts.Encode())
dbPath := filepath.Join(c.dbPath, name+".db")
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
db, err := sqlx.Open("sqlite3", dsn)
if err != nil {
return fmt.Errorf("sqlite: could not open database: %v", err)
@ -56,10 +79,29 @@ func (c *Client) Open(filename string) error {
}
c.db = db
c.activeProject = name
return nil
}
func (c *Client) Projects() ([]proj.Project, error) {
files, err := ioutil.ReadDir(c.dbPath)
if err != nil {
return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err)
}
projects := make([]proj.Project, len(files))
for i, file := range files {
projName := strings.TrimSuffix(file.Name(), ".db")
projects[i] = proj.Project{
Name: projName,
IsActive: c.activeProject == projName,
}
}
return projects, nil
}
func prepareSchema(db *sqlx.DB) error {
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS http_requests (
id INTEGER PRIMARY KEY,
@ -97,6 +139,14 @@ func prepareSchema(db *sqlx.DB) error {
return fmt.Errorf("could not create http_headers table: %v", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
module TEXT PRIMARY KEY,
settings TEXT
)`)
if err != nil {
return fmt.Errorf("could not create settings table: %v", err)
}
return nil
}
@ -108,7 +158,18 @@ func (c *Client) Close() error {
if err := c.db.Close(); err != nil {
return fmt.Errorf("sqlite: could not close database: %v", err)
}
c.db = nil
c.activeProject = ""
return nil
}
func (c *Client) DeleteProject(name string) error {
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
return fmt.Errorf("sqlite: could not remove database file: %v", err)
}
return nil
}
@ -136,11 +197,11 @@ var headerFieldToColumnMap = map[string]string{
func (c *Client) FindRequestLogs(
ctx context.Context,
opts reqlog.FindRequestsOptions,
filter reqlog.FindRequestsFilter,
scope *scope.Scope,
) (reqLogs []reqlog.Request, err error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
@ -153,12 +214,24 @@ func (c *Client) FindRequestLogs(
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
}
sql, _, err := reqQuery.ToSql()
if filter.OnlyInScope && scope != nil {
var ruleExpr []sq.Sqlizer
for _, rule := range scope.Rules() {
if rule.URL != nil {
ruleExpr = append(ruleExpr, sq.Expr("req.url regexp ?", rule.URL.String()))
}
}
if len(ruleExpr) > 0 {
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
}
}
sql, args, err := reqQuery.ToSql()
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse query: %v", err)
}
rows, err := c.db.QueryxContext(ctx, sql, nil)
rows, err := c.db.QueryxContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
}
@ -186,7 +259,7 @@ func (c *Client) FindRequestLogs(
func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) {
if c.db == nil {
return reqlog.Request{}, reqlog.ErrNoProject
return reqlog.Request{}, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
@ -229,7 +302,7 @@ func (c *Client) AddRequestLog(
timestamp time.Time,
) (*reqlog.Request, error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
reqLog := &reqlog.Request{
@ -303,7 +376,7 @@ func (c *Client) AddResponseLog(
timestamp time.Time,
) (*reqlog.Response, error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
resLog := &reqlog.Response{
@ -376,6 +449,49 @@ func (c *Client) AddResponseLog(
return resLog, nil
}
func (c *Client) UpsertSettings(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
// TODO: Fix where `ErrNoProject` lives.
return proj.ErrNoProject
}
jsonSettings, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err)
}
_, err = c.db.ExecContext(ctx,
`INSERT INTO settings (module, settings) VALUES (?, ?)
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
if err != nil {
return fmt.Errorf("sqlite: could not insert scope settings: %v", err)
}
return nil
}
func (c *Client) FindSettingsByModule(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
return proj.ErrNoProject
}
var jsonSettings []byte
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
err := row.Scan(&jsonSettings)
if err == sql.ErrNoRows {
return proj.ErrNoSettings
}
if err != nil {
return fmt.Errorf("sqlite: could not scan row: %v", err)
}
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err)
}
return nil
}
func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.Header) error {
for key, values := range headers {
for _, value := range values {