diff --git a/.gitignore b/.gitignore index bc14fc2..7369c3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ **/rice-box.go +sqlite3_mod_regexp.dylib dist hetty hetty.bolt diff --git a/.goreleaser.yml b/.goreleaser.yml index a32bb10..2b6c40e 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,6 +1,7 @@ env: - GO111MODULE=on - CGO_ENABLED=1 + - CGO_CFLAGS=-I/go/pkg/mod/github.com/mattn/go-sqlite3@v1.14.4 before: hooks: - make clean @@ -15,6 +16,7 @@ builds: env: - CC=o64-clang - CXX=o64-clang++ + - CGO_LDFLAGS=-Wl,-undefined,dynamic_lookup flags: - -mod=readonly ldflags: @@ -24,22 +26,27 @@ builds: - amd64 goos: - linux - flags: - - -mod=readonly - ldflags: - - id: hetty-windows-amd64 - main: ./cmd/hetty - goarch: - - amd64 - goos: - - windows env: - - CC=x86_64-w64-mingw32-gcc - - CXX=x86_64-w64-mingw32-g++ + - CGO_CFLAGS=-I/go/pkg/mod/github.com/mattn/go-sqlite3@v1.14.4 + - CGO_LDFLAGS=-Wl,--unresolved-symbols=ignore-in-object-files flags: - -mod=readonly ldflags: - - -buildmode=exe + # - id: hetty-windows-amd64 + # main: ./cmd/hetty + # goarch: + # - amd64 + # goos: + # - windows + # env: + # - CC=x86_64-w64-mingw32-gcc + # - CXX=x86_64-w64-mingw32-g++ + # - CGO_CFLAGS=-I/go/pkg/mod/github.com/mattn/go-sqlite3@v1.14.4 + # - CGO_LDFLAGS=-Wl,--unresolved-symbols=ignore-in-object-files # Not working :( + # flags: + # - -mod=readonly + # ldflags: + # - -buildmode=exe archives: - replacements: darwin: macOS diff --git a/Dockerfile b/Dockerfile index 83a7729..bbab5cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,13 +2,16 @@ ARG GO_VERSION=1.15 ARG CGO_ENABLED=1 ARG NODE_VERSION=14.11 -FROM golang:${GO_VERSION}-alpine AS go-builder +FROM golang:${GO_VERSION} AS go-builder WORKDIR /app -RUN apk add --no-cache build-base +RUN apt-get update && \ + apt-get install -y build-essential COPY go.mod go.sum ./ RUN go mod download COPY cmd ./cmd COPY pkg ./pkg +ENV CGO_CFLAGS=-I/go/pkg/mod/github.com/mattn/go-sqlite3@v1.14.4 +ENV CGO_LDFLAGS=-Wl,--unresolved-symbols=ignore-in-object-files RUN go build -o hetty ./cmd/hetty FROM node:${NODE_VERSION}-alpine AS node-builder @@ -19,7 +22,7 @@ COPY admin/ . ENV NEXT_TELEMETRY_DISABLED=1 RUN yarn run export -FROM alpine:3.12 +FROM debian:buster-slim WORKDIR /app COPY --from=go-builder /app/hetty . COPY --from=node-builder /app/dist admin diff --git a/Makefile b/Makefile index 21e0c44..980ff0f 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ embed: .PHONY: embed build: embed - env CGO_ENABLED=1 go build ./cmd/hetty + env CGO_ENABLED=1 CGO_CFLAGS="-DUSE_LIBSQLITE3" CGO_LDFLAGS="-Wl,-undefined,dynamic_lookup" \ + go build -tags libsqlite3 ./cmd/hetty .PHONY: build clean: diff --git a/README.md b/README.md index 00a5f00..21f69c8 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ - [x] HTTP man-in-the-middle (MITM) proxy and GraphQL server. - [x] Web interface (Next.js) with proxy log viewer. -- [ ] Add scope support to the proxy. +- [x] Add scope support to the proxy. - [ ] Full text search (with regex) in proxy log viewer. - [x] Project management. - [ ] Sender module for sending manual HTTP requests, either from scratch or based diff --git a/admin/package.json b/admin/package.json index 5c324c9..6f20d56 100644 --- a/admin/package.json +++ b/admin/package.json @@ -6,7 +6,7 @@ "dev": "next dev", "build": "next build", "start": "next start", - "export": "next build && next export -o dist" + "export": "rm -rf .next && next build && next export -o dist" }, "dependencies": { "@apollo/client": "^3.2.0", diff --git a/admin/src/components/Layout.tsx b/admin/src/components/Layout.tsx index 2a32a2a..4d037d5 100644 --- a/admin/src/components/Layout.tsx +++ b/admin/src/components/Layout.tsx @@ -22,6 +22,7 @@ import HomeIcon from "@material-ui/icons/Home"; import SettingsEthernetIcon from "@material-ui/icons/SettingsEthernet"; import SendIcon from "@material-ui/icons/Send"; import FolderIcon from "@material-ui/icons/Folder"; +import LocationSearchingIcon from "@material-ui/icons/LocationSearching"; import ChevronLeftIcon from "@material-ui/icons/ChevronLeft"; import ChevronRightIcon from "@material-ui/icons/ChevronRight"; import clsx from "clsx"; @@ -33,6 +34,7 @@ export enum Page { ProxySetup, ProxyLogs, Sender, + Scope, } const drawerWidth = 240; @@ -236,6 +238,22 @@ export function Layout({ title, page, children }: Props): JSX.Element { + + + + + + + + + + + createStyles({ + root: { + padding: "2px 4px", + display: "flex", + alignItems: "center", + width: 400, + }, + input: { + marginLeft: theme.spacing(1), + flex: 1, + }, + iconButton: { + padding: 10, + }, + filterPopper: { + width: 400, + marginTop: 6, + zIndex: 99, + }, + filterOptions: { + padding: theme.spacing(2), + }, + filterLoading: { + marginRight: 1, + color: theme.palette.text.primary, + }, + }) +); + +export interface SearchFilter { + onlyInScope: boolean; +} + +function Search(): JSX.Element { + const classes = useStyles(); + const theme = useTheme(); + + const { loading: filterLoading, error: filterErr, data: filter } = useQuery( + FILTER + ); + + const client = useApolloClient(); + const [ + setFilterMutate, + { error: setFilterErr, loading: setFilterLoading }, + ] = useMutation<{ + setHttpRequestLogFilter: SearchFilter | null; + }>(SET_FILTER, { + update(_, { data: { setHttpRequestLogFilter } }) { + client.writeQuery({ + query: FILTER, + data: { + httpRequestLogFilter: setHttpRequestLogFilter, + }, + }); + }, + }); + + const filterRef = useRef(); + const [filterOpen, setFilterOpen] = useState(false); + + const handleSubmit = (e: React.SyntheticEvent) => { + e.preventDefault(); + }; + + const handleClickAway = (event: React.MouseEvent) => { + if (filterRef.current.contains(event.target as HTMLElement)) { + return; + } + setFilterOpen(false); + }; + + return ( + + + {filterErr && ( + + + Error fetching filter: {filterErr.message} + + + )} + {setFilterErr && ( + + + Error setting filter: {setFilterErr.message} + + + )} + + + setFilterOpen(!filterOpen)} + style={{ + color: + filter?.httpRequestLogFilter !== null + ? theme.palette.secondary.main + : "inherit", + }} + > + {filterLoading || setFilterLoading ? ( + + ) : ( + + )} + + + setFilterOpen(true)} + /> + + + + + + + + + + + setFilterMutate({ + variables: { + filter: { + ...withoutTypename(filter?.httpRequestLogFilter), + onlyInScope: e.target.checked, + }, + }, + }) + } + /> + } + label="Only show in-scope requests" + /> + + + + + ); +} + +export default Search; diff --git a/admin/src/components/scope/AddRule.tsx b/admin/src/components/scope/AddRule.tsx new file mode 100644 index 0000000..b27e90a --- /dev/null +++ b/admin/src/components/scope/AddRule.tsx @@ -0,0 +1,140 @@ +import { gql, useApolloClient, useMutation } from "@apollo/client"; +import { + Box, + Button, + CircularProgress, + createStyles, + FormControl, + FormControlLabel, + FormLabel, + makeStyles, + Radio, + RadioGroup, + TextField, + Theme, +} from "@material-ui/core"; +import AddIcon from "@material-ui/icons/Add"; +import { Alert } from "@material-ui/lab"; +import React from "react"; +import { SCOPE } from "./Rules"; + +const SET_SCOPE = gql` + mutation SetScope($scope: [ScopeRuleInput!]!) { + setScope(scope: $scope) { + url + } + } +`; + +const useStyles = makeStyles((theme: Theme) => + createStyles({ + ruleExpression: { + fontFamily: "'JetBrains Mono', monospace", + }, + }) +); + +function AddRule(): JSX.Element { + const classes = useStyles(); + + const [ruleType, setRuleType] = React.useState("url"); + const [expression, setExpression] = React.useState(null); + + const client = useApolloClient(); + const [setScope, { error, loading }] = useMutation(SET_SCOPE, { + onError() {}, + onCompleted() { + expression.value = ""; + }, + update(_, { data: { setScope } }) { + client.writeQuery({ + query: SCOPE, + data: { scope: setScope }, + }); + }, + }); + + const handleTypeChange = (e: React.ChangeEvent, value: string) => { + setRuleType(value); + }; + const handleSubmit = (e: React.SyntheticEvent) => { + e.preventDefault(); + let scope = []; + + try { + const data = client.readQuery({ + query: SCOPE, + }); + scope = data.scope; + } catch (e) {} + + setScope({ + variables: { + scope: [ + ...scope.map(({ url }) => ({ url })), + { url: expression.value }, + ], + }, + }); + }; + + return ( +
+ {error && ( + + Error adding rule: {error.message} + + )} +
+ + + Rule Type + + + } label="URL" /> + + + + { + setExpression(node); + }, + }} + margin="normal" + /> + + + + +
+
+ ); +} + +export default AddRule; diff --git a/admin/src/components/scope/RuleListItem.tsx b/admin/src/components/scope/RuleListItem.tsx new file mode 100644 index 0000000..f22d6d6 --- /dev/null +++ b/admin/src/components/scope/RuleListItem.tsx @@ -0,0 +1,86 @@ +import { gql, useApolloClient, useMutation, useQuery } from "@apollo/client"; +import { + Avatar, + Chip, + IconButton, + ListItem, + ListItemAvatar, + ListItemSecondaryAction, + ListItemText, + Tooltip, +} from "@material-ui/core"; +import CodeIcon from "@material-ui/icons/Code"; +import DeleteIcon from "@material-ui/icons/Delete"; +import React from "react"; +import { SCOPE } from "./Rules"; + +const SET_SCOPE = gql` + mutation SetScope($scope: [ScopeRuleInput!]!) { + setScope(scope: $scope) { + url + } + } +`; + +function RuleListItem({ scope, rule, index }): JSX.Element { + const client = useApolloClient(); + const [setScope, { loading }] = useMutation(SET_SCOPE, { + update(_, { data: { setScope } }) { + client.writeQuery({ + query: SCOPE, + data: { scope: setScope }, + }); + }, + }); + + const handleDelete = (index: number) => { + const clone = [...scope]; + clone.splice(index, 1); + setScope({ + variables: { + scope: clone.map(({ url }) => ({ url })), + }, + }); + }; + + return ( + + + + + + + + + + + + handleDelete(index)} disabled={loading}> + + + + + + + ); +} + +function RuleListItemText({ rule }): JSX.Element { + let text: JSX.Element; + + if (rule.url) { + text = {rule.url}; + } + + // TODO: Parse and handle rule.header and rule.body. + + return {text}; +} + +function RuleTypeChip({ rule }): JSX.Element { + if (rule.url) { + return ; + } +} + +export default RuleListItem; diff --git a/admin/src/components/scope/Rules.tsx b/admin/src/components/scope/Rules.tsx new file mode 100644 index 0000000..684a524 --- /dev/null +++ b/admin/src/components/scope/Rules.tsx @@ -0,0 +1,55 @@ +import { gql, useQuery } from "@apollo/client"; +import { + CircularProgress, + createStyles, + List, + makeStyles, + Theme, +} from "@material-ui/core"; +import { Alert } from "@material-ui/lab"; +import React from "react"; +import RuleListItem from "./RuleListItem"; + +const useStyles = makeStyles((theme: Theme) => + createStyles({ + rulesList: { + backgroundColor: theme.palette.background.paper, + }, + }) +); + +export const SCOPE = gql` + query Scope { + scope { + url + } + } +`; + +function Rules(): JSX.Element { + const classes = useStyles(); + const { loading, error, data } = useQuery(SCOPE); + + return ( +
+ {loading && } + {error && ( + Error fetching scope: {error.message} + )} + {data?.scope.length > 0 && ( + + {data.scope.map((rule, index) => ( + + ))} + + )} +
+ ); +} + +export default Rules; diff --git a/admin/src/lib/omitTypename.ts b/admin/src/lib/omitTypename.ts new file mode 100644 index 0000000..51818bf --- /dev/null +++ b/admin/src/lib/omitTypename.ts @@ -0,0 +1,5 @@ +const omitTypename = (key, value) => (key === "__typename" ? undefined : value); + +export function withoutTypename(input: any): any { + return JSON.parse(JSON.stringify(input), omitTypename); +} diff --git a/admin/src/pages/proxy/logs/index.tsx b/admin/src/pages/proxy/logs/index.tsx index fb4c035..4ae7f11 100644 --- a/admin/src/pages/proxy/logs/index.tsx +++ b/admin/src/pages/proxy/logs/index.tsx @@ -1,9 +1,15 @@ +import { Box } from "@material-ui/core"; + import LogsOverview from "../../../components/reqlog/LogsOverview"; import Layout, { Page } from "../../../components/Layout"; +import Search from "../../../components/reqlog/Search"; function ProxyLogs(): JSX.Element { return ( + + + ); diff --git a/admin/src/pages/scope/index.tsx b/admin/src/pages/scope/index.tsx new file mode 100644 index 0000000..aa3cde2 --- /dev/null +++ b/admin/src/pages/scope/index.tsx @@ -0,0 +1,39 @@ +import { Box, Divider, Grid, Typography } from "@material-ui/core"; +import React from "react"; + +import Layout, { Page } from "../../components/Layout"; +import AddRule from "../../components/scope/AddRule"; +import Rules from "../../components/scope/Rules"; + +function Index(): JSX.Element { + return ( + + + + Scope + + + Scope rules are used by various modules in Hetty and can influence + their behavior. For example: the Proxy logs module can match incoming + requests against scope rules and decide its behavior (e.g. log or + bypass) based on the outcome of the match. All scope configuration is + stored per project. + + + + + + + + + + + + + + + + ); +} + +export default Index; diff --git a/cmd/hetty/main.go b/cmd/hetty/main.go index e9e8aab..f402cc9 100644 --- a/cmd/hetty/main.go +++ b/cmd/hetty/main.go @@ -11,9 +11,11 @@ import ( rice "github.com/GeertJohan/go.rice" "github.com/dstotijn/hetty/pkg/api" + "github.com/dstotijn/hetty/pkg/db/sqlite" "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proxy" "github.com/dstotijn/hetty/pkg/reqlog" + "github.com/dstotijn/hetty/pkg/scope" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/playground" @@ -58,15 +60,23 @@ func main() { log.Fatalf("[FATAL] Could not create/load CA key pair: %v", err) } - projService, err := proj.NewService(projPath) + db, err := sqlite.New(projPath) + if err != nil { + log.Fatalf("[FATAL] Could not initialize database client: %v", err) + } + + projService, err := proj.NewService(db) if err != nil { log.Fatalf("[FATAL] Could not create new project service: %v", err) } defer projService.Close() + scope := scope.New(db, projService) + reqLogService := reqlog.NewService(reqlog.Config{ - Scope: projService.Scope, - Repository: projService.Database(), + Scope: scope, + ProjectService: projService, + Repository: db, }) p, err := proxy.NewProxy(caCert, caKey) @@ -102,6 +112,7 @@ func main() { adminRouter.Path("/api/graphql/").Handler(handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{ RequestLogService: reqLogService, ProjectService: projService, + ScopeService: scope, }}))) // Admin interface. diff --git a/modd.conf b/modd.conf index d538c37..cd06c3d 100644 --- a/modd.conf +++ b/modd.conf @@ -1,3 +1,4 @@ +@env = CGO_CFLAGS=-DUSE_LIBSQLITE3 CGO_LDFLAGS=-Wl,-undefined,dynamic_lookup **/*.go { - daemon +sigterm: go run ./cmd/hetty + daemon +sigterm: @env go run -tags libsqlite3 ./cmd/hetty } \ No newline at end of file diff --git a/pkg/api/generated.go b/pkg/api/generated.go index 0450e93..f0abe64 100644 --- a/pkg/api/generated.go +++ b/pkg/api/generated.go @@ -67,6 +67,10 @@ type ComplexityRoot struct { URL func(childComplexity int) int } + HTTPRequestLogFilter struct { + OnlyInScope func(childComplexity int) int + } + HTTPResponseLog struct { Body func(childComplexity int) int Headers func(childComplexity int) int @@ -77,9 +81,11 @@ type ComplexityRoot struct { } Mutation struct { - CloseProject func(childComplexity int) int - DeleteProject func(childComplexity int, name string) int - OpenProject func(childComplexity int, name string) int + CloseProject func(childComplexity int) int + DeleteProject func(childComplexity int, name string) int + OpenProject func(childComplexity int, name string) int + SetHTTPRequestLogFilter func(childComplexity int, filter *HTTPRequestLogFilterInput) int + SetScope func(childComplexity int, scope []ScopeRuleInput) int } Project struct { @@ -88,10 +94,23 @@ type ComplexityRoot struct { } Query struct { - ActiveProject func(childComplexity int) int - HTTPRequestLog func(childComplexity int, id int64) int - HTTPRequestLogs func(childComplexity int) int - Projects func(childComplexity int) int + ActiveProject func(childComplexity int) int + HTTPRequestLog func(childComplexity int, id int64) int + HTTPRequestLogFilter func(childComplexity int) int + HTTPRequestLogs func(childComplexity int) int + Projects func(childComplexity int) int + Scope func(childComplexity int) int + } + + ScopeHeader struct { + Key func(childComplexity int) int + Value func(childComplexity int) int + } + + ScopeRule struct { + Body func(childComplexity int) int + Header func(childComplexity int) int + URL func(childComplexity int) int } } @@ -99,12 +118,16 @@ type MutationResolver interface { OpenProject(ctx context.Context, name string) (*Project, error) CloseProject(ctx context.Context) (*CloseProjectResult, error) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) + SetScope(ctx context.Context, scope []ScopeRuleInput) ([]ScopeRule, error) + SetHTTPRequestLogFilter(ctx context.Context, filter *HTTPRequestLogFilterInput) (*HTTPRequestLogFilter, error) } type QueryResolver interface { HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) + HTTPRequestLogFilter(ctx context.Context) (*HTTPRequestLogFilter, error) ActiveProject(ctx context.Context) (*Project, error) Projects(ctx context.Context) ([]Project, error) + Scope(ctx context.Context) ([]ScopeRule, error) } type executableSchema struct { @@ -206,6 +229,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.HTTPRequestLog.URL(childComplexity), true + case "HttpRequestLogFilter.onlyInScope": + if e.complexity.HTTPRequestLogFilter.OnlyInScope == nil { + break + } + + return e.complexity.HTTPRequestLogFilter.OnlyInScope(childComplexity), true + case "HttpResponseLog.body": if e.complexity.HTTPResponseLog.Body == nil { break @@ -279,6 +309,30 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.OpenProject(childComplexity, args["name"].(string)), true + case "Mutation.setHttpRequestLogFilter": + if e.complexity.Mutation.SetHTTPRequestLogFilter == nil { + break + } + + args, err := ec.field_Mutation_setHttpRequestLogFilter_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.SetHTTPRequestLogFilter(childComplexity, args["filter"].(*HTTPRequestLogFilterInput)), true + + case "Mutation.setScope": + if e.complexity.Mutation.SetScope == nil { + break + } + + args, err := ec.field_Mutation_setScope_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.SetScope(childComplexity, args["scope"].([]ScopeRuleInput)), true + case "Project.isActive": if e.complexity.Project.IsActive == nil { break @@ -312,6 +366,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.HTTPRequestLog(childComplexity, args["id"].(int64)), true + case "Query.httpRequestLogFilter": + if e.complexity.Query.HTTPRequestLogFilter == nil { + break + } + + return e.complexity.Query.HTTPRequestLogFilter(childComplexity), true + case "Query.httpRequestLogs": if e.complexity.Query.HTTPRequestLogs == nil { break @@ -326,6 +387,48 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Projects(childComplexity), true + case "Query.scope": + if e.complexity.Query.Scope == nil { + break + } + + return e.complexity.Query.Scope(childComplexity), true + + case "ScopeHeader.key": + if e.complexity.ScopeHeader.Key == nil { + break + } + + return e.complexity.ScopeHeader.Key(childComplexity), true + + case "ScopeHeader.value": + if e.complexity.ScopeHeader.Value == nil { + break + } + + return e.complexity.ScopeHeader.Value(childComplexity), true + + case "ScopeRule.body": + if e.complexity.ScopeRule.Body == nil { + break + } + + return e.complexity.ScopeRule.Body(childComplexity), true + + case "ScopeRule.header": + if e.complexity.ScopeRule.Header == nil { + break + } + + return e.complexity.ScopeRule.Header(childComplexity), true + + case "ScopeRule.url": + if e.complexity.ScopeRule.URL == nil { + break + } + + return e.complexity.ScopeRule.URL(childComplexity), true + } return 0, false } @@ -420,6 +523,28 @@ type Project { isActive: Boolean! } +type ScopeRule { + url: Regexp + header: ScopeHeader + body: Regexp +} + +input ScopeRuleInput { + url: Regexp + header: ScopeHeaderInput + body: Regexp +} + +type ScopeHeader { + key: Regexp + value: Regexp +} + +input ScopeHeaderInput { + key: Regexp + value: Regexp +} + type CloseProjectResult { success: Boolean! } @@ -428,17 +553,31 @@ type DeleteProjectResult { success: Boolean! } +input HttpRequestLogFilterInput { + onlyInScope: Boolean +} + +type HttpRequestLogFilter { + onlyInScope: Boolean! +} + type Query { httpRequestLog(id: ID!): HttpRequestLog httpRequestLogs: [HttpRequestLog!]! + httpRequestLogFilter: HttpRequestLogFilter activeProject: Project projects: [Project!]! + scope: [ScopeRule!]! } type Mutation { openProject(name: String!): Project closeProject: CloseProjectResult! deleteProject(name: String!): DeleteProjectResult! + setScope(scope: [ScopeRuleInput!]!): [ScopeRule!]! + setHttpRequestLogFilter( + filter: HttpRequestLogFilterInput + ): HttpRequestLogFilter } enum HttpMethod { @@ -454,6 +593,7 @@ enum HttpMethod { } scalar Time +scalar Regexp `, BuiltIn: false}, } var parsedSchema = gqlparser.MustLoadSchema(sources...) @@ -492,6 +632,36 @@ func (ec *executionContext) field_Mutation_openProject_args(ctx context.Context, return args, nil } +func (ec *executionContext) field_Mutation_setHttpRequestLogFilter_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *HTTPRequestLogFilterInput + if tmp, ok := rawArgs["filter"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("filter")) + arg0, err = ec.unmarshalOHttpRequestLogFilterInput2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilterInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["filter"] = arg0 + return args, nil +} + +func (ec *executionContext) field_Mutation_setScope_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 []ScopeRuleInput + if tmp, ok := rawArgs["scope"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scope")) + arg0, err = ec.unmarshalNScopeRuleInput2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleInputᚄ(ctx, tmp) + if err != nil { + return nil, err + } + } + args["scope"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -974,6 +1144,41 @@ func (ec *executionContext) _HttpRequestLog_response(ctx context.Context, field return ec.marshalOHttpResponseLog2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPResponseLog(ctx, field.Selections, res) } +func (ec *executionContext) _HttpRequestLogFilter_onlyInScope(ctx context.Context, field graphql.CollectedField, obj *HTTPRequestLogFilter) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "HttpRequestLogFilter", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.OnlyInScope, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _HttpResponseLog_requestId(ctx context.Context, field graphql.CollectedField, obj *HTTPResponseLog) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1297,6 +1502,87 @@ func (ec *executionContext) _Mutation_deleteProject(ctx context.Context, field g return ec.marshalNDeleteProjectResult2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐDeleteProjectResult(ctx, field.Selections, res) } +func (ec *executionContext) _Mutation_setScope(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Mutation", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Mutation_setScope_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().SetScope(rctx, args["scope"].([]ScopeRuleInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]ScopeRule) + fc.Result = res + return ec.marshalNScopeRule2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) _Mutation_setHttpRequestLogFilter(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Mutation", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Mutation_setHttpRequestLogFilter_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().SetHTTPRequestLogFilter(rctx, args["filter"].(*HTTPRequestLogFilterInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*HTTPRequestLogFilter) + fc.Result = res + return ec.marshalOHttpRequestLogFilter2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilter(ctx, field.Selections, res) +} + func (ec *executionContext) _Project_name(ctx context.Context, field graphql.CollectedField, obj *Project) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1441,6 +1727,38 @@ func (ec *executionContext) _Query_httpRequestLogs(ctx context.Context, field gr return ec.marshalNHttpRequestLog2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogᚄ(ctx, field.Selections, res) } +func (ec *executionContext) _Query_httpRequestLogFilter(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().HTTPRequestLogFilter(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*HTTPRequestLogFilter) + fc.Result = res + return ec.marshalOHttpRequestLogFilter2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilter(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_activeProject(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1508,6 +1826,41 @@ func (ec *executionContext) _Query_projects(ctx context.Context, field graphql.C return ec.marshalNProject2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐProjectᚄ(ctx, field.Selections, res) } +func (ec *executionContext) _Query_scope(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Scope(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]ScopeRule) + fc.Result = res + return ec.marshalNScopeRule2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleᚄ(ctx, field.Selections, res) +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1579,6 +1932,166 @@ func (ec *executionContext) _Query___schema(ctx context.Context, field graphql.C return ec.marshalO__Schema2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐSchema(ctx, field.Selections, res) } +func (ec *executionContext) _ScopeHeader_key(ctx context.Context, field graphql.CollectedField, obj *ScopeHeader) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ScopeHeader", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Key, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalORegexp2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) _ScopeHeader_value(ctx context.Context, field graphql.CollectedField, obj *ScopeHeader) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ScopeHeader", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Value, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalORegexp2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) _ScopeRule_url(ctx context.Context, field graphql.CollectedField, obj *ScopeRule) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ScopeRule", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.URL, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalORegexp2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) _ScopeRule_header(ctx context.Context, field graphql.CollectedField, obj *ScopeRule) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ScopeRule", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Header, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*ScopeHeader) + fc.Result = res + return ec.marshalOScopeHeader2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeHeader(ctx, field.Selections, res) +} + +func (ec *executionContext) _ScopeRule_body(ctx context.Context, field graphql.CollectedField, obj *ScopeRule) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ScopeRule", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Body, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalORegexp2ᚖstring(ctx, field.Selections, res) +} + func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2666,6 +3179,90 @@ func (ec *executionContext) ___Type_ofType(ctx context.Context, field graphql.Co // region **************************** input.gotpl ***************************** +func (ec *executionContext) unmarshalInputHttpRequestLogFilterInput(ctx context.Context, obj interface{}) (HTTPRequestLogFilterInput, error) { + var it HTTPRequestLogFilterInput + var asMap = obj.(map[string]interface{}) + + for k, v := range asMap { + switch k { + case "onlyInScope": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("onlyInScope")) + it.OnlyInScope, err = ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + +func (ec *executionContext) unmarshalInputScopeHeaderInput(ctx context.Context, obj interface{}) (ScopeHeaderInput, error) { + var it ScopeHeaderInput + var asMap = obj.(map[string]interface{}) + + for k, v := range asMap { + switch k { + case "key": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("key")) + it.Key, err = ec.unmarshalORegexp2ᚖstring(ctx, v) + if err != nil { + return it, err + } + case "value": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("value")) + it.Value, err = ec.unmarshalORegexp2ᚖstring(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + +func (ec *executionContext) unmarshalInputScopeRuleInput(ctx context.Context, obj interface{}) (ScopeRuleInput, error) { + var it ScopeRuleInput + var asMap = obj.(map[string]interface{}) + + for k, v := range asMap { + switch k { + case "url": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("url")) + it.URL, err = ec.unmarshalORegexp2ᚖstring(ctx, v) + if err != nil { + return it, err + } + case "header": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("header")) + it.Header, err = ec.unmarshalOScopeHeaderInput2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeHeaderInput(ctx, v) + if err != nil { + return it, err + } + case "body": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("body")) + it.Body, err = ec.unmarshalORegexp2ᚖstring(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + // endregion **************************** input.gotpl ***************************** // region ************************** interface.gotpl *************************** @@ -2816,6 +3413,33 @@ func (ec *executionContext) _HttpRequestLog(ctx context.Context, sel ast.Selecti return out } +var httpRequestLogFilterImplementors = []string{"HttpRequestLogFilter"} + +func (ec *executionContext) _HttpRequestLogFilter(ctx context.Context, sel ast.SelectionSet, obj *HTTPRequestLogFilter) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, httpRequestLogFilterImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("HttpRequestLogFilter") + case "onlyInScope": + out.Values[i] = ec._HttpRequestLogFilter_onlyInScope(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var httpResponseLogImplementors = []string{"HttpResponseLog"} func (ec *executionContext) _HttpResponseLog(ctx context.Context, sel ast.SelectionSet, obj *HTTPResponseLog) graphql.Marshaler { @@ -2892,6 +3516,13 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) if out.Values[i] == graphql.Null { invalids++ } + case "setScope": + out.Values[i] = ec._Mutation_setScope(ctx, field) + if out.Values[i] == graphql.Null { + invalids++ + } + case "setHttpRequestLogFilter": + out.Values[i] = ec._Mutation_setHttpRequestLogFilter(ctx, field) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -2975,6 +3606,17 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "httpRequestLogFilter": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_httpRequestLogFilter(ctx, field) + return res + }) case "activeProject": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -3000,6 +3642,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "scope": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_scope(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "__type": out.Values[i] = ec._Query___type(ctx, field) case "__schema": @@ -3015,6 +3671,60 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return out } +var scopeHeaderImplementors = []string{"ScopeHeader"} + +func (ec *executionContext) _ScopeHeader(ctx context.Context, sel ast.SelectionSet, obj *ScopeHeader) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, scopeHeaderImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ScopeHeader") + case "key": + out.Values[i] = ec._ScopeHeader_key(ctx, field, obj) + case "value": + out.Values[i] = ec._ScopeHeader_value(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + +var scopeRuleImplementors = []string{"ScopeRule"} + +func (ec *executionContext) _ScopeRule(ctx context.Context, sel ast.SelectionSet, obj *ScopeRule) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, scopeRuleImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ScopeRule") + case "url": + out.Values[i] = ec._ScopeRule_url(ctx, field, obj) + case "header": + out.Values[i] = ec._ScopeRule_header(ctx, field, obj) + case "body": + out.Values[i] = ec._ScopeRule_body(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var __DirectiveImplementors = []string{"__Directive"} func (ec *executionContext) ___Directive(ctx context.Context, sel ast.SelectionSet, obj *introspection.Directive) graphql.Marshaler { @@ -3466,6 +4176,73 @@ func (ec *executionContext) marshalNProject2ᚕgithubᚗcomᚋdstotijnᚋhetty return ret } +func (ec *executionContext) marshalNScopeRule2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRule(ctx context.Context, sel ast.SelectionSet, v ScopeRule) graphql.Marshaler { + return ec._ScopeRule(ctx, sel, &v) +} + +func (ec *executionContext) marshalNScopeRule2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleᚄ(ctx context.Context, sel ast.SelectionSet, v []ScopeRule) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNScopeRule2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRule(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + return ret +} + +func (ec *executionContext) unmarshalNScopeRuleInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleInput(ctx context.Context, v interface{}) (ScopeRuleInput, error) { + res, err := ec.unmarshalInputScopeRuleInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) unmarshalNScopeRuleInput2ᚕgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleInputᚄ(ctx context.Context, v interface{}) ([]ScopeRuleInput, error) { + var vSlice []interface{} + if v != nil { + if tmp1, ok := v.([]interface{}); ok { + vSlice = tmp1 + } else { + vSlice = []interface{}{v} + } + } + var err error + res := make([]ScopeRuleInput, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNScopeRuleInput2githubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeRuleInput(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + func (ec *executionContext) unmarshalNString2string(ctx context.Context, v interface{}) (string, error) { res, err := graphql.UnmarshalString(v) return res, graphql.ErrorOnPath(ctx, err) @@ -3756,6 +4533,21 @@ func (ec *executionContext) marshalOHttpRequestLog2ᚖgithubᚗcomᚋdstotijnᚋ return ec._HttpRequestLog(ctx, sel, v) } +func (ec *executionContext) marshalOHttpRequestLogFilter2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilter(ctx context.Context, sel ast.SelectionSet, v *HTTPRequestLogFilter) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._HttpRequestLogFilter(ctx, sel, v) +} + +func (ec *executionContext) unmarshalOHttpRequestLogFilterInput2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPRequestLogFilterInput(ctx context.Context, v interface{}) (*HTTPRequestLogFilterInput, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputHttpRequestLogFilterInput(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalOHttpResponseLog2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐHTTPResponseLog(ctx context.Context, sel ast.SelectionSet, v *HTTPResponseLog) graphql.Marshaler { if v == nil { return graphql.Null @@ -3770,6 +4562,36 @@ func (ec *executionContext) marshalOProject2ᚖgithubᚗcomᚋdstotijnᚋhetty return ec._Project(ctx, sel, v) } +func (ec *executionContext) unmarshalORegexp2ᚖstring(ctx context.Context, v interface{}) (*string, error) { + if v == nil { + return nil, nil + } + res, err := graphql.UnmarshalString(v) + return &res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) marshalORegexp2ᚖstring(ctx context.Context, sel ast.SelectionSet, v *string) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return graphql.MarshalString(*v) +} + +func (ec *executionContext) marshalOScopeHeader2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeHeader(ctx context.Context, sel ast.SelectionSet, v *ScopeHeader) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._ScopeHeader(ctx, sel, v) +} + +func (ec *executionContext) unmarshalOScopeHeaderInput2ᚖgithubᚗcomᚋdstotijnᚋhettyᚋpkgᚋapiᚐScopeHeaderInput(ctx context.Context, v interface{}) (*ScopeHeaderInput, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputScopeHeaderInput(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) unmarshalOString2string(ctx context.Context, v interface{}) (string, error) { res, err := graphql.UnmarshalString(v) return res, graphql.ErrorOnPath(ctx, err) diff --git a/pkg/api/models_gen.go b/pkg/api/models_gen.go index 46a7266..81c83d4 100644 --- a/pkg/api/models_gen.go +++ b/pkg/api/models_gen.go @@ -33,6 +33,14 @@ type HTTPRequestLog struct { Response *HTTPResponseLog `json:"response"` } +type HTTPRequestLogFilter struct { + OnlyInScope bool `json:"onlyInScope"` +} + +type HTTPRequestLogFilterInput struct { + OnlyInScope *bool `json:"onlyInScope"` +} + type HTTPResponseLog struct { RequestID int64 `json:"requestId"` Proto string `json:"proto"` @@ -47,6 +55,28 @@ type Project struct { IsActive bool `json:"isActive"` } +type ScopeHeader struct { + Key *string `json:"key"` + Value *string `json:"value"` +} + +type ScopeHeaderInput struct { + Key *string `json:"key"` + Value *string `json:"value"` +} + +type ScopeRule struct { + URL *string `json:"url"` + Header *ScopeHeader `json:"header"` + Body *string `json:"body"` +} + +type ScopeRuleInput struct { + URL *string `json:"url"` + Header *ScopeHeaderInput `json:"header"` + Body *string `json:"body"` +} + type HTTPMethod string const ( diff --git a/pkg/api/resolvers.go b/pkg/api/resolvers.go index 93f4fa0..e15b7b1 100644 --- a/pkg/api/resolvers.go +++ b/pkg/api/resolvers.go @@ -5,17 +5,20 @@ package api import ( "context" "fmt" + "regexp" "strings" "github.com/99designs/gqlgen/graphql" "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/reqlog" + "github.com/dstotijn/hetty/pkg/scope" "github.com/vektah/gqlparser/v2/gqlerror" ) type Resolver struct { RequestLogService *reqlog.Service ProjectService *proj.Service + ScopeService *scope.Scope } type queryResolver struct{ *Resolver } @@ -25,9 +28,8 @@ func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) { - opts := reqlog.FindRequestsOptions{OmitOutOfScope: false} - reqs, err := r.RequestLogService.FindRequests(ctx, opts) - if err == reqlog.ErrNoProject { + reqs, err := r.RequestLogService.FindRequests(ctx) + if err == proj.ErrNoProject { return nil, &gqlerror.Error{ Path: graphql.GetPath(ctx), Message: "No active project.", @@ -133,7 +135,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) { } func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) { - p, err := r.ProjectService.Open(name) + p, err := r.ProjectService.Open(ctx, name) if err == proj.ErrInvalidName { return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.") } @@ -178,6 +180,19 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) { return projects, nil } +func (r *queryResolver) Scope(ctx context.Context) ([]ScopeRule, error) { + rules := r.ScopeService.Rules() + return scopeToScopeRules(rules), nil +} + +func regexpToStringPtr(r *regexp.Regexp) *string { + if r == nil { + return nil + } + s := r.String() + return &s +} + func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) { if err := r.ProjectService.Close(); err != nil { return nil, fmt.Errorf("could not close project: %v", err) @@ -193,3 +208,107 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del Success: true, }, nil } + +func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) { + rules := make([]scope.Rule, len(input)) + for i, rule := range input { + u, err := stringPtrToRegexp(rule.URL) + if err != nil { + return nil, fmt.Errorf("invalid URL in scope rule: %v", err) + } + var headerKey, headerValue *regexp.Regexp + if rule.Header != nil { + headerKey, err = stringPtrToRegexp(rule.Header.Key) + if err != nil { + return nil, fmt.Errorf("invalid header key in scope rule: %v", err) + } + headerValue, err = stringPtrToRegexp(rule.Header.Key) + if err != nil { + return nil, fmt.Errorf("invalid header value in scope rule: %v", err) + } + } + body, err := stringPtrToRegexp(rule.Body) + if err != nil { + return nil, fmt.Errorf("invalid body in scope rule: %v", err) + } + rules[i] = scope.Rule{ + URL: u, + Header: scope.Header{ + Key: headerKey, + Value: headerValue, + }, + Body: body, + } + } + + if err := r.ScopeService.SetRules(ctx, rules); err != nil { + return nil, fmt.Errorf("could not set scope: %v", err) + } + + return scopeToScopeRules(rules), nil +} + +func (r *queryResolver) HTTPRequestLogFilter(ctx context.Context) (*HTTPRequestLogFilter, error) { + return findReqFilterToHTTPReqLogFilter(r.RequestLogService.FindReqsFilter), nil +} + +func (r *mutationResolver) SetHTTPRequestLogFilter( + ctx context.Context, + input *HTTPRequestLogFilterInput, +) (*HTTPRequestLogFilter, error) { + filter := findRequestsFilterFromInput(input) + if err := r.RequestLogService.SetRequestLogFilter(ctx, filter); err != nil { + return nil, fmt.Errorf("could not set request log filter: %v", err) + } + + empty := reqlog.FindRequestsFilter{} + if filter == empty { + return nil, nil + } + return findReqFilterToHTTPReqLogFilter(filter), nil +} + +func stringPtrToRegexp(s *string) (*regexp.Regexp, error) { + if s == nil { + return nil, nil + } + return regexp.Compile(*s) +} + +func scopeToScopeRules(rules []scope.Rule) []ScopeRule { + scopeRules := make([]ScopeRule, len(rules)) + for i, rule := range rules { + scopeRules[i].URL = regexpToStringPtr(rule.URL) + if rule.Header.Key != nil || rule.Header.Value != nil { + scopeRules[i].Header = &ScopeHeader{ + Key: regexpToStringPtr(rule.Header.Key), + Value: regexpToStringPtr(rule.Header.Value), + } + } + scopeRules[i].Body = regexpToStringPtr(rule.Body) + } + return scopeRules +} + +func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlog.FindRequestsFilter) { + if input == nil { + return + } + if input.OnlyInScope != nil { + filter.OnlyInScope = *input.OnlyInScope + } + + return +} + +func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *HTTPRequestLogFilter { + empty := reqlog.FindRequestsFilter{} + if findReqFilter == empty { + return nil + } + httpReqLogFilter := &HTTPRequestLogFilter{ + OnlyInScope: findReqFilter.OnlyInScope, + } + + return httpReqLogFilter +} diff --git a/pkg/api/schema.graphql b/pkg/api/schema.graphql index edc9859..32691b2 100644 --- a/pkg/api/schema.graphql +++ b/pkg/api/schema.graphql @@ -28,6 +28,28 @@ type Project { isActive: Boolean! } +type ScopeRule { + url: Regexp + header: ScopeHeader + body: Regexp +} + +input ScopeRuleInput { + url: Regexp + header: ScopeHeaderInput + body: Regexp +} + +type ScopeHeader { + key: Regexp + value: Regexp +} + +input ScopeHeaderInput { + key: Regexp + value: Regexp +} + type CloseProjectResult { success: Boolean! } @@ -36,17 +58,31 @@ type DeleteProjectResult { success: Boolean! } +input HttpRequestLogFilterInput { + onlyInScope: Boolean +} + +type HttpRequestLogFilter { + onlyInScope: Boolean! +} + type Query { httpRequestLog(id: ID!): HttpRequestLog httpRequestLogs: [HttpRequestLog!]! + httpRequestLogFilter: HttpRequestLogFilter activeProject: Project projects: [Project!]! + scope: [ScopeRule!]! } type Mutation { openProject(name: String!): Project closeProject: CloseProjectResult! deleteProject(name: String!): DeleteProjectResult! + setScope(scope: [ScopeRuleInput!]!): [ScopeRule!]! + setHttpRequestLogFilter( + filter: HttpRequestLogFilterInput + ): HttpRequestLogFilter } enum HttpMethod { @@ -62,3 +98,4 @@ enum HttpMethod { } scalar Time +scalar Regexp diff --git a/pkg/db/sqlite/regexp/regexp.c b/pkg/db/sqlite/regexp/regexp.c new file mode 100644 index 0000000..de56133 --- /dev/null +++ b/pkg/db/sqlite/regexp/regexp.c @@ -0,0 +1,759 @@ +/* +** 2012-11-13 +** +** The author disclaims copyright to this source code. In place of +** a legal notice, here is a blessing: +** +** May you do good and not evil. +** May you find forgiveness for yourself and forgive others. +** May you share freely, never taking more than you give. +** +****************************************************************************** +** +** The code in this file implements a compact but reasonably +** efficient regular-expression matcher for posix extended regular +** expressions against UTF8 text. +** +** This file is an SQLite extension. It registers a single function +** named "regexp(A,B)" where A is the regular expression and B is the +** string to be matched. By registering this function, SQLite will also +** then implement the "B regexp A" operator. Note that with the function +** the regular expression comes first, but with the operator it comes +** second. +** +** The following regular expression syntax is supported: +** +** X* zero or more occurrences of X +** X+ one or more occurrences of X +** X? zero or one occurrences of X +** X{p,q} between p and q occurrences of X +** (X) match X +** X|Y X or Y +** ^X X occurring at the beginning of the string +** X$ X occurring at the end of the string +** . Match any single character +** \c Character c where c is one of \{}()[]|*+?. +** \c C-language escapes for c in afnrtv. ex: \t or \n +** \uXXXX Where XXXX is exactly 4 hex digits, unicode value XXXX +** \xXX Where XX is exactly 2 hex digits, unicode value XX +** [abc] Any single character from the set abc +** [^abc] Any single character not in the set abc +** [a-z] Any single character in the range a-z +** [^a-z] Any single character not in the range a-z +** \b Word boundary +** \w Word character. [A-Za-z0-9_] +** \W Non-word character +** \d Digit +** \D Non-digit +** \s Whitespace character +** \S Non-whitespace character +** +** A nondeterministic finite automaton (NFA) is used for matching, so the +** performance is bounded by O(N*M) where N is the size of the regular +** expression and M is the size of the input string. The matcher never +** exhibits exponential behavior. Note that the X{p,q} operator expands +** to p copies of X following by q-p copies of X? and that the size of the +** regular expression in the O(N*M) performance bound is computed after +** this expansion. +*/ +#include +#include +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 + +/* +** The following #defines change the names of some functions implemented in +** this file to prevent name collisions with C-library functions of the +** same name. +*/ +#define re_match sqlite3re_match +#define re_compile sqlite3re_compile +#define re_free sqlite3re_free + +/* The end-of-input character */ +#define RE_EOF 0 /* End of input */ + +/* The NFA is implemented as sequence of opcodes taken from the following +** set. Each opcode has a single integer argument. +*/ +#define RE_OP_MATCH 1 /* Match the one character in the argument */ +#define RE_OP_ANY 2 /* Match any one character. (Implements ".") */ +#define RE_OP_ANYSTAR 3 /* Special optimized version of .* */ +#define RE_OP_FORK 4 /* Continue to both next and opcode at iArg */ +#define RE_OP_GOTO 5 /* Jump to opcode at iArg */ +#define RE_OP_ACCEPT 6 /* Halt and indicate a successful match */ +#define RE_OP_CC_INC 7 /* Beginning of a [...] character class */ +#define RE_OP_CC_EXC 8 /* Beginning of a [^...] character class */ +#define RE_OP_CC_VALUE 9 /* Single value in a character class */ +#define RE_OP_CC_RANGE 10 /* Range of values in a character class */ +#define RE_OP_WORD 11 /* Perl word character [A-Za-z0-9_] */ +#define RE_OP_NOTWORD 12 /* Not a perl word character */ +#define RE_OP_DIGIT 13 /* digit: [0-9] */ +#define RE_OP_NOTDIGIT 14 /* Not a digit */ +#define RE_OP_SPACE 15 /* space: [ \t\n\r\v\f] */ +#define RE_OP_NOTSPACE 16 /* Not a digit */ +#define RE_OP_BOUNDARY 17 /* Boundary between word and non-word */ + +/* Each opcode is a "state" in the NFA */ +typedef unsigned short ReStateNumber; + +/* Because this is an NFA and not a DFA, multiple states can be active at +** once. An instance of the following object records all active states in +** the NFA. The implementation is optimized for the common case where the +** number of actives states is small. +*/ +typedef struct ReStateSet { + unsigned nState; /* Number of current states */ + ReStateNumber *aState; /* Current states */ +} ReStateSet; + +/* An input string read one character at a time. +*/ +typedef struct ReInput ReInput; +struct ReInput { + const unsigned char *z; /* All text */ + int i; /* Next byte to read */ + int mx; /* EOF when i>=mx */ +}; + +/* A compiled NFA (or an NFA that is in the process of being compiled) is +** an instance of the following object. +*/ +typedef struct ReCompiled ReCompiled; +struct ReCompiled { + ReInput sIn; /* Regular expression text */ + const char *zErr; /* Error message to return */ + char *aOp; /* Operators for the virtual machine */ + int *aArg; /* Arguments to each operator */ + unsigned (*xNextChar)(ReInput*); /* Next character function */ + unsigned char zInit[12]; /* Initial text to match */ + int nInit; /* Number of characters in zInit */ + unsigned nState; /* Number of entries in aOp[] and aArg[] */ + unsigned nAlloc; /* Slots allocated for aOp[] and aArg[] */ +}; + +/* Add a state to the given state set if it is not already there */ +static void re_add_state(ReStateSet *pSet, int newState){ + unsigned i; + for(i=0; inState; i++) if( pSet->aState[i]==newState ) return; + pSet->aState[pSet->nState++] = (ReStateNumber)newState; +} + +/* Extract the next unicode character from *pzIn and return it. Advance +** *pzIn to the first byte past the end of the character returned. To +** be clear: this routine converts utf8 to unicode. This routine is +** optimized for the common case where the next character is a single byte. +*/ +static unsigned re_next_char(ReInput *p){ + unsigned c; + if( p->i>=p->mx ) return 0; + c = p->z[p->i++]; + if( c>=0x80 ){ + if( (c&0xe0)==0xc0 && p->imx && (p->z[p->i]&0xc0)==0x80 ){ + c = (c&0x1f)<<6 | (p->z[p->i++]&0x3f); + if( c<0x80 ) c = 0xfffd; + }else if( (c&0xf0)==0xe0 && p->i+1mx && (p->z[p->i]&0xc0)==0x80 + && (p->z[p->i+1]&0xc0)==0x80 ){ + c = (c&0x0f)<<12 | ((p->z[p->i]&0x3f)<<6) | (p->z[p->i+1]&0x3f); + p->i += 2; + if( c<=0x7ff || (c>=0xd800 && c<=0xdfff) ) c = 0xfffd; + }else if( (c&0xf8)==0xf0 && p->i+3mx && (p->z[p->i]&0xc0)==0x80 + && (p->z[p->i+1]&0xc0)==0x80 && (p->z[p->i+2]&0xc0)==0x80 ){ + c = (c&0x07)<<18 | ((p->z[p->i]&0x3f)<<12) | ((p->z[p->i+1]&0x3f)<<6) + | (p->z[p->i+2]&0x3f); + p->i += 3; + if( c<=0xffff || c>0x10ffff ) c = 0xfffd; + }else{ + c = 0xfffd; + } + } + return c; +} +static unsigned re_next_char_nocase(ReInput *p){ + unsigned c = re_next_char(p); + if( c>='A' && c<='Z' ) c += 'a' - 'A'; + return c; +} + +/* Return true if c is a perl "word" character: [A-Za-z0-9_] */ +static int re_word_char(int c){ + return (c>='0' && c<='9') || (c>='a' && c<='z') + || (c>='A' && c<='Z') || c=='_'; +} + +/* Return true if c is a "digit" character: [0-9] */ +static int re_digit_char(int c){ + return (c>='0' && c<='9'); +} + +/* Return true if c is a perl "space" character: [ \t\r\n\v\f] */ +static int re_space_char(int c){ + return c==' ' || c=='\t' || c=='\n' || c=='\r' || c=='\v' || c=='\f'; +} + +/* Run a compiled regular expression on the zero-terminated input +** string zIn[]. Return true on a match and false if there is no match. +*/ +static int re_match(ReCompiled *pRe, const unsigned char *zIn, int nIn){ + ReStateSet aStateSet[2], *pThis, *pNext; + ReStateNumber aSpace[100]; + ReStateNumber *pToFree; + unsigned int i = 0; + unsigned int iSwap = 0; + int c = RE_EOF+1; + int cPrev = 0; + int rc = 0; + ReInput in; + + in.z = zIn; + in.i = 0; + in.mx = nIn>=0 ? nIn : (int)strlen((char const*)zIn); + + /* Look for the initial prefix match, if there is one. */ + if( pRe->nInit ){ + unsigned char x = pRe->zInit[0]; + while( in.i+pRe->nInit<=in.mx + && (zIn[in.i]!=x || + strncmp((const char*)zIn+in.i, (const char*)pRe->zInit, pRe->nInit)!=0) + ){ + in.i++; + } + if( in.i+pRe->nInit>in.mx ) return 0; + } + + if( pRe->nState<=(sizeof(aSpace)/(sizeof(aSpace[0])*2)) ){ + pToFree = 0; + aStateSet[0].aState = aSpace; + }else{ + pToFree = sqlite3_malloc64( sizeof(ReStateNumber)*2*pRe->nState ); + if( pToFree==0 ) return -1; + aStateSet[0].aState = pToFree; + } + aStateSet[1].aState = &aStateSet[0].aState[pRe->nState]; + pNext = &aStateSet[1]; + pNext->nState = 0; + re_add_state(pNext, 0); + while( c!=RE_EOF && pNext->nState>0 ){ + cPrev = c; + c = pRe->xNextChar(&in); + pThis = pNext; + pNext = &aStateSet[iSwap]; + iSwap = 1 - iSwap; + pNext->nState = 0; + for(i=0; inState; i++){ + int x = pThis->aState[i]; + switch( pRe->aOp[x] ){ + case RE_OP_MATCH: { + if( pRe->aArg[x]==c ) re_add_state(pNext, x+1); + break; + } + case RE_OP_ANY: { + re_add_state(pNext, x+1); + break; + } + case RE_OP_WORD: { + if( re_word_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_NOTWORD: { + if( !re_word_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_DIGIT: { + if( re_digit_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_NOTDIGIT: { + if( !re_digit_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_SPACE: { + if( re_space_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_NOTSPACE: { + if( !re_space_char(c) ) re_add_state(pNext, x+1); + break; + } + case RE_OP_BOUNDARY: { + if( re_word_char(c)!=re_word_char(cPrev) ) re_add_state(pThis, x+1); + break; + } + case RE_OP_ANYSTAR: { + re_add_state(pNext, x); + re_add_state(pThis, x+1); + break; + } + case RE_OP_FORK: { + re_add_state(pThis, x+pRe->aArg[x]); + re_add_state(pThis, x+1); + break; + } + case RE_OP_GOTO: { + re_add_state(pThis, x+pRe->aArg[x]); + break; + } + case RE_OP_ACCEPT: { + rc = 1; + goto re_match_end; + } + case RE_OP_CC_INC: + case RE_OP_CC_EXC: { + int j = 1; + int n = pRe->aArg[x]; + int hit = 0; + for(j=1; j>0 && jaOp[x+j]==RE_OP_CC_VALUE ){ + if( pRe->aArg[x+j]==c ){ + hit = 1; + j = -1; + } + }else{ + if( pRe->aArg[x+j]<=c && pRe->aArg[x+j+1]>=c ){ + hit = 1; + j = -1; + }else{ + j++; + } + } + } + if( pRe->aOp[x]==RE_OP_CC_EXC ) hit = !hit; + if( hit ) re_add_state(pNext, x+n); + break; + } + } + } + } + for(i=0; inState; i++){ + if( pRe->aOp[pNext->aState[i]]==RE_OP_ACCEPT ){ rc = 1; break; } + } +re_match_end: + sqlite3_free(pToFree); + return rc; +} + +/* Resize the opcode and argument arrays for an RE under construction. +*/ +static int re_resize(ReCompiled *p, int N){ + char *aOp; + int *aArg; + aOp = sqlite3_realloc64(p->aOp, N*sizeof(p->aOp[0])); + if( aOp==0 ) return 1; + p->aOp = aOp; + aArg = sqlite3_realloc64(p->aArg, N*sizeof(p->aArg[0])); + if( aArg==0 ) return 1; + p->aArg = aArg; + p->nAlloc = N; + return 0; +} + +/* Insert a new opcode and argument into an RE under construction. The +** insertion point is just prior to existing opcode iBefore. +*/ +static int re_insert(ReCompiled *p, int iBefore, int op, int arg){ + int i; + if( p->nAlloc<=p->nState && re_resize(p, p->nAlloc*2) ) return 0; + for(i=p->nState; i>iBefore; i--){ + p->aOp[i] = p->aOp[i-1]; + p->aArg[i] = p->aArg[i-1]; + } + p->nState++; + p->aOp[iBefore] = (char)op; + p->aArg[iBefore] = arg; + return iBefore; +} + +/* Append a new opcode and argument to the end of the RE under construction. +*/ +static int re_append(ReCompiled *p, int op, int arg){ + return re_insert(p, p->nState, op, arg); +} + +/* Make a copy of N opcodes starting at iStart onto the end of the RE +** under construction. +*/ +static void re_copy(ReCompiled *p, int iStart, int N){ + if( p->nState+N>=p->nAlloc && re_resize(p, p->nAlloc*2+N) ) return; + memcpy(&p->aOp[p->nState], &p->aOp[iStart], N*sizeof(p->aOp[0])); + memcpy(&p->aArg[p->nState], &p->aArg[iStart], N*sizeof(p->aArg[0])); + p->nState += N; +} + +/* Return true if c is a hexadecimal digit character: [0-9a-fA-F] +** If c is a hex digit, also set *pV = (*pV)*16 + valueof(c). If +** c is not a hex digit *pV is unchanged. +*/ +static int re_hex(int c, int *pV){ + if( c>='0' && c<='9' ){ + c -= '0'; + }else if( c>='a' && c<='f' ){ + c -= 'a' - 10; + }else if( c>='A' && c<='F' ){ + c -= 'A' - 10; + }else{ + return 0; + } + *pV = (*pV)*16 + (c & 0xff); + return 1; +} + +/* A backslash character has been seen, read the next character and +** return its interpretation. +*/ +static unsigned re_esc_char(ReCompiled *p){ + static const char zEsc[] = "afnrtv\\()*.+?[$^{|}]"; + static const char zTrans[] = "\a\f\n\r\t\v"; + int i, v = 0; + char c; + if( p->sIn.i>=p->sIn.mx ) return 0; + c = p->sIn.z[p->sIn.i]; + if( c=='u' && p->sIn.i+4sIn.mx ){ + const unsigned char *zIn = p->sIn.z + p->sIn.i; + if( re_hex(zIn[1],&v) + && re_hex(zIn[2],&v) + && re_hex(zIn[3],&v) + && re_hex(zIn[4],&v) + ){ + p->sIn.i += 5; + return v; + } + } + if( c=='x' && p->sIn.i+2sIn.mx ){ + const unsigned char *zIn = p->sIn.z + p->sIn.i; + if( re_hex(zIn[1],&v) + && re_hex(zIn[2],&v) + ){ + p->sIn.i += 3; + return v; + } + } + for(i=0; zEsc[i] && zEsc[i]!=c; i++){} + if( zEsc[i] ){ + if( i<6 ) c = zTrans[i]; + p->sIn.i++; + }else{ + p->zErr = "unknown \\ escape"; + } + return c; +} + +/* Forward declaration */ +static const char *re_subcompile_string(ReCompiled*); + +/* Peek at the next byte of input */ +static unsigned char rePeek(ReCompiled *p){ + return p->sIn.isIn.mx ? p->sIn.z[p->sIn.i] : 0; +} + +/* Compile RE text into a sequence of opcodes. Continue up to the +** first unmatched ")" character, then return. If an error is found, +** return a pointer to the error message string. +*/ +static const char *re_subcompile_re(ReCompiled *p){ + const char *zErr; + int iStart, iEnd, iGoto; + iStart = p->nState; + zErr = re_subcompile_string(p); + if( zErr ) return zErr; + while( rePeek(p)=='|' ){ + iEnd = p->nState; + re_insert(p, iStart, RE_OP_FORK, iEnd + 2 - iStart); + iGoto = re_append(p, RE_OP_GOTO, 0); + p->sIn.i++; + zErr = re_subcompile_string(p); + if( zErr ) return zErr; + p->aArg[iGoto] = p->nState - iGoto; + } + return 0; +} + +/* Compile an element of regular expression text (anything that can be +** an operand to the "|" operator). Return NULL on success or a pointer +** to the error message if there is a problem. +*/ +static const char *re_subcompile_string(ReCompiled *p){ + int iPrev = -1; + int iStart; + unsigned c; + const char *zErr; + while( (c = p->xNextChar(&p->sIn))!=0 ){ + iStart = p->nState; + switch( c ){ + case '|': + case '$': + case ')': { + p->sIn.i--; + return 0; + } + case '(': { + zErr = re_subcompile_re(p); + if( zErr ) return zErr; + if( rePeek(p)!=')' ) return "unmatched '('"; + p->sIn.i++; + break; + } + case '.': { + if( rePeek(p)=='*' ){ + re_append(p, RE_OP_ANYSTAR, 0); + p->sIn.i++; + }else{ + re_append(p, RE_OP_ANY, 0); + } + break; + } + case '*': { + if( iPrev<0 ) return "'*' without operand"; + re_insert(p, iPrev, RE_OP_GOTO, p->nState - iPrev + 1); + re_append(p, RE_OP_FORK, iPrev - p->nState + 1); + break; + } + case '+': { + if( iPrev<0 ) return "'+' without operand"; + re_append(p, RE_OP_FORK, iPrev - p->nState); + break; + } + case '?': { + if( iPrev<0 ) return "'?' without operand"; + re_insert(p, iPrev, RE_OP_FORK, p->nState - iPrev+1); + break; + } + case '{': { + int m = 0, n = 0; + int sz, j; + if( iPrev<0 ) return "'{m,n}' without operand"; + while( (c=rePeek(p))>='0' && c<='9' ){ m = m*10 + c - '0'; p->sIn.i++; } + n = m; + if( c==',' ){ + p->sIn.i++; + n = 0; + while( (c=rePeek(p))>='0' && c<='9' ){ n = n*10 + c-'0'; p->sIn.i++; } + } + if( c!='}' ) return "unmatched '{'"; + if( n>0 && nsIn.i++; + sz = p->nState - iPrev; + if( m==0 ){ + if( n==0 ) return "both m and n are zero in '{m,n}'"; + re_insert(p, iPrev, RE_OP_FORK, sz+1); + n--; + }else{ + for(j=1; j0 ){ + re_append(p, RE_OP_FORK, -sz); + } + break; + } + case '[': { + int iFirst = p->nState; + if( rePeek(p)=='^' ){ + re_append(p, RE_OP_CC_EXC, 0); + p->sIn.i++; + }else{ + re_append(p, RE_OP_CC_INC, 0); + } + while( (c = p->xNextChar(&p->sIn))!=0 ){ + if( c=='[' && rePeek(p)==':' ){ + return "POSIX character classes not supported"; + } + if( c=='\\' ) c = re_esc_char(p); + if( rePeek(p)=='-' ){ + re_append(p, RE_OP_CC_RANGE, c); + p->sIn.i++; + c = p->xNextChar(&p->sIn); + if( c=='\\' ) c = re_esc_char(p); + re_append(p, RE_OP_CC_RANGE, c); + }else{ + re_append(p, RE_OP_CC_VALUE, c); + } + if( rePeek(p)==']' ){ p->sIn.i++; break; } + } + if( c==0 ) return "unclosed '['"; + p->aArg[iFirst] = p->nState - iFirst; + break; + } + case '\\': { + int specialOp = 0; + switch( rePeek(p) ){ + case 'b': specialOp = RE_OP_BOUNDARY; break; + case 'd': specialOp = RE_OP_DIGIT; break; + case 'D': specialOp = RE_OP_NOTDIGIT; break; + case 's': specialOp = RE_OP_SPACE; break; + case 'S': specialOp = RE_OP_NOTSPACE; break; + case 'w': specialOp = RE_OP_WORD; break; + case 'W': specialOp = RE_OP_NOTWORD; break; + } + if( specialOp ){ + p->sIn.i++; + re_append(p, specialOp, 0); + }else{ + c = re_esc_char(p); + re_append(p, RE_OP_MATCH, c); + } + break; + } + default: { + re_append(p, RE_OP_MATCH, c); + break; + } + } + iPrev = iStart; + } + return 0; +} + +/* Free and reclaim all the memory used by a previously compiled +** regular expression. Applications should invoke this routine once +** for every call to re_compile() to avoid memory leaks. +*/ +static void re_free(ReCompiled *pRe){ + if( pRe ){ + sqlite3_free(pRe->aOp); + sqlite3_free(pRe->aArg); + sqlite3_free(pRe); + } +} + +/* +** Compile a textual regular expression in zIn[] into a compiled regular +** expression suitable for us by re_match() and return a pointer to the +** compiled regular expression in *ppRe. Return NULL on success or an +** error message if something goes wrong. +*/ +static const char *re_compile(ReCompiled **ppRe, const char *zIn, int noCase){ + ReCompiled *pRe; + const char *zErr; + int i, j; + + *ppRe = 0; + pRe = sqlite3_malloc( sizeof(*pRe) ); + if( pRe==0 ){ + return "out of memory"; + } + memset(pRe, 0, sizeof(*pRe)); + pRe->xNextChar = noCase ? re_next_char_nocase : re_next_char; + if( re_resize(pRe, 30) ){ + re_free(pRe); + return "out of memory"; + } + if( zIn[0]=='^' ){ + zIn++; + }else{ + re_append(pRe, RE_OP_ANYSTAR, 0); + } + pRe->sIn.z = (unsigned char*)zIn; + pRe->sIn.i = 0; + pRe->sIn.mx = (int)strlen(zIn); + zErr = re_subcompile_re(pRe); + if( zErr ){ + re_free(pRe); + return zErr; + } + if( rePeek(pRe)=='$' && pRe->sIn.i+1>=pRe->sIn.mx ){ + re_append(pRe, RE_OP_MATCH, RE_EOF); + re_append(pRe, RE_OP_ACCEPT, 0); + *ppRe = pRe; + }else if( pRe->sIn.i>=pRe->sIn.mx ){ + re_append(pRe, RE_OP_ACCEPT, 0); + *ppRe = pRe; + }else{ + re_free(pRe); + return "unrecognized character"; + } + + /* The following is a performance optimization. If the regex begins with + ** ".*" (if the input regex lacks an initial "^") and afterwards there are + ** one or more matching characters, enter those matching characters into + ** zInit[]. The re_match() routine can then search ahead in the input + ** string looking for the initial match without having to run the whole + ** regex engine over the string. Do not worry able trying to match + ** unicode characters beyond plane 0 - those are very rare and this is + ** just an optimization. */ + if( pRe->aOp[0]==RE_OP_ANYSTAR ){ + for(j=0, i=1; jzInit)-2 && pRe->aOp[i]==RE_OP_MATCH; i++){ + unsigned x = pRe->aArg[i]; + if( x<=127 ){ + pRe->zInit[j++] = (unsigned char)x; + }else if( x<=0xfff ){ + pRe->zInit[j++] = (unsigned char)(0xc0 | (x>>6)); + pRe->zInit[j++] = 0x80 | (x&0x3f); + }else if( x<=0xffff ){ + pRe->zInit[j++] = (unsigned char)(0xd0 | (x>>12)); + pRe->zInit[j++] = 0x80 | ((x>>6)&0x3f); + pRe->zInit[j++] = 0x80 | (x&0x3f); + }else{ + break; + } + } + if( j>0 && pRe->zInit[j-1]==0 ) j--; + pRe->nInit = j; + } + return pRe->zErr; +} + +/* +** Implementation of the regexp() SQL function. This function implements +** the build-in REGEXP operator. The first argument to the function is the +** pattern and the second argument is the string. So, the SQL statements: +** +** A REGEXP B +** +** is implemented as regexp(B,A). +*/ +static void re_sql_func( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + ReCompiled *pRe; /* Compiled regular expression */ + const char *zPattern; /* The regular expression */ + const unsigned char *zStr;/* String being searched */ + const char *zErr; /* Compile error message */ + int setAux = 0; /* True to invoke sqlite3_set_auxdata() */ + + pRe = sqlite3_get_auxdata(context, 0); + if( pRe==0 ){ + zPattern = (const char*)sqlite3_value_text(argv[0]); + if( zPattern==0 ) return; + zErr = re_compile(&pRe, zPattern, 0); + if( zErr ){ + re_free(pRe); + sqlite3_result_error(context, zErr, -1); + return; + } + if( pRe==0 ){ + sqlite3_result_error_nomem(context); + return; + } + setAux = 1; + } + zStr = (const unsigned char*)sqlite3_value_text(argv[1]); + if( zStr!=0 ){ + sqlite3_result_int(context, re_match(pRe, zStr, -1)); + } + if( setAux ){ + sqlite3_set_auxdata(context, 0, pRe, (void(*)(void*))re_free); + } +} + +/* +** Invoke this routine to register the regexp() function with the +** SQLite database connection. +*/ +#ifdef _WIN32 +__declspec(dllexport) +#endif +int sqlite3_regexp_init( + sqlite3 *db, + char **pzErrMsg, + const sqlite3_api_routines *pApi +){ + int rc = SQLITE_OK; + SQLITE_EXTENSION_INIT2(pApi); + rc = sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8, 0, re_sql_func, 0, 0); + return rc; +} diff --git a/pkg/db/sqlite/regexp/regexp.go b/pkg/db/sqlite/regexp/regexp.go new file mode 100644 index 0000000..d0dd9e4 --- /dev/null +++ b/pkg/db/sqlite/regexp/regexp.go @@ -0,0 +1,16 @@ +package regexp + +// #ifndef USE_LIBSQLITE3 +// #include +// #else +// #include +// #endif +// +// // Extension function defined in regexp.c. +// extern int sqlite3_regexp_init(sqlite3*, char**, const sqlite3_api_routines*); +// +// // Use constructor to register extension function with sqlite. +// void __attribute__((constructor)) init(void) { +// sqlite3_auto_extension((void*) sqlite3_regexp_init); +// } +import "C" diff --git a/pkg/db/sqlite/sqlite.go b/pkg/db/sqlite/sqlite.go index 1724218..61289f1 100644 --- a/pkg/db/sqlite/sqlite.go +++ b/pkg/db/sqlite/sqlite.go @@ -3,12 +3,18 @@ package sqlite import ( "context" "database/sql" + "encoding/json" "errors" "fmt" + "io/ioutil" "net/http" "net/url" + "os" + "path/filepath" + "strings" "time" + "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" @@ -16,13 +22,18 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" - // Register sqlite3 for use via database/sql. + // Register `sqlite3` driver. _ "github.com/mattn/go-sqlite3" + + // Register `regexp()` function. + _ "github.com/dstotijn/hetty/pkg/db/sqlite/regexp" ) // Client implements reqlog.Repository. type Client struct { - db *sqlx.DB + db *sqlx.DB + dbPath string + activeProject string } type httpRequestLogsQuery struct { @@ -32,16 +43,28 @@ type httpRequestLogsQuery struct { joinResponse bool } -// Open opens a database. -func (c *Client) Open(filename string) error { +func New(dbPath string) (*Client, error) { + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + if err := os.MkdirAll(dbPath, 0755); err != nil { + return nil, fmt.Errorf("proj: could not create project directory: %v", err) + } + } + return &Client{ + dbPath: dbPath, + }, nil +} + +// OpenProject opens a project database. +func (c *Client) OpenProject(name string) error { if c.db != nil { - return errors.New("sqlite: database already open") + return errors.New("sqlite: there is already a project open") } opts := make(url.Values) opts.Set("_foreign_keys", "1") - dsn := fmt.Sprintf("file:%v?%v", filename, opts.Encode()) + dbPath := filepath.Join(c.dbPath, name+".db") + dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode()) db, err := sqlx.Open("sqlite3", dsn) if err != nil { return fmt.Errorf("sqlite: could not open database: %v", err) @@ -56,10 +79,29 @@ func (c *Client) Open(filename string) error { } c.db = db + c.activeProject = name return nil } +func (c *Client) Projects() ([]proj.Project, error) { + files, err := ioutil.ReadDir(c.dbPath) + if err != nil { + return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err) + } + + projects := make([]proj.Project, len(files)) + for i, file := range files { + projName := strings.TrimSuffix(file.Name(), ".db") + projects[i] = proj.Project{ + Name: projName, + IsActive: c.activeProject == projName, + } + } + + return projects, nil +} + func prepareSchema(db *sqlx.DB) error { _, err := db.Exec(`CREATE TABLE IF NOT EXISTS http_requests ( id INTEGER PRIMARY KEY, @@ -97,6 +139,14 @@ func prepareSchema(db *sqlx.DB) error { return fmt.Errorf("could not create http_headers table: %v", err) } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings ( + module TEXT PRIMARY KEY, + settings TEXT + )`) + if err != nil { + return fmt.Errorf("could not create settings table: %v", err) + } + return nil } @@ -108,7 +158,18 @@ func (c *Client) Close() error { if err := c.db.Close(); err != nil { return fmt.Errorf("sqlite: could not close database: %v", err) } + c.db = nil + c.activeProject = "" + + return nil +} + +func (c *Client) DeleteProject(name string) error { + if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil { + return fmt.Errorf("sqlite: could not remove database file: %v", err) + } + return nil } @@ -136,11 +197,11 @@ var headerFieldToColumnMap = map[string]string{ func (c *Client) FindRequestLogs( ctx context.Context, - opts reqlog.FindRequestsOptions, + filter reqlog.FindRequestsFilter, scope *scope.Scope, ) (reqLogs []reqlog.Request, err error) { if c.db == nil { - return nil, reqlog.ErrNoProject + return nil, proj.ErrNoProject } httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx) @@ -153,12 +214,24 @@ func (c *Client) FindRequestLogs( reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id") } - sql, _, err := reqQuery.ToSql() + if filter.OnlyInScope && scope != nil { + var ruleExpr []sq.Sqlizer + for _, rule := range scope.Rules() { + if rule.URL != nil { + ruleExpr = append(ruleExpr, sq.Expr("req.url regexp ?", rule.URL.String())) + } + } + if len(ruleExpr) > 0 { + reqQuery = reqQuery.Where(sq.Or(ruleExpr)) + } + } + + sql, args, err := reqQuery.ToSql() if err != nil { return nil, fmt.Errorf("sqlite: could not parse query: %v", err) } - rows, err := c.db.QueryxContext(ctx, sql, nil) + rows, err := c.db.QueryxContext(ctx, sql, args...) if err != nil { return nil, fmt.Errorf("sqlite: could not execute query: %v", err) } @@ -186,7 +259,7 @@ func (c *Client) FindRequestLogs( func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Request, error) { if c.db == nil { - return reqlog.Request{}, reqlog.ErrNoProject + return reqlog.Request{}, proj.ErrNoProject } httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx) @@ -229,7 +302,7 @@ func (c *Client) AddRequestLog( timestamp time.Time, ) (*reqlog.Request, error) { if c.db == nil { - return nil, reqlog.ErrNoProject + return nil, proj.ErrNoProject } reqLog := &reqlog.Request{ @@ -303,7 +376,7 @@ func (c *Client) AddResponseLog( timestamp time.Time, ) (*reqlog.Response, error) { if c.db == nil { - return nil, reqlog.ErrNoProject + return nil, proj.ErrNoProject } resLog := &reqlog.Response{ @@ -376,6 +449,49 @@ func (c *Client) AddResponseLog( return resLog, nil } +func (c *Client) UpsertSettings(ctx context.Context, module string, settings interface{}) error { + if c.db == nil { + // TODO: Fix where `ErrNoProject` lives. + return proj.ErrNoProject + } + + jsonSettings, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err) + } + + _, err = c.db.ExecContext(ctx, + `INSERT INTO settings (module, settings) VALUES (?, ?) + ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings) + if err != nil { + return fmt.Errorf("sqlite: could not insert scope settings: %v", err) + } + + return nil +} + +func (c *Client) FindSettingsByModule(ctx context.Context, module string, settings interface{}) error { + if c.db == nil { + return proj.ErrNoProject + } + + var jsonSettings []byte + row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module) + err := row.Scan(&jsonSettings) + if err == sql.ErrNoRows { + return proj.ErrNoSettings + } + if err != nil { + return fmt.Errorf("sqlite: could not scan row: %v", err) + } + + if err := json.Unmarshal(jsonSettings, &settings); err != nil { + return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err) + } + + return nil +} + func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.Header) error { for key, values := range headers { for _, value := range values { diff --git a/pkg/proj/proj.go b/pkg/proj/proj.go index 0b07fb2..df69a42 100644 --- a/pkg/proj/proj.go +++ b/pkg/proj/proj.go @@ -1,25 +1,24 @@ package proj import ( + "context" "errors" "fmt" - "io/ioutil" - "os" - "path/filepath" + "log" "regexp" - "strings" - - "github.com/dstotijn/hetty/pkg/db/sqlite" - "github.com/dstotijn/hetty/pkg/scope" + "sync" ) +type OnProjectOpenFn func(name string) error +type OnProjectCloseFn func(name string) error + // Service is used for managing projects. type Service struct { - dbPath string - db *sqlite.Client - name string - - Scope *scope.Scope + repo Repository + activeProject string + onProjectOpenFns []OnProjectOpenFn + onProjectCloseFns []OnProjectCloseFn + mu sync.RWMutex } type Project struct { @@ -29,33 +28,33 @@ type Project struct { var ( ErrNoProject = errors.New("proj: no open project") + ErrNoSettings = errors.New("proj: settings not found") ErrInvalidName = errors.New("proj: invalid name, must be alphanumeric or whitespace chars") ) var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`) // NewService returns a new Service. -func NewService(dbPath string) (*Service, error) { - // Create directory for DBs if it doesn't exist yet. - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - if err := os.MkdirAll(dbPath, 0755); err != nil { - return nil, fmt.Errorf("proj: could not create project directory: %v", err) - } - } - +func NewService(repo Repository) (*Service, error) { return &Service{ - dbPath: dbPath, - db: &sqlite.Client{}, - Scope: scope.New(nil), + repo: repo, }, nil } // Close closes the currently open project database (if there is one). func (svc *Service) Close() error { - if err := svc.db.Close(); err != nil { + svc.mu.Lock() + defer svc.mu.Unlock() + + closedProject := svc.activeProject + if err := svc.repo.Close(); err != nil { return fmt.Errorf("proj: could not close project: %v", err) } - svc.name = "" + + svc.activeProject = "" + + svc.emitProjectClosed(closedProject) + return nil } @@ -64,41 +63,37 @@ func (svc *Service) Delete(name string) error { if name == "" { return errors.New("proj: name cannot be empty") } - if svc.name == name { + if svc.activeProject == name { return fmt.Errorf("proj: project (%v) is active", name) } - if err := os.Remove(filepath.Join(svc.dbPath, name+".db")); err != nil { - return fmt.Errorf("proj: could not remove database file: %v", err) + if err := svc.repo.DeleteProject(name); err != nil { + return fmt.Errorf("proj: could not delete project: %v", err) } return nil } -// Database returns the currently open database. If no database is open, it will -// return `nil`. -func (svc *Service) Database() *sqlite.Client { - return svc.db -} - // Open opens a database identified with `name`. If a database with this // identifier doesn't exist yet, it will be automatically created. -func (svc *Service) Open(name string) (Project, error) { +func (svc *Service) Open(ctx context.Context, name string) (Project, error) { if !nameRegexp.MatchString(name) { return Project{}, ErrInvalidName } - if err := svc.db.Close(); err != nil { + + svc.mu.Lock() + defer svc.mu.Unlock() + + if err := svc.repo.Close(); err != nil { return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err) } - dbPath := filepath.Join(svc.dbPath, name+".db") - - err := svc.db.Open(dbPath) - if err != nil { + if err := svc.repo.OpenProject(name); err != nil { return Project{}, fmt.Errorf("proj: could not open database: %v", err) } - svc.name = name + svc.activeProject = name + svc.emitProjectOpened() return Project{ Name: name, @@ -107,29 +102,51 @@ func (svc *Service) Open(name string) (Project, error) { } func (svc *Service) ActiveProject() (Project, error) { - if !svc.db.IsOpen() { + activeProject := svc.activeProject + if activeProject == "" { return Project{}, ErrNoProject } return Project{ - Name: svc.name, + Name: activeProject, }, nil } func (svc *Service) Projects() ([]Project, error) { - files, err := ioutil.ReadDir(svc.dbPath) + projects, err := svc.repo.Projects() if err != nil { - return nil, fmt.Errorf("proj: could not read projects directory: %v", err) - } - - projects := make([]Project, len(files)) - for i, file := range files { - projName := strings.TrimSuffix(file.Name(), ".db") - projects[i] = Project{ - Name: projName, - IsActive: svc.name == projName, - } + return nil, fmt.Errorf("proj: could not get projects: %v", err) } return projects, nil } + +func (svc *Service) OnProjectOpen(fn OnProjectOpenFn) { + svc.mu.Lock() + defer svc.mu.Unlock() + + svc.onProjectOpenFns = append(svc.onProjectOpenFns, fn) +} + +func (svc *Service) OnProjectClose(fn OnProjectCloseFn) { + svc.mu.Lock() + defer svc.mu.Unlock() + + svc.onProjectCloseFns = append(svc.onProjectCloseFns, fn) +} + +func (svc *Service) emitProjectOpened() { + for _, fn := range svc.onProjectOpenFns { + if err := fn(svc.activeProject); err != nil { + log.Printf("[ERROR] Could not execute onProjectOpen function: %v", err) + } + } +} + +func (svc *Service) emitProjectClosed(name string) { + for _, fn := range svc.onProjectCloseFns { + if err := fn(name); err != nil { + log.Printf("[ERROR] Could not execute onProjectClose function: %v", err) + } + } +} diff --git a/pkg/proj/repo.go b/pkg/proj/repo.go new file mode 100644 index 0000000..09be551 --- /dev/null +++ b/pkg/proj/repo.go @@ -0,0 +1,14 @@ +package proj + +import ( + "context" +) + +type Repository interface { + UpsertSettings(ctx context.Context, module string, settings interface{}) error + FindSettingsByModule(ctx context.Context, module string, settings interface{}) error + OpenProject(name string) error + DeleteProject(name string) error + Projects() ([]Project, error) + Close() error +} diff --git a/pkg/reqlog/repo.go b/pkg/reqlog/repo.go index 217f4cf..8607a60 100644 --- a/pkg/reqlog/repo.go +++ b/pkg/reqlog/repo.go @@ -13,8 +13,10 @@ type RepositoryProvider interface { } type Repository interface { - FindRequestLogs(ctx context.Context, opts FindRequestsOptions, scope *scope.Scope) ([]Request, error) + FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error) FindRequestLogByID(ctx context.Context, id int64) (Request, error) AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error) AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) + UpsertSettings(ctx context.Context, module string, settings interface{}) error + FindSettingsByModule(ctx context.Context, module string, settings interface{}) error } diff --git a/pkg/reqlog/reqlog.go b/pkg/reqlog/reqlog.go index 3232a3f..0957b56 100644 --- a/pkg/reqlog/reqlog.go +++ b/pkg/reqlog/reqlog.go @@ -11,6 +11,7 @@ import ( "net/http" "time" + "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proxy" "github.com/dstotijn/hetty/pkg/scope" ) @@ -19,9 +20,10 @@ type contextKey int const LogBypassedKey contextKey = 0 +const moduleName = "reqlog" + var ( ErrRequestNotFound = errors.New("reqlog: request not found") - ErrNoProject = errors.New("reqlog: no project") ) type Request struct { @@ -42,42 +44,61 @@ type Response struct { type Service struct { BypassOutOfScopeRequests bool + FindReqsFilter FindRequestsFilter scope *scope.Scope repo Repository } -type FindRequestsOptions struct { - OmitOutOfScope bool +type FindRequestsFilter struct { + OnlyInScope bool } type Config struct { Scope *scope.Scope Repository Repository + ProjectService *proj.Service BypassOutOfScopeRequests bool } func NewService(cfg Config) *Service { - return &Service{ + svc := &Service{ scope: cfg.Scope, repo: cfg.Repository, BypassOutOfScopeRequests: cfg.BypassOutOfScopeRequests, } + + cfg.ProjectService.OnProjectOpen(func(_ string) error { + err := svc.loadSettings() + if err == proj.ErrNoSettings { + return nil + } + if err != nil { + return fmt.Errorf("reqlog: could not load settings: %v", err) + } + return nil + }) + cfg.ProjectService.OnProjectClose(func(_ string) error { + svc.unloadSettings() + return nil + }) + + return svc } -func (svc *Service) FindRequests(ctx context.Context, opts FindRequestsOptions) ([]Request, error) { - var scope *scope.Scope - if opts.OmitOutOfScope { - scope = svc.scope - } - - return svc.repo.FindRequestLogs(ctx, opts, scope) +func (svc *Service) FindRequests(ctx context.Context) ([]Request, error) { + return svc.repo.FindRequestLogs(ctx, svc.FindReqsFilter, svc.scope) } func (svc *Service) FindRequestLogByID(ctx context.Context, id int64) (Request, error) { return svc.repo.FindRequestLogByID(ctx, id) } +func (svc *Service) SetRequestLogFilter(ctx context.Context, filter FindRequestsFilter) error { + svc.FindReqsFilter = filter + return svc.repo.UpsertSettings(ctx, "reqlog", svc) +} + func (svc *Service) addRequest( ctx context.Context, req http.Request, @@ -136,7 +157,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM } reqLog, err := svc.addRequest(req.Context(), *clone, body, now) - if err == ErrNoProject { + if err == proj.ErrNoProject { ctx := context.WithValue(req.Context(), LogBypassedKey, true) *req = *req.WithContext(ctx) return @@ -184,3 +205,12 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon return nil } } + +func (svc *Service) loadSettings() error { + return svc.repo.FindSettingsByModule(context.Background(), moduleName, svc) +} + +func (svc *Service) unloadSettings() { + svc.BypassOutOfScopeRequests = false + svc.FindReqsFilter = FindRequestsFilter{} +} diff --git a/pkg/scope/repo.go b/pkg/scope/repo.go new file mode 100644 index 0000000..086f12f --- /dev/null +++ b/pkg/scope/repo.go @@ -0,0 +1,8 @@ +package scope + +import "context" + +type Repository interface { + UpsertSettings(ctx context.Context, module string, settings interface{}) error + FindSettingsByModule(ctx context.Context, module string, settings interface{}) error +} diff --git a/pkg/scope/scope.go b/pkg/scope/scope.go index 585349e..9d8dd03 100644 --- a/pkg/scope/scope.go +++ b/pkg/scope/scope.go @@ -1,14 +1,23 @@ package scope import ( + "context" + "encoding/json" + "fmt" "net/http" "regexp" "sync" + + "github.com/dstotijn/hetty/pkg/proj" ) +const moduleName = "scope" + type Scope struct { - mu sync.RWMutex rules []Rule + repo Repository + + mu sync.RWMutex } type Rule struct { @@ -22,11 +31,26 @@ type Header struct { Value *regexp.Regexp } -func New(rules []Rule) *Scope { - s := &Scope{} - if rules != nil { - s.rules = rules +func New(repo Repository, projService *proj.Service) *Scope { + s := &Scope{ + repo: repo, } + + projService.OnProjectOpen(func(_ string) error { + err := s.load(context.Background()) + if err == proj.ErrNoSettings { + return nil + } + if err != nil { + return fmt.Errorf("scope: could not load scope: %v", err) + } + return nil + }) + projService.OnProjectClose(func(_ string) error { + s.unload() + return nil + }) + return s } @@ -36,11 +60,41 @@ func (s *Scope) Rules() []Rule { return s.rules } -func (s *Scope) SetRules(rules []Rule) { +func (s *Scope) load(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() + var rules []Rule + err := s.repo.FindSettingsByModule(ctx, moduleName, &rules) + if err == proj.ErrNoSettings { + return err + } + if err != nil { + return fmt.Errorf("scope: could not load scope settings: %v", err) + } + s.rules = rules + + return nil +} + +func (s *Scope) unload() { + s.mu.Lock() + defer s.mu.Unlock() + s.rules = nil +} + +func (s *Scope) SetRules(ctx context.Context, rules []Rule) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil { + return fmt.Errorf("scope: cannot set rules in repository: %v", err) + } + + s.rules = rules + + return nil } func (s *Scope) Match(req *http.Request, body []byte) bool { @@ -97,3 +151,87 @@ func (r Rule) Match(req *http.Request, body []byte) bool { return false } + +// MarshalJSON implements json.Marshaler. +func (r Rule) MarshalJSON() ([]byte, error) { + type headerDTO struct { + Key string + Value string + } + type ruleDTO struct { + URL string + Header headerDTO + Body string + } + + dto := ruleDTO{ + URL: regexpToString(r.URL), + Header: headerDTO{ + Key: regexpToString(r.Header.Key), + Value: regexpToString(r.Header.Value), + }, + Body: regexpToString(r.Body), + } + + return json.Marshal(dto) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (r *Rule) UnmarshalJSON(data []byte) error { + type headerDTO struct { + Key string + Value string + } + type ruleDTO struct { + URL string + Header headerDTO + Body string + } + + var dto ruleDTO + if err := json.Unmarshal(data, &dto); err != nil { + return err + } + + url, err := stringToRegexp(dto.URL) + if err != nil { + return err + } + headerKey, err := stringToRegexp(dto.Header.Key) + if err != nil { + return err + } + headerValue, err := stringToRegexp(dto.Header.Value) + if err != nil { + return err + } + body, err := stringToRegexp(dto.Body) + if err != nil { + return err + } + + *r = Rule{ + URL: url, + Header: Header{ + Key: headerKey, + Value: headerValue, + }, + Body: body, + } + + return nil +} + +func regexpToString(r *regexp.Regexp) string { + if r == nil { + return "" + } + return r.String() +} + +func stringToRegexp(s string) (*regexp.Regexp, error) { + if s == "" { + return nil, nil + } + return regexp.Compile(s) +}