Add scope support

This commit is contained in:
David Stotijn
2020-10-29 20:54:17 +01:00
parent 98dacbe849
commit 0d04996f06
30 changed files with 2807 additions and 119 deletions

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,14 @@ type HTTPRequestLog struct {
Response *HTTPResponseLog `json:"response"`
}
type HTTPRequestLogFilter struct {
OnlyInScope bool `json:"onlyInScope"`
}
type HTTPRequestLogFilterInput struct {
OnlyInScope *bool `json:"onlyInScope"`
}
type HTTPResponseLog struct {
RequestID int64 `json:"requestId"`
Proto string `json:"proto"`
@ -47,6 +55,28 @@ type Project struct {
IsActive bool `json:"isActive"`
}
type ScopeHeader struct {
Key *string `json:"key"`
Value *string `json:"value"`
}
type ScopeHeaderInput struct {
Key *string `json:"key"`
Value *string `json:"value"`
}
type ScopeRule struct {
URL *string `json:"url"`
Header *ScopeHeader `json:"header"`
Body *string `json:"body"`
}
type ScopeRuleInput struct {
URL *string `json:"url"`
Header *ScopeHeaderInput `json:"header"`
Body *string `json:"body"`
}
type HTTPMethod string
const (

View File

@ -5,17 +5,20 @@ package api
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/99designs/gqlgen/graphql"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type Resolver struct {
RequestLogService *reqlog.Service
ProjectService *proj.Service
ScopeService *scope.Scope
}
type queryResolver struct{ *Resolver }
@ -25,9 +28,8 @@ func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) {
opts := reqlog.FindRequestsOptions{OmitOutOfScope: false}
reqs, err := r.RequestLogService.FindRequests(ctx, opts)
if err == reqlog.ErrNoProject {
reqs, err := r.RequestLogService.FindRequests(ctx)
if err == proj.ErrNoProject {
return nil, &gqlerror.Error{
Path: graphql.GetPath(ctx),
Message: "No active project.",
@ -133,7 +135,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
}
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
p, err := r.ProjectService.Open(name)
p, err := r.ProjectService.Open(ctx, name)
if err == proj.ErrInvalidName {
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
}
@ -178,6 +180,19 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
return projects, nil
}
func (r *queryResolver) Scope(ctx context.Context) ([]ScopeRule, error) {
rules := r.ScopeService.Rules()
return scopeToScopeRules(rules), nil
}
func regexpToStringPtr(r *regexp.Regexp) *string {
if r == nil {
return nil
}
s := r.String()
return &s
}
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
if err := r.ProjectService.Close(); err != nil {
return nil, fmt.Errorf("could not close project: %v", err)
@ -193,3 +208,107 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del
Success: true,
}, nil
}
func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) {
rules := make([]scope.Rule, len(input))
for i, rule := range input {
u, err := stringPtrToRegexp(rule.URL)
if err != nil {
return nil, fmt.Errorf("invalid URL in scope rule: %v", err)
}
var headerKey, headerValue *regexp.Regexp
if rule.Header != nil {
headerKey, err = stringPtrToRegexp(rule.Header.Key)
if err != nil {
return nil, fmt.Errorf("invalid header key in scope rule: %v", err)
}
headerValue, err = stringPtrToRegexp(rule.Header.Key)
if err != nil {
return nil, fmt.Errorf("invalid header value in scope rule: %v", err)
}
}
body, err := stringPtrToRegexp(rule.Body)
if err != nil {
return nil, fmt.Errorf("invalid body in scope rule: %v", err)
}
rules[i] = scope.Rule{
URL: u,
Header: scope.Header{
Key: headerKey,
Value: headerValue,
},
Body: body,
}
}
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
return nil, fmt.Errorf("could not set scope: %v", err)
}
return scopeToScopeRules(rules), nil
}
func (r *queryResolver) HTTPRequestLogFilter(ctx context.Context) (*HTTPRequestLogFilter, error) {
return findReqFilterToHTTPReqLogFilter(r.RequestLogService.FindReqsFilter), nil
}
func (r *mutationResolver) SetHTTPRequestLogFilter(
ctx context.Context,
input *HTTPRequestLogFilterInput,
) (*HTTPRequestLogFilter, error) {
filter := findRequestsFilterFromInput(input)
if err := r.RequestLogService.SetRequestLogFilter(ctx, filter); err != nil {
return nil, fmt.Errorf("could not set request log filter: %v", err)
}
empty := reqlog.FindRequestsFilter{}
if filter == empty {
return nil, nil
}
return findReqFilterToHTTPReqLogFilter(filter), nil
}
func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
if s == nil {
return nil, nil
}
return regexp.Compile(*s)
}
func scopeToScopeRules(rules []scope.Rule) []ScopeRule {
scopeRules := make([]ScopeRule, len(rules))
for i, rule := range rules {
scopeRules[i].URL = regexpToStringPtr(rule.URL)
if rule.Header.Key != nil || rule.Header.Value != nil {
scopeRules[i].Header = &ScopeHeader{
Key: regexpToStringPtr(rule.Header.Key),
Value: regexpToStringPtr(rule.Header.Value),
}
}
scopeRules[i].Body = regexpToStringPtr(rule.Body)
}
return scopeRules
}
func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlog.FindRequestsFilter) {
if input == nil {
return
}
if input.OnlyInScope != nil {
filter.OnlyInScope = *input.OnlyInScope
}
return
}
func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *HTTPRequestLogFilter {
empty := reqlog.FindRequestsFilter{}
if findReqFilter == empty {
return nil
}
httpReqLogFilter := &HTTPRequestLogFilter{
OnlyInScope: findReqFilter.OnlyInScope,
}
return httpReqLogFilter
}

View File

@ -28,6 +28,28 @@ type Project {
isActive: Boolean!
}
type ScopeRule {
url: Regexp
header: ScopeHeader
body: Regexp
}
input ScopeRuleInput {
url: Regexp
header: ScopeHeaderInput
body: Regexp
}
type ScopeHeader {
key: Regexp
value: Regexp
}
input ScopeHeaderInput {
key: Regexp
value: Regexp
}
type CloseProjectResult {
success: Boolean!
}
@ -36,17 +58,31 @@ type DeleteProjectResult {
success: Boolean!
}
input HttpRequestLogFilterInput {
onlyInScope: Boolean
}
type HttpRequestLogFilter {
onlyInScope: Boolean!
}
type Query {
httpRequestLog(id: ID!): HttpRequestLog
httpRequestLogs: [HttpRequestLog!]!
httpRequestLogFilter: HttpRequestLogFilter
activeProject: Project
projects: [Project!]!
scope: [ScopeRule!]!
}
type Mutation {
openProject(name: String!): Project
closeProject: CloseProjectResult!
deleteProject(name: String!): DeleteProjectResult!
setScope(scope: [ScopeRuleInput!]!): [ScopeRule!]!
setHttpRequestLogFilter(
filter: HttpRequestLogFilterInput
): HttpRequestLogFilter
}
enum HttpMethod {
@ -62,3 +98,4 @@ enum HttpMethod {
}
scalar Time
scalar Regexp

View File

@ -0,0 +1,759 @@
/*
** 2012-11-13
**
** The author disclaims copyright to this source code. In place of
** a legal notice, here is a blessing:
**
** May you do good and not evil.
** May you find forgiveness for yourself and forgive others.
** May you share freely, never taking more than you give.
**
******************************************************************************
**
** The code in this file implements a compact but reasonably
** efficient regular-expression matcher for posix extended regular
** expressions against UTF8 text.
**
** This file is an SQLite extension. It registers a single function
** named "regexp(A,B)" where A is the regular expression and B is the
** string to be matched. By registering this function, SQLite will also
** then implement the "B regexp A" operator. Note that with the function
** the regular expression comes first, but with the operator it comes
** second.
**
** The following regular expression syntax is supported:
**
** X* zero or more occurrences of X
** X+ one or more occurrences of X
** X? zero or one occurrences of X
** X{p,q} between p and q occurrences of X
** (X) match X
** X|Y X or Y
** ^X X occurring at the beginning of the string
** X$ X occurring at the end of the string
** . Match any single character
** \c Character c where c is one of \{}()[]|*+?.
** \c C-language escapes for c in afnrtv. ex: \t or \n
** \uXXXX Where XXXX is exactly 4 hex digits, unicode value XXXX
** \xXX Where XX is exactly 2 hex digits, unicode value XX
** [abc] Any single character from the set abc
** [^abc] Any single character not in the set abc
** [a-z] Any single character in the range a-z
** [^a-z] Any single character not in the range a-z
** \b Word boundary
** \w Word character. [A-Za-z0-9_]
** \W Non-word character
** \d Digit
** \D Non-digit
** \s Whitespace character
** \S Non-whitespace character
**
** A nondeterministic finite automaton (NFA) is used for matching, so the
** performance is bounded by O(N*M) where N is the size of the regular
** expression and M is the size of the input string. The matcher never
** exhibits exponential behavior. Note that the X{p,q} operator expands
** to p copies of X following by q-p copies of X? and that the size of the
** regular expression in the O(N*M) performance bound is computed after
** this expansion.
*/
#include <string.h>
#include <stdlib.h>
#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1
/*
** The following #defines change the names of some functions implemented in
** this file to prevent name collisions with C-library functions of the
** same name.
*/
#define re_match sqlite3re_match
#define re_compile sqlite3re_compile
#define re_free sqlite3re_free
/* The end-of-input character */
#define RE_EOF 0 /* End of input */
/* The NFA is implemented as sequence of opcodes taken from the following
** set. Each opcode has a single integer argument.
*/
#define RE_OP_MATCH 1 /* Match the one character in the argument */
#define RE_OP_ANY 2 /* Match any one character. (Implements ".") */
#define RE_OP_ANYSTAR 3 /* Special optimized version of .* */
#define RE_OP_FORK 4 /* Continue to both next and opcode at iArg */
#define RE_OP_GOTO 5 /* Jump to opcode at iArg */
#define RE_OP_ACCEPT 6 /* Halt and indicate a successful match */
#define RE_OP_CC_INC 7 /* Beginning of a [...] character class */
#define RE_OP_CC_EXC 8 /* Beginning of a [^...] character class */
#define RE_OP_CC_VALUE 9 /* Single value in a character class */
#define RE_OP_CC_RANGE 10 /* Range of values in a character class */
#define RE_OP_WORD 11 /* Perl word character [A-Za-z0-9_] */
#define RE_OP_NOTWORD 12 /* Not a perl word character */
#define RE_OP_DIGIT 13 /* digit: [0-9] */
#define RE_OP_NOTDIGIT 14 /* Not a digit */
#define RE_OP_SPACE 15 /* space: [ \t\n\r\v\f] */
#define RE_OP_NOTSPACE 16 /* Not a digit */
#define RE_OP_BOUNDARY 17 /* Boundary between word and non-word */
/* Each opcode is a "state" in the NFA */
typedef unsigned short ReStateNumber;
/* Because this is an NFA and not a DFA, multiple states can be active at
** once. An instance of the following object records all active states in
** the NFA. The implementation is optimized for the common case where the
** number of actives states is small.
*/
typedef struct ReStateSet {
unsigned nState; /* Number of current states */
ReStateNumber *aState; /* Current states */
} ReStateSet;
/* An input string read one character at a time.
*/
typedef struct ReInput ReInput;
struct ReInput {
const unsigned char *z; /* All text */
int i; /* Next byte to read */
int mx; /* EOF when i>=mx */
};
/* A compiled NFA (or an NFA that is in the process of being compiled) is
** an instance of the following object.
*/
typedef struct ReCompiled ReCompiled;
struct ReCompiled {
ReInput sIn; /* Regular expression text */
const char *zErr; /* Error message to return */
char *aOp; /* Operators for the virtual machine */
int *aArg; /* Arguments to each operator */
unsigned (*xNextChar)(ReInput*); /* Next character function */
unsigned char zInit[12]; /* Initial text to match */
int nInit; /* Number of characters in zInit */
unsigned nState; /* Number of entries in aOp[] and aArg[] */
unsigned nAlloc; /* Slots allocated for aOp[] and aArg[] */
};
/* Add a state to the given state set if it is not already there */
static void re_add_state(ReStateSet *pSet, int newState){
unsigned i;
for(i=0; i<pSet->nState; i++) if( pSet->aState[i]==newState ) return;
pSet->aState[pSet->nState++] = (ReStateNumber)newState;
}
/* Extract the next unicode character from *pzIn and return it. Advance
** *pzIn to the first byte past the end of the character returned. To
** be clear: this routine converts utf8 to unicode. This routine is
** optimized for the common case where the next character is a single byte.
*/
static unsigned re_next_char(ReInput *p){
unsigned c;
if( p->i>=p->mx ) return 0;
c = p->z[p->i++];
if( c>=0x80 ){
if( (c&0xe0)==0xc0 && p->i<p->mx && (p->z[p->i]&0xc0)==0x80 ){
c = (c&0x1f)<<6 | (p->z[p->i++]&0x3f);
if( c<0x80 ) c = 0xfffd;
}else if( (c&0xf0)==0xe0 && p->i+1<p->mx && (p->z[p->i]&0xc0)==0x80
&& (p->z[p->i+1]&0xc0)==0x80 ){
c = (c&0x0f)<<12 | ((p->z[p->i]&0x3f)<<6) | (p->z[p->i+1]&0x3f);
p->i += 2;
if( c<=0x7ff || (c>=0xd800 && c<=0xdfff) ) c = 0xfffd;
}else if( (c&0xf8)==0xf0 && p->i+3<p->mx && (p->z[p->i]&0xc0)==0x80
&& (p->z[p->i+1]&0xc0)==0x80 && (p->z[p->i+2]&0xc0)==0x80 ){
c = (c&0x07)<<18 | ((p->z[p->i]&0x3f)<<12) | ((p->z[p->i+1]&0x3f)<<6)
| (p->z[p->i+2]&0x3f);
p->i += 3;
if( c<=0xffff || c>0x10ffff ) c = 0xfffd;
}else{
c = 0xfffd;
}
}
return c;
}
static unsigned re_next_char_nocase(ReInput *p){
unsigned c = re_next_char(p);
if( c>='A' && c<='Z' ) c += 'a' - 'A';
return c;
}
/* Return true if c is a perl "word" character: [A-Za-z0-9_] */
static int re_word_char(int c){
return (c>='0' && c<='9') || (c>='a' && c<='z')
|| (c>='A' && c<='Z') || c=='_';
}
/* Return true if c is a "digit" character: [0-9] */
static int re_digit_char(int c){
return (c>='0' && c<='9');
}
/* Return true if c is a perl "space" character: [ \t\r\n\v\f] */
static int re_space_char(int c){
return c==' ' || c=='\t' || c=='\n' || c=='\r' || c=='\v' || c=='\f';
}
/* Run a compiled regular expression on the zero-terminated input
** string zIn[]. Return true on a match and false if there is no match.
*/
static int re_match(ReCompiled *pRe, const unsigned char *zIn, int nIn){
ReStateSet aStateSet[2], *pThis, *pNext;
ReStateNumber aSpace[100];
ReStateNumber *pToFree;
unsigned int i = 0;
unsigned int iSwap = 0;
int c = RE_EOF+1;
int cPrev = 0;
int rc = 0;
ReInput in;
in.z = zIn;
in.i = 0;
in.mx = nIn>=0 ? nIn : (int)strlen((char const*)zIn);
/* Look for the initial prefix match, if there is one. */
if( pRe->nInit ){
unsigned char x = pRe->zInit[0];
while( in.i+pRe->nInit<=in.mx
&& (zIn[in.i]!=x ||
strncmp((const char*)zIn+in.i, (const char*)pRe->zInit, pRe->nInit)!=0)
){
in.i++;
}
if( in.i+pRe->nInit>in.mx ) return 0;
}
if( pRe->nState<=(sizeof(aSpace)/(sizeof(aSpace[0])*2)) ){
pToFree = 0;
aStateSet[0].aState = aSpace;
}else{
pToFree = sqlite3_malloc64( sizeof(ReStateNumber)*2*pRe->nState );
if( pToFree==0 ) return -1;
aStateSet[0].aState = pToFree;
}
aStateSet[1].aState = &aStateSet[0].aState[pRe->nState];
pNext = &aStateSet[1];
pNext->nState = 0;
re_add_state(pNext, 0);
while( c!=RE_EOF && pNext->nState>0 ){
cPrev = c;
c = pRe->xNextChar(&in);
pThis = pNext;
pNext = &aStateSet[iSwap];
iSwap = 1 - iSwap;
pNext->nState = 0;
for(i=0; i<pThis->nState; i++){
int x = pThis->aState[i];
switch( pRe->aOp[x] ){
case RE_OP_MATCH: {
if( pRe->aArg[x]==c ) re_add_state(pNext, x+1);
break;
}
case RE_OP_ANY: {
re_add_state(pNext, x+1);
break;
}
case RE_OP_WORD: {
if( re_word_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_NOTWORD: {
if( !re_word_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_DIGIT: {
if( re_digit_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_NOTDIGIT: {
if( !re_digit_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_SPACE: {
if( re_space_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_NOTSPACE: {
if( !re_space_char(c) ) re_add_state(pNext, x+1);
break;
}
case RE_OP_BOUNDARY: {
if( re_word_char(c)!=re_word_char(cPrev) ) re_add_state(pThis, x+1);
break;
}
case RE_OP_ANYSTAR: {
re_add_state(pNext, x);
re_add_state(pThis, x+1);
break;
}
case RE_OP_FORK: {
re_add_state(pThis, x+pRe->aArg[x]);
re_add_state(pThis, x+1);
break;
}
case RE_OP_GOTO: {
re_add_state(pThis, x+pRe->aArg[x]);
break;
}
case RE_OP_ACCEPT: {
rc = 1;
goto re_match_end;
}
case RE_OP_CC_INC:
case RE_OP_CC_EXC: {
int j = 1;
int n = pRe->aArg[x];
int hit = 0;
for(j=1; j>0 && j<n; j++){
if( pRe->aOp[x+j]==RE_OP_CC_VALUE ){
if( pRe->aArg[x+j]==c ){
hit = 1;
j = -1;
}
}else{
if( pRe->aArg[x+j]<=c && pRe->aArg[x+j+1]>=c ){
hit = 1;
j = -1;
}else{
j++;
}
}
}
if( pRe->aOp[x]==RE_OP_CC_EXC ) hit = !hit;
if( hit ) re_add_state(pNext, x+n);
break;
}
}
}
}
for(i=0; i<pNext->nState; i++){
if( pRe->aOp[pNext->aState[i]]==RE_OP_ACCEPT ){ rc = 1; break; }
}
re_match_end:
sqlite3_free(pToFree);
return rc;
}
/* Resize the opcode and argument arrays for an RE under construction.
*/
static int re_resize(ReCompiled *p, int N){
char *aOp;
int *aArg;
aOp = sqlite3_realloc64(p->aOp, N*sizeof(p->aOp[0]));
if( aOp==0 ) return 1;
p->aOp = aOp;
aArg = sqlite3_realloc64(p->aArg, N*sizeof(p->aArg[0]));
if( aArg==0 ) return 1;
p->aArg = aArg;
p->nAlloc = N;
return 0;
}
/* Insert a new opcode and argument into an RE under construction. The
** insertion point is just prior to existing opcode iBefore.
*/
static int re_insert(ReCompiled *p, int iBefore, int op, int arg){
int i;
if( p->nAlloc<=p->nState && re_resize(p, p->nAlloc*2) ) return 0;
for(i=p->nState; i>iBefore; i--){
p->aOp[i] = p->aOp[i-1];
p->aArg[i] = p->aArg[i-1];
}
p->nState++;
p->aOp[iBefore] = (char)op;
p->aArg[iBefore] = arg;
return iBefore;
}
/* Append a new opcode and argument to the end of the RE under construction.
*/
static int re_append(ReCompiled *p, int op, int arg){
return re_insert(p, p->nState, op, arg);
}
/* Make a copy of N opcodes starting at iStart onto the end of the RE
** under construction.
*/
static void re_copy(ReCompiled *p, int iStart, int N){
if( p->nState+N>=p->nAlloc && re_resize(p, p->nAlloc*2+N) ) return;
memcpy(&p->aOp[p->nState], &p->aOp[iStart], N*sizeof(p->aOp[0]));
memcpy(&p->aArg[p->nState], &p->aArg[iStart], N*sizeof(p->aArg[0]));
p->nState += N;
}
/* Return true if c is a hexadecimal digit character: [0-9a-fA-F]
** If c is a hex digit, also set *pV = (*pV)*16 + valueof(c). If
** c is not a hex digit *pV is unchanged.
*/
static int re_hex(int c, int *pV){
if( c>='0' && c<='9' ){
c -= '0';
}else if( c>='a' && c<='f' ){
c -= 'a' - 10;
}else if( c>='A' && c<='F' ){
c -= 'A' - 10;
}else{
return 0;
}
*pV = (*pV)*16 + (c & 0xff);
return 1;
}
/* A backslash character has been seen, read the next character and
** return its interpretation.
*/
static unsigned re_esc_char(ReCompiled *p){
static const char zEsc[] = "afnrtv\\()*.+?[$^{|}]";
static const char zTrans[] = "\a\f\n\r\t\v";
int i, v = 0;
char c;
if( p->sIn.i>=p->sIn.mx ) return 0;
c = p->sIn.z[p->sIn.i];
if( c=='u' && p->sIn.i+4<p->sIn.mx ){
const unsigned char *zIn = p->sIn.z + p->sIn.i;
if( re_hex(zIn[1],&v)
&& re_hex(zIn[2],&v)
&& re_hex(zIn[3],&v)
&& re_hex(zIn[4],&v)
){
p->sIn.i += 5;
return v;
}
}
if( c=='x' && p->sIn.i+2<p->sIn.mx ){
const unsigned char *zIn = p->sIn.z + p->sIn.i;
if( re_hex(zIn[1],&v)
&& re_hex(zIn[2],&v)
){
p->sIn.i += 3;
return v;
}
}
for(i=0; zEsc[i] && zEsc[i]!=c; i++){}
if( zEsc[i] ){
if( i<6 ) c = zTrans[i];
p->sIn.i++;
}else{
p->zErr = "unknown \\ escape";
}
return c;
}
/* Forward declaration */
static const char *re_subcompile_string(ReCompiled*);
/* Peek at the next byte of input */
static unsigned char rePeek(ReCompiled *p){
return p->sIn.i<p->sIn.mx ? p->sIn.z[p->sIn.i] : 0;
}
/* Compile RE text into a sequence of opcodes. Continue up to the
** first unmatched ")" character, then return. If an error is found,
** return a pointer to the error message string.
*/
static const char *re_subcompile_re(ReCompiled *p){
const char *zErr;
int iStart, iEnd, iGoto;
iStart = p->nState;
zErr = re_subcompile_string(p);
if( zErr ) return zErr;
while( rePeek(p)=='|' ){
iEnd = p->nState;
re_insert(p, iStart, RE_OP_FORK, iEnd + 2 - iStart);
iGoto = re_append(p, RE_OP_GOTO, 0);
p->sIn.i++;
zErr = re_subcompile_string(p);
if( zErr ) return zErr;
p->aArg[iGoto] = p->nState - iGoto;
}
return 0;
}
/* Compile an element of regular expression text (anything that can be
** an operand to the "|" operator). Return NULL on success or a pointer
** to the error message if there is a problem.
*/
static const char *re_subcompile_string(ReCompiled *p){
int iPrev = -1;
int iStart;
unsigned c;
const char *zErr;
while( (c = p->xNextChar(&p->sIn))!=0 ){
iStart = p->nState;
switch( c ){
case '|':
case '$':
case ')': {
p->sIn.i--;
return 0;
}
case '(': {
zErr = re_subcompile_re(p);
if( zErr ) return zErr;
if( rePeek(p)!=')' ) return "unmatched '('";
p->sIn.i++;
break;
}
case '.': {
if( rePeek(p)=='*' ){
re_append(p, RE_OP_ANYSTAR, 0);
p->sIn.i++;
}else{
re_append(p, RE_OP_ANY, 0);
}
break;
}
case '*': {
if( iPrev<0 ) return "'*' without operand";
re_insert(p, iPrev, RE_OP_GOTO, p->nState - iPrev + 1);
re_append(p, RE_OP_FORK, iPrev - p->nState + 1);
break;
}
case '+': {
if( iPrev<0 ) return "'+' without operand";
re_append(p, RE_OP_FORK, iPrev - p->nState);
break;
}
case '?': {
if( iPrev<0 ) return "'?' without operand";
re_insert(p, iPrev, RE_OP_FORK, p->nState - iPrev+1);
break;
}
case '{': {
int m = 0, n = 0;
int sz, j;
if( iPrev<0 ) return "'{m,n}' without operand";
while( (c=rePeek(p))>='0' && c<='9' ){ m = m*10 + c - '0'; p->sIn.i++; }
n = m;
if( c==',' ){
p->sIn.i++;
n = 0;
while( (c=rePeek(p))>='0' && c<='9' ){ n = n*10 + c-'0'; p->sIn.i++; }
}
if( c!='}' ) return "unmatched '{'";
if( n>0 && n<m ) return "n less than m in '{m,n}'";
p->sIn.i++;
sz = p->nState - iPrev;
if( m==0 ){
if( n==0 ) return "both m and n are zero in '{m,n}'";
re_insert(p, iPrev, RE_OP_FORK, sz+1);
n--;
}else{
for(j=1; j<m; j++) re_copy(p, iPrev, sz);
}
for(j=m; j<n; j++){
re_append(p, RE_OP_FORK, sz+1);
re_copy(p, iPrev, sz);
}
if( n==0 && m>0 ){
re_append(p, RE_OP_FORK, -sz);
}
break;
}
case '[': {
int iFirst = p->nState;
if( rePeek(p)=='^' ){
re_append(p, RE_OP_CC_EXC, 0);
p->sIn.i++;
}else{
re_append(p, RE_OP_CC_INC, 0);
}
while( (c = p->xNextChar(&p->sIn))!=0 ){
if( c=='[' && rePeek(p)==':' ){
return "POSIX character classes not supported";
}
if( c=='\\' ) c = re_esc_char(p);
if( rePeek(p)=='-' ){
re_append(p, RE_OP_CC_RANGE, c);
p->sIn.i++;
c = p->xNextChar(&p->sIn);
if( c=='\\' ) c = re_esc_char(p);
re_append(p, RE_OP_CC_RANGE, c);
}else{
re_append(p, RE_OP_CC_VALUE, c);
}
if( rePeek(p)==']' ){ p->sIn.i++; break; }
}
if( c==0 ) return "unclosed '['";
p->aArg[iFirst] = p->nState - iFirst;
break;
}
case '\\': {
int specialOp = 0;
switch( rePeek(p) ){
case 'b': specialOp = RE_OP_BOUNDARY; break;
case 'd': specialOp = RE_OP_DIGIT; break;
case 'D': specialOp = RE_OP_NOTDIGIT; break;
case 's': specialOp = RE_OP_SPACE; break;
case 'S': specialOp = RE_OP_NOTSPACE; break;
case 'w': specialOp = RE_OP_WORD; break;
case 'W': specialOp = RE_OP_NOTWORD; break;
}
if( specialOp ){
p->sIn.i++;
re_append(p, specialOp, 0);
}else{
c = re_esc_char(p);
re_append(p, RE_OP_MATCH, c);
}
break;
}
default: {
re_append(p, RE_OP_MATCH, c);
break;
}
}
iPrev = iStart;
}
return 0;
}
/* Free and reclaim all the memory used by a previously compiled
** regular expression. Applications should invoke this routine once
** for every call to re_compile() to avoid memory leaks.
*/
static void re_free(ReCompiled *pRe){
if( pRe ){
sqlite3_free(pRe->aOp);
sqlite3_free(pRe->aArg);
sqlite3_free(pRe);
}
}
/*
** Compile a textual regular expression in zIn[] into a compiled regular
** expression suitable for us by re_match() and return a pointer to the
** compiled regular expression in *ppRe. Return NULL on success or an
** error message if something goes wrong.
*/
static const char *re_compile(ReCompiled **ppRe, const char *zIn, int noCase){
ReCompiled *pRe;
const char *zErr;
int i, j;
*ppRe = 0;
pRe = sqlite3_malloc( sizeof(*pRe) );
if( pRe==0 ){
return "out of memory";
}
memset(pRe, 0, sizeof(*pRe));
pRe->xNextChar = noCase ? re_next_char_nocase : re_next_char;
if( re_resize(pRe, 30) ){
re_free(pRe);
return "out of memory";
}
if( zIn[0]=='^' ){
zIn++;
}else{
re_append(pRe, RE_OP_ANYSTAR, 0);
}
pRe->sIn.z = (unsigned char*)zIn;
pRe->sIn.i = 0;
pRe->sIn.mx = (int)strlen(zIn);
zErr = re_subcompile_re(pRe);
if( zErr ){
re_free(pRe);
return zErr;
}
if( rePeek(pRe)=='$' && pRe->sIn.i+1>=pRe->sIn.mx ){
re_append(pRe, RE_OP_MATCH, RE_EOF);
re_append(pRe, RE_OP_ACCEPT, 0);
*ppRe = pRe;
}else if( pRe->sIn.i>=pRe->sIn.mx ){
re_append(pRe, RE_OP_ACCEPT, 0);
*ppRe = pRe;
}else{
re_free(pRe);
return "unrecognized character";
}
/* The following is a performance optimization. If the regex begins with
** ".*" (if the input regex lacks an initial "^") and afterwards there are
** one or more matching characters, enter those matching characters into
** zInit[]. The re_match() routine can then search ahead in the input
** string looking for the initial match without having to run the whole
** regex engine over the string. Do not worry able trying to match
** unicode characters beyond plane 0 - those are very rare and this is
** just an optimization. */
if( pRe->aOp[0]==RE_OP_ANYSTAR ){
for(j=0, i=1; j<sizeof(pRe->zInit)-2 && pRe->aOp[i]==RE_OP_MATCH; i++){
unsigned x = pRe->aArg[i];
if( x<=127 ){
pRe->zInit[j++] = (unsigned char)x;
}else if( x<=0xfff ){
pRe->zInit[j++] = (unsigned char)(0xc0 | (x>>6));
pRe->zInit[j++] = 0x80 | (x&0x3f);
}else if( x<=0xffff ){
pRe->zInit[j++] = (unsigned char)(0xd0 | (x>>12));
pRe->zInit[j++] = 0x80 | ((x>>6)&0x3f);
pRe->zInit[j++] = 0x80 | (x&0x3f);
}else{
break;
}
}
if( j>0 && pRe->zInit[j-1]==0 ) j--;
pRe->nInit = j;
}
return pRe->zErr;
}
/*
** Implementation of the regexp() SQL function. This function implements
** the build-in REGEXP operator. The first argument to the function is the
** pattern and the second argument is the string. So, the SQL statements:
**
** A REGEXP B
**
** is implemented as regexp(B,A).
*/
static void re_sql_func(
sqlite3_context *context,
int argc,
sqlite3_value **argv
){
ReCompiled *pRe; /* Compiled regular expression */
const char *zPattern; /* The regular expression */
const unsigned char *zStr;/* String being searched */
const char *zErr; /* Compile error message */
int setAux = 0; /* True to invoke sqlite3_set_auxdata() */
pRe = sqlite3_get_auxdata(context, 0);
if( pRe==0 ){
zPattern = (const char*)sqlite3_value_text(argv[0]);
if( zPattern==0 ) return;
zErr = re_compile(&pRe, zPattern, 0);
if( zErr ){
re_free(pRe);
sqlite3_result_error(context, zErr, -1);
return;
}
if( pRe==0 ){
sqlite3_result_error_nomem(context);
return;
}
setAux = 1;
}
zStr = (const unsigned char*)sqlite3_value_text(argv[1]);
if( zStr!=0 ){
sqlite3_result_int(context, re_match(pRe, zStr, -1));
}
if( setAux ){
sqlite3_set_auxdata(context, 0, pRe, (void(*)(void*))re_free);
}
}
/*
** Invoke this routine to register the regexp() function with the
** SQLite database connection.
*/
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_regexp_init(
sqlite3 *db,
char **pzErrMsg,
const sqlite3_api_routines *pApi
){
int rc = SQLITE_OK;
SQLITE_EXTENSION_INIT2(pApi);
rc = sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8, 0, re_sql_func, 0, 0);
return rc;
}

View File

@ -0,0 +1,16 @@
package regexp
// #ifndef USE_LIBSQLITE3
// #include <sqlite3-binding.h>
// #else
// #include <sqlite3.h>
// #endif
//
// // Extension function defined in regexp.c.
// extern int sqlite3_regexp_init(sqlite3*, char**, const sqlite3_api_routines*);
//
// // Use constructor to register extension function with sqlite.
// void __attribute__((constructor)) init(void) {
// sqlite3_auto_extension((void*) sqlite3_regexp_init);
// }
import "C"

View File

@ -3,12 +3,18 @@ package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
@ -16,13 +22,18 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
// Register sqlite3 for use via database/sql.
// Register `sqlite3` driver.
_ "github.com/mattn/go-sqlite3"
// Register `regexp()` function.
_ "github.com/dstotijn/hetty/pkg/db/sqlite/regexp"
)
// Client implements reqlog.Repository.
type Client struct {
db *sqlx.DB
db *sqlx.DB
dbPath string
activeProject string
}
type httpRequestLogsQuery struct {
@ -32,16 +43,28 @@ type httpRequestLogsQuery struct {
joinResponse bool
}
// Open opens a database.
func (c *Client) Open(filename string) error {
func New(dbPath string) (*Client, error) {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
}
}
return &Client{
dbPath: dbPath,
}, nil
}
// OpenProject opens a project database.
func (c *Client) OpenProject(name string) error {
if c.db != nil {
return errors.New("sqlite: database already open")
return errors.New("sqlite: there is already a project open")
}
opts := make(url.Values)
opts.Set("_foreign_keys", "1")
dsn := fmt.Sprintf("file:%v?%v", filename, opts.Encode())
dbPath := filepath.Join(c.dbPath, name+".db")
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
db, err := sqlx.Open("sqlite3", dsn)
if err != nil {
return fmt.Errorf("sqlite: could not open database: %v", err)
@ -56,10 +79,29 @@ func (c *Client) Open(filename string) error {
}
c.db = db
c.activeProject = name
return nil
}
func (c *Client) Projects() ([]proj.Project, error) {
files, err := ioutil.ReadDir(c.dbPath)
if err != nil {
return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err)
}
projects := make([]proj.Project, len(files))
for i, file := range files {
projName := strings.TrimSuffix(file.Name(), ".db")
projects[i] = proj.Project{
Name: projName,
IsActive: c.activeProject == projName,
}
}
return projects, nil
}
func prepareSchema(db *sqlx.DB) error {
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS http_requests (
id INTEGER PRIMARY KEY,
@ -97,6 +139,14 @@ func prepareSchema(db *sqlx.DB) error {
return fmt.Errorf("could not create http_headers table: %v", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
module TEXT PRIMARY KEY,
settings TEXT
)`)
if err != nil {
return fmt.Errorf("could not create settings table: %v", err)
}
return nil
}
@ -108,7 +158,18 @@ func (c *Client) Close() error {
if err := c.db.Close(); err != nil {
return fmt.Errorf("sqlite: could not close database: %v", err)
}
c.db = nil
c.activeProject = ""
return nil
}
func (c *Client) DeleteProject(name string) error {
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
return fmt.Errorf("sqlite: could not remove database file: %v", err)
}
return nil
}
@ -136,11 +197,11 @@ var headerFieldToColumnMap = map[string]string{
func (c *Client) FindRequestLogs(
ctx context.Context,
opts reqlog.FindRequestsOptions,
filter reqlog.FindRequestsFilter,
scope *scope.Scope,
) (reqLogs []reqlog.Request, err error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
@ -153,12 +214,24 @@ func (c *Client) FindRequestLogs(
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
}
sql, _, err := reqQuery.ToSql()
if filter.OnlyInScope && scope != nil {
var ruleExpr []sq.Sqlizer
for _, rule := range scope.Rules() {
if rule.URL != nil {
ruleExpr = append(ruleExpr, sq.Expr("req.url regexp ?", rule.URL.String()))
}
}
if len(ruleExpr) > 0 {
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
}
}
sql, args, err := reqQuery.ToSql()
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse query: %v", err)
}
rows, err := c.db.QueryxContext(ctx, sql, nil)
rows, err := c.db.QueryxContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
}
@ -186,7 +259,7 @@ func (c *Client) FindRequestLogs(
func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) {
if c.db == nil {
return reqlog.Request{}, reqlog.ErrNoProject
return reqlog.Request{}, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
@ -229,7 +302,7 @@ func (c *Client) AddRequestLog(
timestamp time.Time,
) (*reqlog.Request, error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
reqLog := &reqlog.Request{
@ -303,7 +376,7 @@ func (c *Client) AddResponseLog(
timestamp time.Time,
) (*reqlog.Response, error) {
if c.db == nil {
return nil, reqlog.ErrNoProject
return nil, proj.ErrNoProject
}
resLog := &reqlog.Response{
@ -376,6 +449,49 @@ func (c *Client) AddResponseLog(
return resLog, nil
}
func (c *Client) UpsertSettings(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
// TODO: Fix where `ErrNoProject` lives.
return proj.ErrNoProject
}
jsonSettings, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err)
}
_, err = c.db.ExecContext(ctx,
`INSERT INTO settings (module, settings) VALUES (?, ?)
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
if err != nil {
return fmt.Errorf("sqlite: could not insert scope settings: %v", err)
}
return nil
}
func (c *Client) FindSettingsByModule(ctx context.Context, module string, settings interface{}) error {
if c.db == nil {
return proj.ErrNoProject
}
var jsonSettings []byte
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
err := row.Scan(&jsonSettings)
if err == sql.ErrNoRows {
return proj.ErrNoSettings
}
if err != nil {
return fmt.Errorf("sqlite: could not scan row: %v", err)
}
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err)
}
return nil
}
func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.Header) error {
for key, values := range headers {
for _, value := range values {

View File

@ -1,25 +1,24 @@
package proj
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"log"
"regexp"
"strings"
"github.com/dstotijn/hetty/pkg/db/sqlite"
"github.com/dstotijn/hetty/pkg/scope"
"sync"
)
type OnProjectOpenFn func(name string) error
type OnProjectCloseFn func(name string) error
// Service is used for managing projects.
type Service struct {
dbPath string
db *sqlite.Client
name string
Scope *scope.Scope
repo Repository
activeProject string
onProjectOpenFns []OnProjectOpenFn
onProjectCloseFns []OnProjectCloseFn
mu sync.RWMutex
}
type Project struct {
@ -29,33 +28,33 @@ type Project struct {
var (
ErrNoProject = errors.New("proj: no open project")
ErrNoSettings = errors.New("proj: settings not found")
ErrInvalidName = errors.New("proj: invalid name, must be alphanumeric or whitespace chars")
)
var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`)
// NewService returns a new Service.
func NewService(dbPath string) (*Service, error) {
// Create directory for DBs if it doesn't exist yet.
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
}
}
func NewService(repo Repository) (*Service, error) {
return &Service{
dbPath: dbPath,
db: &sqlite.Client{},
Scope: scope.New(nil),
repo: repo,
}, nil
}
// Close closes the currently open project database (if there is one).
func (svc *Service) Close() error {
if err := svc.db.Close(); err != nil {
svc.mu.Lock()
defer svc.mu.Unlock()
closedProject := svc.activeProject
if err := svc.repo.Close(); err != nil {
return fmt.Errorf("proj: could not close project: %v", err)
}
svc.name = ""
svc.activeProject = ""
svc.emitProjectClosed(closedProject)
return nil
}
@ -64,41 +63,37 @@ func (svc *Service) Delete(name string) error {
if name == "" {
return errors.New("proj: name cannot be empty")
}
if svc.name == name {
if svc.activeProject == name {
return fmt.Errorf("proj: project (%v) is active", name)
}
if err := os.Remove(filepath.Join(svc.dbPath, name+".db")); err != nil {
return fmt.Errorf("proj: could not remove database file: %v", err)
if err := svc.repo.DeleteProject(name); err != nil {
return fmt.Errorf("proj: could not delete project: %v", err)
}
return nil
}
// Database returns the currently open database. If no database is open, it will
// return `nil`.
func (svc *Service) Database() *sqlite.Client {
return svc.db
}
// Open opens a database identified with `name`. If a database with this
// identifier doesn't exist yet, it will be automatically created.
func (svc *Service) Open(name string) (Project, error) {
func (svc *Service) Open(ctx context.Context, name string) (Project, error) {
if !nameRegexp.MatchString(name) {
return Project{}, ErrInvalidName
}
if err := svc.db.Close(); err != nil {
svc.mu.Lock()
defer svc.mu.Unlock()
if err := svc.repo.Close(); err != nil {
return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err)
}
dbPath := filepath.Join(svc.dbPath, name+".db")
err := svc.db.Open(dbPath)
if err != nil {
if err := svc.repo.OpenProject(name); err != nil {
return Project{}, fmt.Errorf("proj: could not open database: %v", err)
}
svc.name = name
svc.activeProject = name
svc.emitProjectOpened()
return Project{
Name: name,
@ -107,29 +102,51 @@ func (svc *Service) Open(name string) (Project, error) {
}
func (svc *Service) ActiveProject() (Project, error) {
if !svc.db.IsOpen() {
activeProject := svc.activeProject
if activeProject == "" {
return Project{}, ErrNoProject
}
return Project{
Name: svc.name,
Name: activeProject,
}, nil
}
func (svc *Service) Projects() ([]Project, error) {
files, err := ioutil.ReadDir(svc.dbPath)
projects, err := svc.repo.Projects()
if err != nil {
return nil, fmt.Errorf("proj: could not read projects directory: %v", err)
}
projects := make([]Project, len(files))
for i, file := range files {
projName := strings.TrimSuffix(file.Name(), ".db")
projects[i] = Project{
Name: projName,
IsActive: svc.name == projName,
}
return nil, fmt.Errorf("proj: could not get projects: %v", err)
}
return projects, nil
}
func (svc *Service) OnProjectOpen(fn OnProjectOpenFn) {
svc.mu.Lock()
defer svc.mu.Unlock()
svc.onProjectOpenFns = append(svc.onProjectOpenFns, fn)
}
func (svc *Service) OnProjectClose(fn OnProjectCloseFn) {
svc.mu.Lock()
defer svc.mu.Unlock()
svc.onProjectCloseFns = append(svc.onProjectCloseFns, fn)
}
func (svc *Service) emitProjectOpened() {
for _, fn := range svc.onProjectOpenFns {
if err := fn(svc.activeProject); err != nil {
log.Printf("[ERROR] Could not execute onProjectOpen function: %v", err)
}
}
}
func (svc *Service) emitProjectClosed(name string) {
for _, fn := range svc.onProjectCloseFns {
if err := fn(name); err != nil {
log.Printf("[ERROR] Could not execute onProjectClose function: %v", err)
}
}
}

14
pkg/proj/repo.go Normal file
View File

@ -0,0 +1,14 @@
package proj
import (
"context"
)
type Repository interface {
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
OpenProject(name string) error
DeleteProject(name string) error
Projects() ([]Project, error)
Close() error
}

View File

@ -13,8 +13,10 @@ type RepositoryProvider interface {
}
type Repository interface {
FindRequestLogs(ctx context.Context, opts FindRequestsOptions, scope *scope.Scope) ([]Request, error)
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error)
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
}

View File

@ -11,6 +11,7 @@ import (
"net/http"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/scope"
)
@ -19,9 +20,10 @@ type contextKey int
const LogBypassedKey contextKey = 0
const moduleName = "reqlog"
var (
ErrRequestNotFound = errors.New("reqlog: request not found")
ErrNoProject = errors.New("reqlog: no project")
)
type Request struct {
@ -42,42 +44,61 @@ type Response struct {
type Service struct {
BypassOutOfScopeRequests bool
FindReqsFilter FindRequestsFilter
scope *scope.Scope
repo Repository
}
type FindRequestsOptions struct {
OmitOutOfScope bool
type FindRequestsFilter struct {
OnlyInScope bool
}
type Config struct {
Scope *scope.Scope
Repository Repository
ProjectService *proj.Service
BypassOutOfScopeRequests bool
}
func NewService(cfg Config) *Service {
return &Service{
svc := &Service{
scope: cfg.Scope,
repo: cfg.Repository,
BypassOutOfScopeRequests: cfg.BypassOutOfScopeRequests,
}
cfg.ProjectService.OnProjectOpen(func(_ string) error {
err := svc.loadSettings()
if err == proj.ErrNoSettings {
return nil
}
if err != nil {
return fmt.Errorf("reqlog: could not load settings: %v", err)
}
return nil
})
cfg.ProjectService.OnProjectClose(func(_ string) error {
svc.unloadSettings()
return nil
})
return svc
}
func (svc *Service) FindRequests(ctx context.Context, opts FindRequestsOptions) ([]Request, error) {
var scope *scope.Scope
if opts.OmitOutOfScope {
scope = svc.scope
}
return svc.repo.FindRequestLogs(ctx, opts, scope)
func (svc *Service) FindRequests(ctx context.Context) ([]Request, error) {
return svc.repo.FindRequestLogs(ctx, svc.FindReqsFilter, svc.scope)
}
func (svc *Service) FindRequestLogByID(ctx context.Context, id int64) (Request, error) {
return svc.repo.FindRequestLogByID(ctx, id)
}
func (svc *Service) SetRequestLogFilter(ctx context.Context, filter FindRequestsFilter) error {
svc.FindReqsFilter = filter
return svc.repo.UpsertSettings(ctx, "reqlog", svc)
}
func (svc *Service) addRequest(
ctx context.Context,
req http.Request,
@ -136,7 +157,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
}
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
if err == ErrNoProject {
if err == proj.ErrNoProject {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
return
@ -184,3 +205,12 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
return nil
}
}
func (svc *Service) loadSettings() error {
return svc.repo.FindSettingsByModule(context.Background(), moduleName, svc)
}
func (svc *Service) unloadSettings() {
svc.BypassOutOfScopeRequests = false
svc.FindReqsFilter = FindRequestsFilter{}
}

8
pkg/scope/repo.go Normal file
View File

@ -0,0 +1,8 @@
package scope
import "context"
type Repository interface {
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
}

View File

@ -1,14 +1,23 @@
package scope
import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"sync"
"github.com/dstotijn/hetty/pkg/proj"
)
const moduleName = "scope"
type Scope struct {
mu sync.RWMutex
rules []Rule
repo Repository
mu sync.RWMutex
}
type Rule struct {
@ -22,11 +31,26 @@ type Header struct {
Value *regexp.Regexp
}
func New(rules []Rule) *Scope {
s := &Scope{}
if rules != nil {
s.rules = rules
func New(repo Repository, projService *proj.Service) *Scope {
s := &Scope{
repo: repo,
}
projService.OnProjectOpen(func(_ string) error {
err := s.load(context.Background())
if err == proj.ErrNoSettings {
return nil
}
if err != nil {
return fmt.Errorf("scope: could not load scope: %v", err)
}
return nil
})
projService.OnProjectClose(func(_ string) error {
s.unload()
return nil
})
return s
}
@ -36,11 +60,41 @@ func (s *Scope) Rules() []Rule {
return s.rules
}
func (s *Scope) SetRules(rules []Rule) {
func (s *Scope) load(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
var rules []Rule
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
if err == proj.ErrNoSettings {
return err
}
if err != nil {
return fmt.Errorf("scope: could not load scope settings: %v", err)
}
s.rules = rules
return nil
}
func (s *Scope) unload() {
s.mu.Lock()
defer s.mu.Unlock()
s.rules = nil
}
func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
return fmt.Errorf("scope: cannot set rules in repository: %v", err)
}
s.rules = rules
return nil
}
func (s *Scope) Match(req *http.Request, body []byte) bool {
@ -97,3 +151,87 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
return false
}
// MarshalJSON implements json.Marshaler.
func (r Rule) MarshalJSON() ([]byte, error) {
type headerDTO struct {
Key string
Value string
}
type ruleDTO struct {
URL string
Header headerDTO
Body string
}
dto := ruleDTO{
URL: regexpToString(r.URL),
Header: headerDTO{
Key: regexpToString(r.Header.Key),
Value: regexpToString(r.Header.Value),
},
Body: regexpToString(r.Body),
}
return json.Marshal(dto)
}
// UnmarshalJSON implements json.Unmarshaler.
func (r *Rule) UnmarshalJSON(data []byte) error {
type headerDTO struct {
Key string
Value string
}
type ruleDTO struct {
URL string
Header headerDTO
Body string
}
var dto ruleDTO
if err := json.Unmarshal(data, &dto); err != nil {
return err
}
url, err := stringToRegexp(dto.URL)
if err != nil {
return err
}
headerKey, err := stringToRegexp(dto.Header.Key)
if err != nil {
return err
}
headerValue, err := stringToRegexp(dto.Header.Value)
if err != nil {
return err
}
body, err := stringToRegexp(dto.Body)
if err != nil {
return err
}
*r = Rule{
URL: url,
Header: Header{
Key: headerKey,
Value: headerValue,
},
Body: body,
}
return nil
}
func regexpToString(r *regexp.Regexp) string {
if r == nil {
return ""
}
return r.String()
}
func stringToRegexp(s string) (*regexp.Regexp, error) {
if s == "" {
return nil, nil
}
return regexp.Compile(s)
}