mirror of
https://github.com/dstotijn/hetty.git
synced 2025-07-01 18:47:29 -04:00
Add intercept module
This commit is contained in:
2214
pkg/api/generated.go
2214
pkg/api/generated.go
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,14 @@ import (
|
||||
"github.com/oklog/ulid"
|
||||
)
|
||||
|
||||
type CancelRequestResult struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type CancelResponseResult struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type ClearHTTPRequestLogResult struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
@ -38,6 +46,16 @@ type HTTPHeaderInput struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type HTTPRequest struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
URL *url.URL `json:"url"`
|
||||
Method HTTPMethod `json:"method"`
|
||||
Proto HTTPProtocol `json:"proto"`
|
||||
Headers []HTTPHeader `json:"headers"`
|
||||
Body *string `json:"body"`
|
||||
Response *HTTPResponse `json:"response"`
|
||||
}
|
||||
|
||||
type HTTPRequestLog struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
URL string `json:"url"`
|
||||
@ -59,6 +77,16 @@ type HTTPRequestLogFilterInput struct {
|
||||
SearchExpression *string `json:"searchExpression"`
|
||||
}
|
||||
|
||||
type HTTPResponse struct {
|
||||
// Will be the same ID as its related request ID.
|
||||
ID ulid.ULID `json:"id"`
|
||||
Proto HTTPProtocol `json:"proto"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
StatusReason string `json:"statusReason"`
|
||||
Body *string `json:"body"`
|
||||
Headers []HTTPHeader `json:"headers"`
|
||||
}
|
||||
|
||||
type HTTPResponseLog struct {
|
||||
// Will be the same ID as its related request ID.
|
||||
ID ulid.ULID `json:"id"`
|
||||
@ -69,10 +97,49 @@ type HTTPResponseLog struct {
|
||||
Headers []HTTPHeader `json:"headers"`
|
||||
}
|
||||
|
||||
type InterceptSettings struct {
|
||||
RequestsEnabled bool `json:"requestsEnabled"`
|
||||
ResponsesEnabled bool `json:"responsesEnabled"`
|
||||
RequestFilter *string `json:"requestFilter"`
|
||||
ResponseFilter *string `json:"responseFilter"`
|
||||
}
|
||||
|
||||
type ModifyRequestInput struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
URL *url.URL `json:"url"`
|
||||
Method HTTPMethod `json:"method"`
|
||||
Proto HTTPProtocol `json:"proto"`
|
||||
Headers []HTTPHeaderInput `json:"headers"`
|
||||
Body *string `json:"body"`
|
||||
ModifyResponse *bool `json:"modifyResponse"`
|
||||
}
|
||||
|
||||
type ModifyRequestResult struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type ModifyResponseInput struct {
|
||||
RequestID ulid.ULID `json:"requestID"`
|
||||
Proto HTTPProtocol `json:"proto"`
|
||||
Headers []HTTPHeaderInput `json:"headers"`
|
||||
Body *string `json:"body"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
StatusReason string `json:"statusReason"`
|
||||
}
|
||||
|
||||
type ModifyResponseResult struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID ulid.ULID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"isActive"`
|
||||
ID ulid.ULID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Settings *ProjectSettings `json:"settings"`
|
||||
}
|
||||
|
||||
type ProjectSettings struct {
|
||||
Intercept *InterceptSettings `json:"intercept"`
|
||||
}
|
||||
|
||||
type ScopeHeader struct {
|
||||
@ -128,6 +195,13 @@ type SenderRequestInput struct {
|
||||
Body *string `json:"body"`
|
||||
}
|
||||
|
||||
type UpdateInterceptSettingsInput struct {
|
||||
RequestsEnabled bool `json:"requestsEnabled"`
|
||||
ResponsesEnabled bool `json:"responsesEnabled"`
|
||||
RequestFilter *string `json:"requestFilter"`
|
||||
ResponseFilter *string `json:"responseFilter"`
|
||||
}
|
||||
|
||||
type HTTPMethod string
|
||||
|
||||
const (
|
||||
|
@ -3,9 +3,12 @@ package api
|
||||
//go:generate go run github.com/99designs/gqlgen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -15,6 +18,8 @@ import (
|
||||
"github.com/vektah/gqlparser/v2/gqlerror"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proj"
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/proxy/intercept"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
@ -36,6 +41,7 @@ var revHTTPProtocolMap = map[HTTPProtocol]string{
|
||||
type Resolver struct {
|
||||
ProjectService proj.Service
|
||||
RequestLogService reqlog.Service
|
||||
InterceptService *intercept.Service
|
||||
SenderService sender.Service
|
||||
}
|
||||
|
||||
@ -179,11 +185,9 @@ func (r *mutationResolver) CreateProject(ctx context.Context, name string) (*Pro
|
||||
return nil, fmt.Errorf("could not open project: %w", err)
|
||||
}
|
||||
|
||||
return &Project{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
IsActive: r.ProjectService.IsProjectActive(p.ID),
|
||||
}, nil
|
||||
project := parseProject(r.ProjectService, p)
|
||||
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) OpenProject(ctx context.Context, id ulid.ULID) (*Project, error) {
|
||||
@ -194,11 +198,9 @@ func (r *mutationResolver) OpenProject(ctx context.Context, id ulid.ULID) (*Proj
|
||||
return nil, fmt.Errorf("could not open project: %w", err)
|
||||
}
|
||||
|
||||
return &Project{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
IsActive: r.ProjectService.IsProjectActive(p.ID),
|
||||
}, nil
|
||||
project := parseProject(r.ProjectService, p)
|
||||
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
||||
@ -209,11 +211,9 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
||||
return nil, fmt.Errorf("could not open project: %w", err)
|
||||
}
|
||||
|
||||
return &Project{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
IsActive: r.ProjectService.IsProjectActive(p.ID),
|
||||
}, nil
|
||||
project := parseProject(r.ProjectService, p)
|
||||
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
|
||||
@ -224,11 +224,7 @@ func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
|
||||
|
||||
projects := make([]Project, len(p))
|
||||
for i, proj := range p {
|
||||
projects[i] = Project{
|
||||
ID: proj.ID,
|
||||
Name: proj.Name,
|
||||
IsActive: r.ProjectService.IsProjectActive(proj.ID),
|
||||
}
|
||||
projects[i] = parseProject(r.ProjectService, proj)
|
||||
}
|
||||
|
||||
return projects, nil
|
||||
@ -520,6 +516,166 @@ func (r *mutationResolver) DeleteSenderRequests(ctx context.Context) (*DeleteSen
|
||||
return &DeleteSenderRequestsResult{true}, nil
|
||||
}
|
||||
|
||||
func (r *queryResolver) InterceptedRequests(ctx context.Context) (httpReqs []HTTPRequest, err error) {
|
||||
items := r.InterceptService.Items()
|
||||
|
||||
for _, item := range items {
|
||||
req, err := parseInterceptItem(item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReqs = append(httpReqs, req)
|
||||
}
|
||||
|
||||
return httpReqs, nil
|
||||
}
|
||||
|
||||
func (r *queryResolver) InterceptedRequest(ctx context.Context, id ulid.ULID) (*HTTPRequest, error) {
|
||||
item, err := r.InterceptService.ItemByID(id)
|
||||
if errors.Is(err, intercept.ErrRequestNotFound) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not get request by ID: %w", err)
|
||||
}
|
||||
|
||||
req, err := parseInterceptItem(item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ModifyRequest(ctx context.Context, input ModifyRequestInput) (*ModifyRequestResult, error) {
|
||||
body := ""
|
||||
if input.Body != nil {
|
||||
body = *input.Body
|
||||
}
|
||||
|
||||
//nolint:noctx
|
||||
req, err := http.NewRequest(input.Method.String(), input.URL.String(), strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct HTTP request: %w", err)
|
||||
}
|
||||
|
||||
for _, header := range input.Headers {
|
||||
req.Header.Add(header.Key, header.Value)
|
||||
}
|
||||
|
||||
err = r.InterceptService.ModifyRequest(input.ID, req, input.ModifyResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not modify http request: %w", err)
|
||||
}
|
||||
|
||||
return &ModifyRequestResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CancelRequest(ctx context.Context, id ulid.ULID) (*CancelRequestResult, error) {
|
||||
err := r.InterceptService.CancelRequest(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not cancel http request: %w", err)
|
||||
}
|
||||
|
||||
return &CancelRequestResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ModifyResponse(
|
||||
ctx context.Context,
|
||||
input ModifyResponseInput,
|
||||
) (*ModifyResponseResult, error) {
|
||||
res := &http.Response{
|
||||
Header: make(http.Header),
|
||||
Status: fmt.Sprintf("%v %v", input.StatusCode, input.StatusReason),
|
||||
StatusCode: input.StatusCode,
|
||||
Proto: revHTTPProtocolMap[input.Proto],
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if res.ProtoMajor, res.ProtoMinor, ok = http.ParseHTTPVersion(res.Proto); !ok {
|
||||
return nil, fmt.Errorf("malformed HTTP version: %q", res.Proto)
|
||||
}
|
||||
|
||||
var body string
|
||||
if input.Body != nil {
|
||||
body = *input.Body
|
||||
}
|
||||
|
||||
res.Body = io.NopCloser(strings.NewReader(body))
|
||||
|
||||
for _, header := range input.Headers {
|
||||
res.Header.Add(header.Key, header.Value)
|
||||
}
|
||||
|
||||
err := r.InterceptService.ModifyResponse(input.RequestID, res)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not modify http request: %w", err)
|
||||
}
|
||||
|
||||
return &ModifyResponseResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CancelResponse(ctx context.Context, requestID ulid.ULID) (*CancelResponseResult, error) {
|
||||
err := r.InterceptService.CancelResponse(requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not cancel http response: %w", err)
|
||||
}
|
||||
|
||||
return &CancelResponseResult{Success: true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) UpdateInterceptSettings(
|
||||
ctx context.Context,
|
||||
input UpdateInterceptSettingsInput,
|
||||
) (*InterceptSettings, error) {
|
||||
settings := intercept.Settings{
|
||||
RequestsEnabled: input.RequestsEnabled,
|
||||
ResponsesEnabled: input.ResponsesEnabled,
|
||||
}
|
||||
|
||||
if input.RequestFilter != nil && *input.RequestFilter != "" {
|
||||
expr, err := search.ParseQuery(*input.RequestFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse request filter: %w", err)
|
||||
}
|
||||
|
||||
settings.RequestFilter = expr
|
||||
}
|
||||
|
||||
if input.ResponseFilter != nil && *input.ResponseFilter != "" {
|
||||
expr, err := search.ParseQuery(*input.ResponseFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse response filter: %w", err)
|
||||
}
|
||||
|
||||
settings.ResponseFilter = expr
|
||||
}
|
||||
|
||||
err := r.ProjectService.UpdateInterceptSettings(ctx, settings)
|
||||
if errors.Is(err, proj.ErrNoProject) {
|
||||
return nil, noActiveProjectErr(ctx)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not update intercept settings: %w", err)
|
||||
}
|
||||
|
||||
updated := &InterceptSettings{
|
||||
RequestsEnabled: settings.RequestsEnabled,
|
||||
ResponsesEnabled: settings.ResponsesEnabled,
|
||||
}
|
||||
|
||||
if settings.RequestFilter != nil {
|
||||
reqFilter := settings.RequestFilter.String()
|
||||
updated.RequestFilter = &reqFilter
|
||||
}
|
||||
|
||||
if settings.ResponseFilter != nil {
|
||||
resFilter := settings.ResponseFilter.String()
|
||||
updated.ResponseFilter = &resFilter
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func parseSenderRequest(req sender.Request) (SenderRequest, error) {
|
||||
method := HTTPMethod(req.Method)
|
||||
if method != "" && !method.IsValid() {
|
||||
@ -575,6 +731,155 @@ func parseSenderRequest(req sender.Request) (SenderRequest, error) {
|
||||
return senderReq, nil
|
||||
}
|
||||
|
||||
func parseHTTPRequest(req *http.Request) (HTTPRequest, error) {
|
||||
method := HTTPMethod(req.Method)
|
||||
if method != "" && !method.IsValid() {
|
||||
return HTTPRequest{}, fmt.Errorf("http request has invalid method: %v", method)
|
||||
}
|
||||
|
||||
reqProto := httpProtocolMap[req.Proto]
|
||||
if !reqProto.IsValid() {
|
||||
return HTTPRequest{}, fmt.Errorf("http request has invalid protocol: %v", req.Proto)
|
||||
}
|
||||
|
||||
id, ok := proxy.RequestIDFromContext(req.Context())
|
||||
if !ok {
|
||||
return HTTPRequest{}, errors.New("http request has missing ID")
|
||||
}
|
||||
|
||||
httpReq := HTTPRequest{
|
||||
ID: id,
|
||||
URL: req.URL,
|
||||
Method: method,
|
||||
Proto: HTTPProtocol(req.Proto),
|
||||
}
|
||||
|
||||
if req.Header != nil {
|
||||
httpReq.Headers = make([]HTTPHeader, 0)
|
||||
|
||||
for key, values := range req.Header {
|
||||
for _, value := range values {
|
||||
httpReq.Headers = append(httpReq.Headers, HTTPHeader{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return HTTPRequest{}, fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
bodyStr := string(body)
|
||||
httpReq.Body = &bodyStr
|
||||
}
|
||||
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func parseHTTPResponse(res *http.Response) (HTTPResponse, error) {
|
||||
resProto := httpProtocolMap[res.Proto]
|
||||
if !resProto.IsValid() {
|
||||
return HTTPResponse{}, fmt.Errorf("http response has invalid protocol: %v", res.Proto)
|
||||
}
|
||||
|
||||
id, ok := proxy.RequestIDFromContext(res.Request.Context())
|
||||
if !ok {
|
||||
return HTTPResponse{}, errors.New("http response has missing ID")
|
||||
}
|
||||
|
||||
httpRes := HTTPResponse{
|
||||
ID: id,
|
||||
Proto: resProto,
|
||||
StatusCode: res.StatusCode,
|
||||
}
|
||||
|
||||
statusReasonSubs := strings.SplitN(res.Status, " ", 2)
|
||||
|
||||
if len(statusReasonSubs) == 2 {
|
||||
httpRes.StatusReason = statusReasonSubs[1]
|
||||
}
|
||||
|
||||
if res.Header != nil {
|
||||
httpRes.Headers = make([]HTTPHeader, 0)
|
||||
|
||||
for key, values := range res.Header {
|
||||
for _, value := range values {
|
||||
httpRes.Headers = append(httpRes.Headers, HTTPHeader{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if res.Body != nil {
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return HTTPResponse{}, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
bodyStr := string(body)
|
||||
httpRes.Body = &bodyStr
|
||||
}
|
||||
|
||||
return httpRes, nil
|
||||
}
|
||||
|
||||
func parseInterceptItem(item intercept.Item) (req HTTPRequest, err error) {
|
||||
if item.Response != nil {
|
||||
req, err = parseHTTPRequest(item.Response.Request)
|
||||
if err != nil {
|
||||
return HTTPRequest{}, err
|
||||
}
|
||||
|
||||
res, err := parseHTTPResponse(item.Response)
|
||||
if err != nil {
|
||||
return HTTPRequest{}, err
|
||||
}
|
||||
|
||||
req.Response = &res
|
||||
} else if item.Request != nil {
|
||||
req, err = parseHTTPRequest(item.Request)
|
||||
if err != nil {
|
||||
return HTTPRequest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func parseProject(projSvc proj.Service, p proj.Project) Project {
|
||||
project := Project{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
IsActive: projSvc.IsProjectActive(p.ID),
|
||||
Settings: &ProjectSettings{
|
||||
Intercept: &InterceptSettings{
|
||||
RequestsEnabled: p.Settings.InterceptRequests,
|
||||
ResponsesEnabled: p.Settings.InterceptResponses,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if p.Settings.InterceptRequestFilter != nil {
|
||||
interceptReqFilter := p.Settings.InterceptRequestFilter.String()
|
||||
project.Settings.Intercept.RequestFilter = &interceptReqFilter
|
||||
}
|
||||
|
||||
if p.Settings.InterceptResponseFilter != nil {
|
||||
interceptResFilter := p.Settings.InterceptResponseFilter.String()
|
||||
project.Settings.Intercept.ResponseFilter = &interceptResFilter
|
||||
}
|
||||
|
||||
return project
|
||||
}
|
||||
|
||||
func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
|
@ -30,6 +30,11 @@ type Project {
|
||||
id: ID!
|
||||
name: String!
|
||||
isActive: Boolean!
|
||||
settings: ProjectSettings!
|
||||
}
|
||||
|
||||
type ProjectSettings {
|
||||
intercept: InterceptSettings!
|
||||
}
|
||||
|
||||
type ScopeRule {
|
||||
@ -116,6 +121,77 @@ type SenderRequestFilter {
|
||||
searchExpression: String
|
||||
}
|
||||
|
||||
type HttpRequest {
|
||||
id: ID!
|
||||
url: URL!
|
||||
method: HttpMethod!
|
||||
proto: HttpProtocol!
|
||||
headers: [HttpHeader!]!
|
||||
body: String
|
||||
response: HttpResponse
|
||||
}
|
||||
|
||||
type HttpResponse {
|
||||
"""
|
||||
Will be the same ID as its related request ID.
|
||||
"""
|
||||
id: ID!
|
||||
proto: HttpProtocol!
|
||||
statusCode: Int!
|
||||
statusReason: String!
|
||||
body: String
|
||||
headers: [HttpHeader!]!
|
||||
}
|
||||
|
||||
input ModifyRequestInput {
|
||||
id: ID!
|
||||
url: URL!
|
||||
method: HttpMethod!
|
||||
proto: HttpProtocol!
|
||||
headers: [HttpHeaderInput!]
|
||||
body: String
|
||||
modifyResponse: Boolean
|
||||
}
|
||||
|
||||
type ModifyRequestResult {
|
||||
success: Boolean!
|
||||
}
|
||||
|
||||
type CancelRequestResult {
|
||||
success: Boolean!
|
||||
}
|
||||
|
||||
input ModifyResponseInput {
|
||||
requestID: ID!
|
||||
proto: HttpProtocol!
|
||||
headers: [HttpHeaderInput!]
|
||||
body: String
|
||||
statusCode: Int!
|
||||
statusReason: String!
|
||||
}
|
||||
|
||||
type ModifyResponseResult {
|
||||
success: Boolean!
|
||||
}
|
||||
|
||||
type CancelResponseResult {
|
||||
success: Boolean!
|
||||
}
|
||||
|
||||
input UpdateInterceptSettingsInput {
|
||||
requestsEnabled: Boolean!
|
||||
responsesEnabled: Boolean!
|
||||
requestFilter: String
|
||||
responseFilter: String
|
||||
}
|
||||
|
||||
type InterceptSettings {
|
||||
requestsEnabled: Boolean!
|
||||
responsesEnabled: Boolean!
|
||||
requestFilter: String
|
||||
responseFilter: String
|
||||
}
|
||||
|
||||
type Query {
|
||||
httpRequestLog(id: ID!): HttpRequestLog
|
||||
httpRequestLogs: [HttpRequestLog!]!
|
||||
@ -125,6 +201,8 @@ type Query {
|
||||
scope: [ScopeRule!]!
|
||||
senderRequest(id: ID!): SenderRequest
|
||||
senderRequests: [SenderRequest!]!
|
||||
interceptedRequests: [HttpRequest!]!
|
||||
interceptedRequest(id: ID!): HttpRequest
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
@ -142,6 +220,13 @@ type Mutation {
|
||||
createSenderRequestFromHttpRequestLog(id: ID!): SenderRequest!
|
||||
sendRequest(id: ID!): SenderRequest!
|
||||
deleteSenderRequests: DeleteSenderRequestsResult!
|
||||
modifyRequest(request: ModifyRequestInput!): ModifyRequestResult!
|
||||
cancelRequest(id: ID!): CancelRequestResult!
|
||||
modifyResponse(response: ModifyResponseInput!): ModifyResponseResult!
|
||||
cancelResponse(requestID: ID!): CancelResponseResult!
|
||||
updateInterceptSettings(
|
||||
input: UpdateInterceptSettingsInput!
|
||||
): InterceptSettings!
|
||||
}
|
||||
|
||||
enum HttpMethod {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proxy/intercept"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
@ -33,10 +34,12 @@ type Service interface {
|
||||
SetScopeRules(ctx context.Context, rules []scope.Rule) error
|
||||
SetRequestLogFindFilter(ctx context.Context, filter reqlog.FindRequestsFilter) error
|
||||
SetSenderRequestFindFilter(ctx context.Context, filter sender.FindRequestsFilter) error
|
||||
UpdateInterceptSettings(ctx context.Context, settings intercept.Settings) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
repo Repository
|
||||
interceptSvc *intercept.Service
|
||||
reqLogSvc reqlog.Service
|
||||
senderSvc sender.Service
|
||||
scope *scope.Scope
|
||||
@ -53,13 +56,22 @@ type Project struct {
|
||||
}
|
||||
|
||||
type Settings struct {
|
||||
// Request log settings
|
||||
ReqLogBypassOutOfScope bool
|
||||
ReqLogOnlyFindInScope bool
|
||||
ReqLogSearchExpr search.Expression
|
||||
|
||||
// Intercept settings
|
||||
InterceptRequests bool
|
||||
InterceptResponses bool
|
||||
InterceptRequestFilter search.Expression
|
||||
InterceptResponseFilter search.Expression
|
||||
|
||||
// Sender settings
|
||||
SenderOnlyFindInScope bool
|
||||
SenderSearchExpr search.Expression
|
||||
|
||||
// Scope settings
|
||||
ScopeRules []scope.Rule
|
||||
}
|
||||
|
||||
@ -73,19 +85,21 @@ var (
|
||||
var nameRegexp = regexp.MustCompile(`^[\w\d\s]+$`)
|
||||
|
||||
type Config struct {
|
||||
Repository Repository
|
||||
ReqLogService reqlog.Service
|
||||
SenderService sender.Service
|
||||
Scope *scope.Scope
|
||||
Repository Repository
|
||||
InterceptService *intercept.Service
|
||||
ReqLogService reqlog.Service
|
||||
SenderService sender.Service
|
||||
Scope *scope.Scope
|
||||
}
|
||||
|
||||
// NewService returns a new Service.
|
||||
func NewService(cfg Config) (Service, error) {
|
||||
return &service{
|
||||
repo: cfg.Repository,
|
||||
reqLogSvc: cfg.ReqLogService,
|
||||
senderSvc: cfg.SenderService,
|
||||
scope: cfg.Scope,
|
||||
repo: cfg.Repository,
|
||||
interceptSvc: cfg.InterceptService,
|
||||
reqLogSvc: cfg.ReqLogService,
|
||||
senderSvc: cfg.SenderService,
|
||||
scope: cfg.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -120,6 +134,12 @@ func (svc *service) CloseProject() error {
|
||||
svc.reqLogSvc.SetActiveProjectID(ulid.ULID{})
|
||||
svc.reqLogSvc.SetBypassOutOfScopeRequests(false)
|
||||
svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{})
|
||||
svc.interceptSvc.UpdateSettings(intercept.Settings{
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: false,
|
||||
RequestFilter: nil,
|
||||
ResponseFilter: nil,
|
||||
})
|
||||
svc.senderSvc.SetActiveProjectID(ulid.ULID{})
|
||||
svc.senderSvc.SetFindReqsFilter(sender.FindRequestsFilter{})
|
||||
svc.scope.SetRules(nil)
|
||||
@ -152,6 +172,7 @@ func (svc *service) OpenProject(ctx context.Context, projectID ulid.ULID) (Proje
|
||||
|
||||
svc.activeProjectID = project.ID
|
||||
|
||||
// Request log settings.
|
||||
svc.reqLogSvc.SetFindReqsFilter(reqlog.FindRequestsFilter{
|
||||
ProjectID: project.ID,
|
||||
OnlyInScope: project.Settings.ReqLogOnlyFindInScope,
|
||||
@ -160,6 +181,15 @@ func (svc *service) OpenProject(ctx context.Context, projectID ulid.ULID) (Proje
|
||||
svc.reqLogSvc.SetBypassOutOfScopeRequests(project.Settings.ReqLogBypassOutOfScope)
|
||||
svc.reqLogSvc.SetActiveProjectID(project.ID)
|
||||
|
||||
// Intercept settings.
|
||||
svc.interceptSvc.UpdateSettings(intercept.Settings{
|
||||
RequestsEnabled: project.Settings.InterceptRequests,
|
||||
ResponsesEnabled: project.Settings.InterceptResponses,
|
||||
RequestFilter: project.Settings.InterceptRequestFilter,
|
||||
ResponseFilter: project.Settings.InterceptResponseFilter,
|
||||
})
|
||||
|
||||
// Sender settings.
|
||||
svc.senderSvc.SetActiveProjectID(project.ID)
|
||||
svc.senderSvc.SetFindReqsFilter(sender.FindRequestsFilter{
|
||||
ProjectID: project.ID,
|
||||
@ -167,6 +197,7 @@ func (svc *service) OpenProject(ctx context.Context, projectID ulid.ULID) (Proje
|
||||
SearchExpr: project.Settings.SenderSearchExpr,
|
||||
})
|
||||
|
||||
// Scope settings.
|
||||
svc.scope.SetRules(project.Settings.ScopeRules)
|
||||
|
||||
return project, nil
|
||||
@ -264,3 +295,24 @@ func (svc *service) SetSenderRequestFindFilter(ctx context.Context, filter sende
|
||||
func (svc *service) IsProjectActive(projectID ulid.ULID) bool {
|
||||
return projectID.Compare(svc.activeProjectID) == 0
|
||||
}
|
||||
|
||||
func (svc *service) UpdateInterceptSettings(ctx context.Context, settings intercept.Settings) error {
|
||||
project, err := svc.ActiveProject(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
project.Settings.InterceptRequests = settings.RequestsEnabled
|
||||
project.Settings.InterceptResponses = settings.ResponsesEnabled
|
||||
project.Settings.InterceptRequestFilter = settings.RequestFilter
|
||||
project.Settings.InterceptResponseFilter = settings.ResponseFilter
|
||||
|
||||
err = svc.repo.UpsertProject(ctx, project)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proj: failed to update project: %w", err)
|
||||
}
|
||||
|
||||
svc.interceptSvc.UpdateSettings(settings)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
35
pkg/proxy/gzip.go
Normal file
35
pkg/proxy/gzip.go
Normal file
@ -0,0 +1,35 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func gunzipResponseBody(res *http.Response) error {
|
||||
if res.Header.Get("Content-Encoding") != "gzip" {
|
||||
return nil
|
||||
}
|
||||
|
||||
gzipReader, err := gzip.NewReader(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy: could not create gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
//nolint:gosec
|
||||
if _, err := io.Copy(buf, gzipReader); err != nil {
|
||||
return fmt.Errorf("proxy: could not read gzipped response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = io.NopCloser(buf)
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Set("Content-Length", fmt.Sprint(buf.Len()))
|
||||
res.ContentLength = int64(buf.Len())
|
||||
|
||||
return nil
|
||||
}
|
395
pkg/proxy/intercept/filter.go
Normal file
395
pkg/proxy/intercept/filter.go
Normal file
@ -0,0 +1,395 @@
|
||||
package intercept
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
)
|
||||
|
||||
//nolint:unparam
|
||||
var reqFilterKeyFns = map[string]func(req *http.Request) (string, error){
|
||||
"proto": func(req *http.Request) (string, error) { return req.Proto, nil },
|
||||
"url": func(req *http.Request) (string, error) {
|
||||
if req.URL == nil {
|
||||
return "", nil
|
||||
}
|
||||
return req.URL.String(), nil
|
||||
},
|
||||
"method": func(req *http.Request) (string, error) { return req.Method, nil },
|
||||
"body": func(req *http.Request) (string, error) {
|
||||
if req.Body == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
return string(body), nil
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:unparam
|
||||
var resFilterKeyFns = map[string]func(res *http.Response) (string, error){
|
||||
"proto": func(res *http.Response) (string, error) { return res.Proto, nil },
|
||||
"statusCode": func(res *http.Response) (string, error) { return strconv.Itoa(res.StatusCode), nil },
|
||||
"statusReason": func(res *http.Response) (string, error) {
|
||||
statusReasonSubs := strings.SplitN(res.Status, " ", 2)
|
||||
|
||||
if len(statusReasonSubs) != 2 {
|
||||
return "", fmt.Errorf("invalid response status %q", res.Status)
|
||||
}
|
||||
return statusReasonSubs[1], nil
|
||||
},
|
||||
"body": func(res *http.Response) (string, error) {
|
||||
if res.Body == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
return string(body), nil
|
||||
},
|
||||
}
|
||||
|
||||
// MatchRequestFilter returns true if an HTTP request matches the request filter expression.
|
||||
func MatchRequestFilter(req *http.Request, expr search.Expression) (bool, error) {
|
||||
switch e := expr.(type) {
|
||||
case search.PrefixExpression:
|
||||
return matchReqPrefixExpr(req, e)
|
||||
case search.InfixExpression:
|
||||
return matchReqInfixExpr(req, e)
|
||||
case search.StringLiteral:
|
||||
return matchReqStringLiteral(req, e)
|
||||
default:
|
||||
return false, fmt.Errorf("expression type (%T) not supported", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func matchReqPrefixExpr(req *http.Request, expr search.PrefixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpNot:
|
||||
match, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return !match, nil
|
||||
default:
|
||||
return false, errors.New("operator is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func matchReqInfixExpr(req *http.Request, expr search.InfixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpAnd:
|
||||
left, err := MatchRequestFilter(req, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left && right, nil
|
||||
case search.TokOpOr:
|
||||
left, err := MatchRequestFilter(req, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchRequestFilter(req, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left || right, nil
|
||||
}
|
||||
|
||||
left, ok := expr.Left.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("left operand must be a string literal")
|
||||
}
|
||||
|
||||
leftVal, err := getMappedStringLiteralFromReq(req, left.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from request for left operand: %w", err)
|
||||
}
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpRe:
|
||||
return right.MatchString(leftVal), nil
|
||||
case search.TokOpNotRe:
|
||||
return !right.MatchString(leftVal), nil
|
||||
}
|
||||
}
|
||||
|
||||
right, ok := expr.Right.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a string literal")
|
||||
}
|
||||
|
||||
rightVal, err := getMappedStringLiteralFromReq(req, right.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from request for right operand: %w", err)
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpEq:
|
||||
return leftVal == rightVal, nil
|
||||
case search.TokOpNotEq:
|
||||
return leftVal != rightVal, nil
|
||||
case search.TokOpGt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal > rightVal, nil
|
||||
case search.TokOpLt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal < rightVal, nil
|
||||
case search.TokOpGtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal >= rightVal, nil
|
||||
case search.TokOpLtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal <= rightVal, nil
|
||||
default:
|
||||
return false, errors.New("unsupported operator")
|
||||
}
|
||||
}
|
||||
|
||||
func getMappedStringLiteralFromReq(req *http.Request, s string) (string, error) {
|
||||
fn, ok := reqFilterKeyFns[s]
|
||||
if ok {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func matchReqStringLiteral(req *http.Request, strLiteral search.StringLiteral) (bool, error) {
|
||||
for _, fn := range reqFilterKeyFns {
|
||||
value, err := fn(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func MatchRequestScope(req *http.Request, s *scope.Scope) (bool, error) {
|
||||
for _, rule := range s.Rules() {
|
||||
if rule.URL != nil && req.URL != nil {
|
||||
if matches := rule.URL.MatchString(req.URL.String()); matches {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range req.Header {
|
||||
var keyMatches, valueMatches bool
|
||||
|
||||
if rule.Header.Key != nil {
|
||||
if matches := rule.Header.Key.MatchString(key); matches {
|
||||
keyMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
if rule.Header.Value != nil {
|
||||
for _, value := range values {
|
||||
if matches := rule.Header.Value.MatchString(value); matches {
|
||||
valueMatches = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When only key or value is set, match on whatever is set.
|
||||
// When both are set, both must match.
|
||||
switch {
|
||||
case rule.Header.Key != nil && rule.Header.Value == nil && keyMatches:
|
||||
return true, nil
|
||||
case rule.Header.Key == nil && rule.Header.Value != nil && valueMatches:
|
||||
return true, nil
|
||||
case rule.Header.Key != nil && rule.Header.Value != nil && keyMatches && valueMatches:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if rule.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
if matches := rule.Body.Match(body); matches {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// MatchResponseFilter returns true if an HTTP response matches the response filter expression.
|
||||
func MatchResponseFilter(res *http.Response, expr search.Expression) (bool, error) {
|
||||
switch e := expr.(type) {
|
||||
case search.PrefixExpression:
|
||||
return matchResPrefixExpr(res, e)
|
||||
case search.InfixExpression:
|
||||
return matchResInfixExpr(res, e)
|
||||
case search.StringLiteral:
|
||||
return matchResStringLiteral(res, e)
|
||||
default:
|
||||
return false, fmt.Errorf("expression type (%T) not supported", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func matchResPrefixExpr(res *http.Response, expr search.PrefixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpNot:
|
||||
match, err := MatchResponseFilter(res, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return !match, nil
|
||||
default:
|
||||
return false, errors.New("operator is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func matchResInfixExpr(res *http.Response, expr search.InfixExpression) (bool, error) {
|
||||
switch expr.Operator {
|
||||
case search.TokOpAnd:
|
||||
left, err := MatchResponseFilter(res, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchResponseFilter(res, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left && right, nil
|
||||
case search.TokOpOr:
|
||||
left, err := MatchResponseFilter(res, expr.Left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
right, err := MatchResponseFilter(res, expr.Right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return left || right, nil
|
||||
}
|
||||
|
||||
left, ok := expr.Left.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("left operand must be a string literal")
|
||||
}
|
||||
|
||||
leftVal, err := getMappedStringLiteralFromRes(res, left.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from response for left operand: %w", err)
|
||||
}
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpRe:
|
||||
return right.MatchString(leftVal), nil
|
||||
case search.TokOpNotRe:
|
||||
return !right.MatchString(leftVal), nil
|
||||
}
|
||||
}
|
||||
|
||||
right, ok := expr.Right.(search.StringLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a string literal")
|
||||
}
|
||||
|
||||
rightVal, err := getMappedStringLiteralFromRes(res, right.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get string literal from response for right operand: %w", err)
|
||||
}
|
||||
|
||||
switch expr.Operator {
|
||||
case search.TokOpEq:
|
||||
return leftVal == rightVal, nil
|
||||
case search.TokOpNotEq:
|
||||
return leftVal != rightVal, nil
|
||||
case search.TokOpGt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal > rightVal, nil
|
||||
case search.TokOpLt:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal < rightVal, nil
|
||||
case search.TokOpGtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal >= rightVal, nil
|
||||
case search.TokOpLtEq:
|
||||
// TODO(?) attempt to parse as int.
|
||||
return leftVal <= rightVal, nil
|
||||
default:
|
||||
return false, errors.New("unsupported operator")
|
||||
}
|
||||
}
|
||||
|
||||
func getMappedStringLiteralFromRes(res *http.Response, s string) (string, error) {
|
||||
fn, ok := resFilterKeyFns[s]
|
||||
if ok {
|
||||
return fn(res)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func matchResStringLiteral(res *http.Response, strLiteral search.StringLiteral) (bool, error) {
|
||||
for _, fn := range resFilterKeyFns {
|
||||
value, err := fn(res)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strLiteral.Value)) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
452
pkg/proxy/intercept/intercept.go
Normal file
452
pkg/proxy/intercept/intercept.go
Normal file
@ -0,0 +1,452 @@
|
||||
package intercept
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/log"
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestAborted = errors.New("intercept: request was aborted")
|
||||
ErrRequestNotFound = errors.New("intercept: request not found")
|
||||
ErrRequestDone = errors.New("intercept: request is done")
|
||||
ErrResponseNotFound = errors.New("intercept: response not found")
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const interceptResponseKey contextKey = 0
|
||||
|
||||
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
|
||||
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
|
||||
type Request struct {
|
||||
req *http.Request
|
||||
ch chan<- *http.Request
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Response represents an HTTP response from a proxied request, alongside a channel for sending a modified version of it
|
||||
// to the routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
|
||||
type Response struct {
|
||||
res *http.Response
|
||||
ch chan<- *http.Response
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Request *http.Request
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
reqMu *sync.RWMutex
|
||||
resMu *sync.RWMutex
|
||||
requests map[ulid.ULID]Request
|
||||
responses map[ulid.ULID]Response
|
||||
logger log.Logger
|
||||
|
||||
requestsEnabled bool
|
||||
responsesEnabled bool
|
||||
reqFilter search.Expression
|
||||
resFilter search.Expression
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger log.Logger
|
||||
RequestsEnabled bool
|
||||
ResponsesEnabled bool
|
||||
RequestFilter search.Expression
|
||||
ResponseFilter search.Expression
|
||||
}
|
||||
|
||||
// RequestIDs implements sort.Interface.
|
||||
type RequestIDs []ulid.ULID
|
||||
|
||||
func NewService(cfg Config) *Service {
|
||||
s := &Service{
|
||||
reqMu: &sync.RWMutex{},
|
||||
resMu: &sync.RWMutex{},
|
||||
requests: make(map[ulid.ULID]Request),
|
||||
responses: make(map[ulid.ULID]Response),
|
||||
logger: cfg.Logger,
|
||||
requestsEnabled: cfg.RequestsEnabled,
|
||||
responsesEnabled: cfg.ResponsesEnabled,
|
||||
reqFilter: cfg.RequestFilter,
|
||||
resFilter: cfg.ResponseFilter,
|
||||
}
|
||||
|
||||
if s.logger == nil {
|
||||
s.logger = log.NewNopLogger()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP requests.
|
||||
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
||||
return func(req *http.Request) {
|
||||
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
|
||||
// (typically `context.Canceled`).
|
||||
modifiedReq, err := svc.InterceptRequest(req.Context(), req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrRequestAborted):
|
||||
svc.logger.Debugw("Stopping intercept, request was aborted.")
|
||||
// Prevent further processing by replacing req.Context with a cancelled context value.
|
||||
// This will cause the http.Roundtripper in the `proxy` package to
|
||||
// handle this request as an error.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
*req = *req.WithContext(ctx)
|
||||
case errors.Is(err, context.Canceled):
|
||||
svc.logger.Debugw("Stopping intercept, context was cancelled.")
|
||||
case err != nil:
|
||||
svc.logger.Errorw("Failed to intercept request.",
|
||||
"error", err)
|
||||
default:
|
||||
*req = *modifiedReq
|
||||
next(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InterceptRequest adds an HTTP request to an array of pending intercepted requests, alongside channels used for
|
||||
// sending a cancellation signal and receiving a modified request. It's safe for concurrent use.
|
||||
func (svc *Service) InterceptRequest(ctx context.Context, req *http.Request) (*http.Request, error) {
|
||||
reqID, ok := proxy.RequestIDFromContext(ctx)
|
||||
if !ok {
|
||||
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if !svc.requestsEnabled {
|
||||
// If request intercept is disabled, return the incoming request as-is.
|
||||
svc.logger.Debugw("Bypassed request interception: feature disabled.")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if svc.reqFilter != nil {
|
||||
match, err := MatchRequestFilter(req, svc.reqFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
|
||||
reqID.String(), err,
|
||||
)
|
||||
}
|
||||
|
||||
if !match {
|
||||
svc.logger.Debugw("Bypassed request interception: request rules don't match.")
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan *http.Request)
|
||||
done := make(chan struct{})
|
||||
|
||||
svc.reqMu.Lock()
|
||||
svc.requests[reqID] = Request{
|
||||
req: req,
|
||||
ch: ch,
|
||||
done: done,
|
||||
}
|
||||
svc.reqMu.Unlock()
|
||||
|
||||
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
|
||||
// should be unblocked, and the request should be removed from the requests queue.
|
||||
defer func() {
|
||||
close(done)
|
||||
svc.reqMu.Lock()
|
||||
defer svc.reqMu.Unlock()
|
||||
delete(svc.requests, reqID)
|
||||
}()
|
||||
|
||||
select {
|
||||
case modReq := <-ch:
|
||||
if modReq == nil {
|
||||
return nil, ErrRequestAborted
|
||||
}
|
||||
|
||||
return modReq, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
|
||||
// cancelled. It's safe for concurrent use.
|
||||
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request, modifyResponse *bool) error {
|
||||
svc.reqMu.RLock()
|
||||
req, ok := svc.requests[reqID]
|
||||
svc.reqMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return ErrRequestNotFound
|
||||
}
|
||||
|
||||
*modReq = *modReq.WithContext(req.req.Context())
|
||||
if modifyResponse != nil {
|
||||
*modReq = *modReq.WithContext(WithInterceptResponse(modReq.Context(), *modifyResponse))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-req.done:
|
||||
return ErrRequestDone
|
||||
case req.ch <- modReq:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CancelRequest ensures an intercepted request is dropped.
|
||||
func (svc *Service) CancelRequest(reqID ulid.ULID) error {
|
||||
return svc.ModifyRequest(reqID, nil, nil)
|
||||
}
|
||||
|
||||
func (svc *Service) ClearRequests() {
|
||||
svc.reqMu.Lock()
|
||||
defer svc.reqMu.Unlock()
|
||||
|
||||
for _, req := range svc.requests {
|
||||
select {
|
||||
case <-req.done:
|
||||
case req.ch <- nil:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) ClearResponses() {
|
||||
svc.resMu.Lock()
|
||||
defer svc.resMu.Unlock()
|
||||
|
||||
for _, res := range svc.responses {
|
||||
select {
|
||||
case <-res.done:
|
||||
case res.ch <- nil:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Items returns a list of pending items (requests and responses). It's safe for concurrent use.
|
||||
func (svc *Service) Items() []Item {
|
||||
svc.reqMu.RLock()
|
||||
defer svc.reqMu.RUnlock()
|
||||
|
||||
svc.resMu.RLock()
|
||||
defer svc.resMu.RUnlock()
|
||||
|
||||
reqIDs := make([]ulid.ULID, 0, len(svc.requests)+len(svc.responses))
|
||||
|
||||
for id := range svc.requests {
|
||||
reqIDs = append(reqIDs, id)
|
||||
}
|
||||
|
||||
for id := range svc.responses {
|
||||
reqIDs = append(reqIDs, id)
|
||||
}
|
||||
|
||||
sort.Sort(RequestIDs(reqIDs))
|
||||
|
||||
items := make([]Item, len(reqIDs))
|
||||
|
||||
for i, id := range reqIDs {
|
||||
item := Item{}
|
||||
|
||||
if req, ok := svc.requests[id]; ok {
|
||||
item.Request = req.req
|
||||
}
|
||||
|
||||
if res, ok := svc.responses[id]; ok {
|
||||
item.Response = res.res
|
||||
}
|
||||
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func (svc *Service) UpdateSettings(settings Settings) {
|
||||
// When updating from requests `enabled` -> `disabled`, clear any pending reqs.
|
||||
if svc.requestsEnabled && !settings.RequestsEnabled {
|
||||
svc.ClearRequests()
|
||||
}
|
||||
|
||||
// When updating from responses `enabled` -> `disabled`, clear any pending responses.
|
||||
if svc.responsesEnabled && !settings.ResponsesEnabled {
|
||||
svc.ClearResponses()
|
||||
}
|
||||
|
||||
svc.requestsEnabled = settings.RequestsEnabled
|
||||
svc.responsesEnabled = settings.ResponsesEnabled
|
||||
svc.reqFilter = settings.RequestFilter
|
||||
svc.resFilter = settings.ResponseFilter
|
||||
}
|
||||
|
||||
// ItemByID returns an intercepted item (request and possible response) by ID. It's safe for concurrent use.
|
||||
func (svc *Service) ItemByID(id ulid.ULID) (Item, error) {
|
||||
svc.reqMu.RLock()
|
||||
defer svc.reqMu.RUnlock()
|
||||
|
||||
svc.resMu.RLock()
|
||||
defer svc.resMu.RUnlock()
|
||||
|
||||
item := Item{}
|
||||
found := false
|
||||
|
||||
if req, ok := svc.requests[id]; ok {
|
||||
item.Request = req.req
|
||||
found = true
|
||||
}
|
||||
|
||||
if res, ok := svc.responses[id]; ok {
|
||||
item.Response = res.res
|
||||
found = true
|
||||
}
|
||||
|
||||
if !found {
|
||||
return Item{}, ErrRequestNotFound
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Len() int {
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Less(i, j int) bool {
|
||||
return ids[i].Compare(ids[j]) == -1
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Swap(i, j int) {
|
||||
ids[i], ids[j] = ids[j], ids[i]
|
||||
}
|
||||
|
||||
func WithInterceptResponse(ctx context.Context, value bool) context.Context {
|
||||
return context.WithValue(ctx, interceptResponseKey, value)
|
||||
}
|
||||
|
||||
func ShouldInterceptResponseFromContext(ctx context.Context) (bool, bool) {
|
||||
shouldIntercept, ok := ctx.Value(interceptResponseKey).(bool)
|
||||
return shouldIntercept, ok
|
||||
}
|
||||
|
||||
// ResponseModifier is a proxy.ResponseModifyMiddleware for intercepting HTTP responses.
|
||||
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
|
||||
return func(res *http.Response) error {
|
||||
// This is a blocking operation, that gets unblocked when either a modified response is returned or an error.
|
||||
//nolint:bodyclose
|
||||
modifiedRes, err := svc.InterceptResponse(res.Request.Context(), res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to intercept response: %w", err)
|
||||
}
|
||||
|
||||
*res = *modifiedRes
|
||||
|
||||
return next(res)
|
||||
}
|
||||
}
|
||||
|
||||
// InterceptResponse adds an HTTP response to an array of pending intercepted responses, alongside channels used for
|
||||
// sending a cancellation signal and receiving a modified response. It's safe for concurrent use.
|
||||
func (svc *Service) InterceptResponse(ctx context.Context, res *http.Response) (*http.Response, error) {
|
||||
reqID, ok := proxy.RequestIDFromContext(ctx)
|
||||
if !ok {
|
||||
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
shouldIntercept, ok := ShouldInterceptResponseFromContext(ctx)
|
||||
if ok && !shouldIntercept {
|
||||
// If the related request explicitly disabled response intercept, return the response as-is.
|
||||
svc.logger.Debugw("Bypassed response interception: related request explicitly disabled response intercept.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// If global response intercept is disabled and interception is *not* explicitly enabled for this response: bypass.
|
||||
if !svc.responsesEnabled && !(ok && shouldIntercept) {
|
||||
svc.logger.Debugw("Bypassed response interception: feature disabled.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if svc.resFilter != nil {
|
||||
match, err := MatchResponseFilter(res, svc.resFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("intercept: failed to match response rules for response (id: %v): %w",
|
||||
reqID.String(), err,
|
||||
)
|
||||
}
|
||||
|
||||
if !match {
|
||||
svc.logger.Debugw("Bypassed response interception: response rules don't match.")
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan *http.Response)
|
||||
done := make(chan struct{})
|
||||
|
||||
svc.resMu.Lock()
|
||||
svc.responses[reqID] = Response{
|
||||
res: res,
|
||||
ch: ch,
|
||||
done: done,
|
||||
}
|
||||
svc.resMu.Unlock()
|
||||
|
||||
// Whatever happens next (modified response returned, or a context cancelled error), any blocked channel senders
|
||||
// should be unblocked, and the response should be removed from the responses queue.
|
||||
defer func() {
|
||||
close(done)
|
||||
svc.resMu.Lock()
|
||||
defer svc.resMu.Unlock()
|
||||
delete(svc.responses, reqID)
|
||||
}()
|
||||
|
||||
select {
|
||||
case modRes := <-ch:
|
||||
if modRes == nil {
|
||||
return nil, ErrRequestAborted
|
||||
}
|
||||
|
||||
return modRes, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyResponse sends a modified HTTP response to the related channel, or returns ErrRequestDone when the related
|
||||
// request was cancelled. It's safe for concurrent use.
|
||||
func (svc *Service) ModifyResponse(reqID ulid.ULID, modRes *http.Response) error {
|
||||
svc.resMu.RLock()
|
||||
res, ok := svc.responses[reqID]
|
||||
svc.resMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return ErrRequestNotFound
|
||||
}
|
||||
|
||||
if modRes != nil {
|
||||
modRes.Request = res.res.Request
|
||||
}
|
||||
|
||||
select {
|
||||
case <-res.done:
|
||||
return ErrRequestDone
|
||||
case res.ch <- modRes:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CancelResponse ensures an intercepted response is dropped.
|
||||
func (svc *Service) CancelResponse(reqID ulid.ULID) error {
|
||||
return svc.ModifyResponse(reqID, nil)
|
||||
}
|
270
pkg/proxy/intercept/intercept_test.go
Normal file
270
pkg/proxy/intercept/intercept_test.go
Normal file
@ -0,0 +1,270 @@
|
||||
package intercept_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/proxy/intercept"
|
||||
)
|
||||
|
||||
//nolint:gosec
|
||||
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
func TestRequestModifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("modify request that's not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: true,
|
||||
ResponsesEnabled: false,
|
||||
})
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
|
||||
err := svc.ModifyRequest(reqID, nil, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestNotFound) {
|
||||
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify request that's done", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: true,
|
||||
ResponsesEnabled: false,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(ctx)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
next := func(req *http.Request) {}
|
||||
go svc.RequestModifier(next)(req)
|
||||
|
||||
// Wait shortly, to allow the req modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
modReq := req.Clone(req.Context())
|
||||
modReq.Header.Set("X-Foo", "bar")
|
||||
|
||||
err := svc.ModifyRequest(reqID, modReq, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestDone) {
|
||||
t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify intercepted request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
req.Header.Set("X-Foo", "foo")
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
modReq := req.Clone(context.Background())
|
||||
modReq.Header.Set("X-Foo", "bar")
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: true,
|
||||
ResponsesEnabled: false,
|
||||
})
|
||||
|
||||
var got *http.Request
|
||||
|
||||
next := func(req *http.Request) {
|
||||
got = req.Clone(context.Background())
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
svc.RequestModifier(next)(req)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the req modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := svc.ModifyRequest(reqID, modReq, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("expected `got` not to be nil")
|
||||
}
|
||||
|
||||
if exp := "bar"; exp != got.Header.Get("X-Foo") {
|
||||
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, got.Header.Get("X-Foo"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseModifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("modify response that's not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
|
||||
err := svc.ModifyResponse(reqID, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestNotFound) {
|
||||
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify response of request that's done", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(ctx)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
res.Header.Add("X-Foo", "foo")
|
||||
|
||||
var modErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
next := func(res *http.Response) error { return nil }
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
modErr = svc.ResponseModifier(next)(res)
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the res modifier goroutine to add `res` to the
|
||||
// array of intercepted responses.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
modRes := *res
|
||||
modRes.Header = make(http.Header)
|
||||
modRes.Header.Set("X-Foo", "bar")
|
||||
|
||||
err := svc.ModifyResponse(reqID, &modRes)
|
||||
if !errors.Is(err, intercept.ErrRequestDone) {
|
||||
t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if !errors.Is(modErr, context.Canceled) {
|
||||
t.Fatalf("expected `context.Canceled`, got: %v", modErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify intercepted response", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
req.Header.Set("X-Foo", "foo")
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
res.Header.Add("X-Foo", "foo")
|
||||
|
||||
modRes := *res
|
||||
modRes.Header = make(http.Header)
|
||||
modRes.Header.Set("X-Foo", "bar")
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
var gotHeader string
|
||||
|
||||
var next proxy.ResponseModifyFunc = func(res *http.Response) error {
|
||||
gotHeader = res.Header.Get("X-Foo")
|
||||
return nil
|
||||
}
|
||||
|
||||
var modErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
modErr = svc.ResponseModifier(next)(res)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the res modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := svc.ModifyResponse(reqID, &modRes)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if modErr != nil {
|
||||
t.Fatalf("unexpected error: %v", modErr)
|
||||
}
|
||||
|
||||
if exp := "bar"; exp != gotHeader {
|
||||
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, gotHeader)
|
||||
}
|
||||
})
|
||||
}
|
10
pkg/proxy/intercept/settings.go
Normal file
10
pkg/proxy/intercept/settings.go
Normal file
@ -0,0 +1,10 @@
|
||||
package intercept
|
||||
|
||||
import "github.com/dstotijn/hetty/pkg/search"
|
||||
|
||||
type Settings struct {
|
||||
RequestsEnabled bool
|
||||
ResponsesEnabled bool
|
||||
RequestFilter search.Expression
|
||||
ResponseFilter search.Expression
|
||||
}
|
@ -7,16 +7,24 @@ import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/log"
|
||||
)
|
||||
|
||||
//nolint:gosec
|
||||
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
type contextKey int
|
||||
|
||||
const ReqLogIDKey contextKey = 0
|
||||
const reqIDKey contextKey = 0
|
||||
|
||||
// Proxy implements http.Handler and offers MITM behaviour for modifying
|
||||
// HTTP requests and responses.
|
||||
@ -54,7 +62,25 @@ func NewProxy(cfg Config) (*Proxy, error) {
|
||||
p.logger = log.NewNopLogger()
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
// Values taken from `http.DefaultTransport`.
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
|
||||
// Non-default transport values.
|
||||
DisableCompression: true,
|
||||
}
|
||||
|
||||
p.handler = &httputil.ReverseProxy{
|
||||
Transport: transport,
|
||||
Director: p.modifyRequest,
|
||||
ModifyResponse: p.modifyResponse,
|
||||
ErrorHandler: p.errorHandler,
|
||||
@ -69,6 +95,10 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
ctx := context.WithValue(r.Context(), reqIDKey, reqID)
|
||||
*r = *r.WithContext(ctx)
|
||||
|
||||
p.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
@ -91,6 +121,25 @@ func (p *Proxy) modifyRequest(r *http.Request) {
|
||||
// set this header.
|
||||
r.Header["X-Forwarded-For"] = nil
|
||||
|
||||
// Strip unsupported encodings.
|
||||
if acceptEncs := r.Header.Get("Accept-Encoding"); acceptEncs != "" {
|
||||
directives := strings.Split(acceptEncs, ",")
|
||||
updated := make([]string, 0, len(directives))
|
||||
|
||||
for _, directive := range directives {
|
||||
stripped := strings.TrimSpace(directive)
|
||||
if strings.HasPrefix(stripped, "*") || strings.HasPrefix(stripped, "gzip") {
|
||||
updated = append(updated, stripped)
|
||||
}
|
||||
}
|
||||
|
||||
if len(updated) == 0 {
|
||||
r.Header.Del("Accept-Encoding")
|
||||
} else {
|
||||
r.Header.Set("Accept-Encoding", strings.Join(updated, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
fn := nopReqModifier
|
||||
|
||||
for i := len(p.reqModifiers) - 1; i >= 0; i-- {
|
||||
@ -103,6 +152,11 @@ func (p *Proxy) modifyRequest(r *http.Request) {
|
||||
func (p *Proxy) modifyResponse(res *http.Response) error {
|
||||
fn := nopResModifier
|
||||
|
||||
// TODO: Make decompressing gzip formatted response bodies a configurable project setting.
|
||||
if err := gunzipResponseBody(res); err != nil {
|
||||
return fmt.Errorf("proxy: failed to gunzip response body: %w", err)
|
||||
}
|
||||
|
||||
for i := len(p.resModifiers) - 1; i >= 0; i-- {
|
||||
fn = p.resModifiers[i](fn)
|
||||
}
|
||||
@ -110,6 +164,15 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
|
||||
return fn(res)
|
||||
}
|
||||
|
||||
func WithRequestID(ctx context.Context, id ulid.ULID) context.Context {
|
||||
return context.WithValue(ctx, reqIDKey, id)
|
||||
}
|
||||
|
||||
func RequestIDFromContext(ctx context.Context) (ulid.ULID, bool) {
|
||||
id, ok := ctx.Value(reqIDKey).(ulid.ULID)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
|
||||
// During the TLS handshake with the client, we use the proxy's CA config to
|
||||
// create a certificate on-the-fly.
|
||||
@ -170,13 +233,14 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
|
||||
}
|
||||
|
||||
func (p *Proxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
switch {
|
||||
case !errors.Is(err, context.Canceled):
|
||||
p.logger.Errorw("Failed to proxy request.",
|
||||
"error", err)
|
||||
case errors.Is(err, context.Canceled):
|
||||
p.logger.Debugw("Proxy request was cancelled.")
|
||||
}
|
||||
|
||||
p.logger.Errorw("Failed to proxy request.",
|
||||
"error", err)
|
||||
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
|
@ -2,16 +2,13 @@ package reqlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
|
||||
@ -23,16 +20,16 @@ import (
|
||||
|
||||
type contextKey int
|
||||
|
||||
const LogBypassedKey contextKey = 0
|
||||
const (
|
||||
LogBypassedKey contextKey = iota
|
||||
ReqLogIDKey
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestNotFound = errors.New("reqlog: request not found")
|
||||
ErrProjectIDMustBeSet = errors.New("reqlog: project ID must be set")
|
||||
)
|
||||
|
||||
//nolint:gosec
|
||||
var ulidEntropy = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
type RequestLog struct {
|
||||
ID ulid.ULID
|
||||
ProjectID ulid.ULID
|
||||
@ -170,8 +167,14 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||
return
|
||||
}
|
||||
|
||||
reqID, ok := proxy.RequestIDFromContext(req.Context())
|
||||
if !ok {
|
||||
svc.logger.Errorw("Bypassed logging: request doesn't have an ID.")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := RequestLog{
|
||||
ID: ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy),
|
||||
ID: reqID,
|
||||
ProjectID: svc.activeProjectID,
|
||||
Method: clone.Method,
|
||||
URL: clone.URL,
|
||||
@ -191,7 +194,7 @@ func (svc *service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||
"reqLogID", reqLog.ID.String(),
|
||||
"url", reqLog.URL.String())
|
||||
|
||||
ctx := context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLog.ID)
|
||||
ctx := context.WithValue(req.Context(), ReqLogIDKey, reqLog.ID)
|
||||
*req = *req.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@ -206,21 +209,23 @@ func (svc *service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
|
||||
return nil
|
||||
}
|
||||
|
||||
reqLogID, ok := res.Request.Context().Value(proxy.ReqLogIDKey).(ulid.ULID)
|
||||
reqLogID, ok := res.Request.Context().Value(ReqLogIDKey).(ulid.ULID)
|
||||
if !ok {
|
||||
return errors.New("reqlog: request is missing ID")
|
||||
}
|
||||
|
||||
clone := *res
|
||||
|
||||
// TODO: Use io.LimitReader.
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||
}
|
||||
if res.Body != nil {
|
||||
// TODO: Use io.LimitReader.
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
res.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := svc.storeResponse(context.Background(), reqLogID, &clone); err != nil {
|
||||
@ -261,23 +266,6 @@ func (svc *service) BypassOutOfScopeRequests() bool {
|
||||
}
|
||||
|
||||
func ParseHTTPResponse(res *http.Response) (ResponseLog, error) {
|
||||
if res.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(res.Body)
|
||||
if err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
//nolint:gosec
|
||||
if _, err := io.Copy(buf, gzipReader); err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = io.NopCloser(buf)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return ResponseLog{}, fmt.Errorf("reqlog: could not read body: %w", err)
|
||||
|
@ -41,6 +41,8 @@ func TestRequestModifier(t *testing.T) {
|
||||
}
|
||||
reqModFn := svc.RequestModifier(next)
|
||||
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
req = req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
reqModFn(req)
|
||||
|
||||
@ -88,7 +90,7 @@ func TestResponseModifier(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/", strings.NewReader("bar"))
|
||||
reqLogID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
req = req.WithContext(context.WithValue(req.Context(), proxy.ReqLogIDKey, reqLogID))
|
||||
req = req.WithContext(context.WithValue(req.Context(), reqlog.ReqLogIDKey, reqLogID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
|
@ -3,7 +3,6 @@ package reqlog
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -100,7 +99,7 @@ func (reqLog RequestLog) matchInfixExpr(expr search.InfixExpression) (bool, erro
|
||||
leftVal := reqLog.getMappedStringLiteral(left.Value)
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(*regexp.Regexp)
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package search
|
||||
import (
|
||||
"encoding/gob"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -50,13 +51,17 @@ type StringLiteral struct {
|
||||
}
|
||||
|
||||
func (sl StringLiteral) String() string {
|
||||
return sl.Value
|
||||
return strconv.Quote(sl.Value)
|
||||
}
|
||||
|
||||
type RegexpLiteral struct {
|
||||
*regexp.Regexp
|
||||
}
|
||||
|
||||
func (rl RegexpLiteral) String() string {
|
||||
return strconv.Quote(rl.Regexp.String())
|
||||
}
|
||||
|
||||
func (rl RegexpLiteral) MarshalBinary() ([]byte, error) {
|
||||
return []byte(rl.Regexp.String()), nil
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
|
||||
return nil, fmt.Errorf("could not compile regular expression %q: %w", rightStr.Value, err)
|
||||
}
|
||||
|
||||
right = re
|
||||
right = RegexpLiteral{re}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ func TestParseQuery(t *testing.T) {
|
||||
expectedExpression: InfixExpression{
|
||||
Operator: TokOpRe,
|
||||
Left: StringLiteral{Value: "foo"},
|
||||
Right: regexp.MustCompile("bar"),
|
||||
Right: RegexpLiteral{regexp.MustCompile("bar")},
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
@ -104,7 +104,7 @@ func TestParseQuery(t *testing.T) {
|
||||
expectedExpression: InfixExpression{
|
||||
Operator: TokOpNotRe,
|
||||
Left: StringLiteral{Value: "foo"},
|
||||
Right: regexp.MustCompile("bar"),
|
||||
Right: RegexpLiteral{regexp.MustCompile("bar")},
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
@ -197,7 +197,7 @@ func TestParseQuery(t *testing.T) {
|
||||
Right: InfixExpression{
|
||||
Operator: TokOpRe,
|
||||
Left: StringLiteral{Value: "baz"},
|
||||
Right: regexp.MustCompile("yolo"),
|
||||
Right: RegexpLiteral{regexp.MustCompile("yolo")},
|
||||
},
|
||||
},
|
||||
expectedError: nil,
|
||||
|
@ -3,7 +3,6 @@ package sender
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/oklog/ulid"
|
||||
@ -93,7 +92,7 @@ func (req Request) matchInfixExpr(expr search.InfixExpression) (bool, error) {
|
||||
leftVal := req.getMappedStringLiteral(left.Value)
|
||||
|
||||
if expr.Operator == search.TokOpRe || expr.Operator == search.TokOpNotRe {
|
||||
right, ok := expr.Right.(*regexp.Regexp)
|
||||
right, ok := expr.Right.(search.RegexpLiteral)
|
||||
if !ok {
|
||||
return false, errors.New("right operand must be a regular expression")
|
||||
}
|
||||
|
Reference in New Issue
Block a user