mirror of
https://github.com/dstotijn/hetty.git
synced 2025-07-01 18:47:29 -04:00
Add support for intercepting HTTP responses
This commit is contained in:
@ -16,11 +16,16 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestAborted = errors.New("intercept: request was aborted")
|
||||
ErrRequestNotFound = errors.New("intercept: request not found")
|
||||
ErrRequestDone = errors.New("intercept: request is done")
|
||||
ErrRequestAborted = errors.New("intercept: request was aborted")
|
||||
ErrRequestNotFound = errors.New("intercept: request not found")
|
||||
ErrRequestDone = errors.New("intercept: request is done")
|
||||
ErrResponseNotFound = errors.New("intercept: response not found")
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const interceptResponseKey contextKey = 0
|
||||
|
||||
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
|
||||
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
|
||||
type Request struct {
|
||||
@ -29,9 +34,24 @@ type Request struct {
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Response represents an HTTP response from a proxied request, alongside a channel for sending a modified version of it
|
||||
// to the routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
|
||||
type Response struct {
|
||||
res *http.Response
|
||||
ch chan<- *http.Response
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Request *http.Request
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
mu *sync.RWMutex
|
||||
reqMu *sync.RWMutex
|
||||
resMu *sync.RWMutex
|
||||
requests map[ulid.ULID]Request
|
||||
responses map[ulid.ULID]Response
|
||||
logger log.Logger
|
||||
enabled bool
|
||||
reqFilter search.Expression
|
||||
@ -48,8 +68,10 @@ type RequestIDs []ulid.ULID
|
||||
|
||||
func NewService(cfg Config) *Service {
|
||||
s := &Service{
|
||||
mu: &sync.RWMutex{},
|
||||
reqMu: &sync.RWMutex{},
|
||||
resMu: &sync.RWMutex{},
|
||||
requests: make(map[ulid.ULID]Request),
|
||||
responses: make(map[ulid.ULID]Response),
|
||||
logger: cfg.Logger,
|
||||
enabled: cfg.Enabled,
|
||||
reqFilter: cfg.RequestFilter,
|
||||
@ -62,13 +84,12 @@ func NewService(cfg Config) *Service {
|
||||
return s
|
||||
}
|
||||
|
||||
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP
|
||||
// requests.
|
||||
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP requests.
|
||||
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
||||
return func(req *http.Request) {
|
||||
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
|
||||
// (typically `context.Canceled`).
|
||||
modifiedReq, err := svc.Intercept(req.Context(), req)
|
||||
modifiedReq, err := svc.InterceptRequest(req.Context(), req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrRequestAborted):
|
||||
@ -86,24 +107,24 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||
svc.logger.Errorw("Failed to intercept request.",
|
||||
"error", err)
|
||||
default:
|
||||
*req = *modifiedReq.WithContext(req.Context())
|
||||
*req = *modifiedReq
|
||||
next(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Intercept adds an HTTP request to an array of pending intercepted requests, alongside channels used for sending a
|
||||
// cancellation signal and receiving a modified request. It's safe for concurrent use.
|
||||
func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Request, error) {
|
||||
// InterceptRequest adds an HTTP request to an array of pending intercepted requests, alongside channels used for
|
||||
// sending a cancellation signal and receiving a modified request. It's safe for concurrent use.
|
||||
func (svc *Service) InterceptRequest(ctx context.Context, req *http.Request) (*http.Request, error) {
|
||||
reqID, ok := proxy.RequestIDFromContext(ctx)
|
||||
if !ok {
|
||||
svc.logger.Errorw("Failed to intercept: request doesn't have an ID.")
|
||||
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
if !svc.enabled {
|
||||
// If intercept is disabled, return the incoming request as-is.
|
||||
svc.logger.Debugw("Bypassed interception: module disabled.")
|
||||
// If request intercept is disabled, return the incoming request as-is.
|
||||
svc.logger.Debugw("Bypassed request interception: feature disabled.")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@ -116,7 +137,7 @@ func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Req
|
||||
}
|
||||
|
||||
if !match {
|
||||
svc.logger.Debugw("Bypassed interception: request rules don't match.")
|
||||
svc.logger.Debugw("Bypassed request interception: request rules don't match.")
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
@ -124,20 +145,20 @@ func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Req
|
||||
ch := make(chan *http.Request)
|
||||
done := make(chan struct{})
|
||||
|
||||
svc.mu.Lock()
|
||||
svc.reqMu.Lock()
|
||||
svc.requests[reqID] = Request{
|
||||
req: req,
|
||||
ch: ch,
|
||||
done: done,
|
||||
}
|
||||
svc.mu.Unlock()
|
||||
svc.reqMu.Unlock()
|
||||
|
||||
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
|
||||
// should be unblocked, and the request should be removed from the requests queue.
|
||||
defer func() {
|
||||
close(done)
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
svc.reqMu.Lock()
|
||||
defer svc.reqMu.Unlock()
|
||||
delete(svc.requests, reqID)
|
||||
}()
|
||||
|
||||
@ -155,15 +176,20 @@ func (svc *Service) Intercept(ctx context.Context, req *http.Request) (*http.Req
|
||||
|
||||
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
|
||||
// cancelled. It's safe for concurrent use.
|
||||
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error {
|
||||
svc.mu.RLock()
|
||||
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request, modifyResponse *bool) error {
|
||||
svc.reqMu.RLock()
|
||||
req, ok := svc.requests[reqID]
|
||||
svc.mu.RUnlock()
|
||||
svc.reqMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return ErrRequestNotFound
|
||||
}
|
||||
|
||||
*modReq = *modReq.WithContext(req.req.Context())
|
||||
if modifyResponse != nil {
|
||||
*modReq = *modReq.WithContext(WithInterceptResponse(modReq.Context(), *modifyResponse))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-req.done:
|
||||
return ErrRequestDone
|
||||
@ -174,12 +200,12 @@ func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request) error {
|
||||
|
||||
// CancelRequest ensures an intercepted request is dropped.
|
||||
func (svc *Service) CancelRequest(reqID ulid.ULID) error {
|
||||
return svc.ModifyRequest(reqID, nil)
|
||||
return svc.ModifyRequest(reqID, nil, nil)
|
||||
}
|
||||
|
||||
func (svc *Service) ClearRequests() {
|
||||
svc.mu.Lock()
|
||||
defer svc.mu.Unlock()
|
||||
svc.reqMu.Lock()
|
||||
defer svc.reqMu.Unlock()
|
||||
|
||||
for _, req := range svc.requests {
|
||||
select {
|
||||
@ -189,47 +215,94 @@ func (svc *Service) ClearRequests() {
|
||||
}
|
||||
}
|
||||
|
||||
// Requests returns a list of pending intercepted requests. It's safe for concurrent use.
|
||||
func (svc *Service) Requests() []*http.Request {
|
||||
svc.mu.RLock()
|
||||
defer svc.mu.RUnlock()
|
||||
func (svc *Service) ClearResponses() {
|
||||
svc.resMu.Lock()
|
||||
defer svc.resMu.Unlock()
|
||||
|
||||
for _, res := range svc.responses {
|
||||
select {
|
||||
case <-res.done:
|
||||
case res.ch <- nil:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Items returns a list of pending items (requests and responses). It's safe for concurrent use.
|
||||
func (svc *Service) Items() []Item {
|
||||
svc.reqMu.RLock()
|
||||
defer svc.reqMu.RUnlock()
|
||||
|
||||
svc.resMu.RLock()
|
||||
defer svc.resMu.RUnlock()
|
||||
|
||||
reqIDs := make([]ulid.ULID, 0, len(svc.requests)+len(svc.responses))
|
||||
|
||||
ids := make([]ulid.ULID, 0, len(svc.requests))
|
||||
for id := range svc.requests {
|
||||
ids = append(ids, id)
|
||||
reqIDs = append(reqIDs, id)
|
||||
}
|
||||
|
||||
sort.Sort(RequestIDs(ids))
|
||||
|
||||
reqs := make([]*http.Request, len(ids))
|
||||
for i, id := range ids {
|
||||
reqs[i] = svc.requests[id].req
|
||||
for id := range svc.responses {
|
||||
reqIDs = append(reqIDs, id)
|
||||
}
|
||||
|
||||
return reqs
|
||||
sort.Sort(RequestIDs(reqIDs))
|
||||
|
||||
items := make([]Item, len(reqIDs))
|
||||
|
||||
for i, id := range reqIDs {
|
||||
item := Item{}
|
||||
|
||||
if req, ok := svc.requests[id]; ok {
|
||||
item.Request = req.req
|
||||
}
|
||||
|
||||
if res, ok := svc.responses[id]; ok {
|
||||
item.Response = res.res
|
||||
}
|
||||
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func (svc *Service) UpdateSettings(settings Settings) {
|
||||
// When updating from `enabled` -> `disabled`, clear any pending reqs.
|
||||
if svc.enabled && !settings.Enabled {
|
||||
svc.ClearRequests()
|
||||
svc.ClearResponses()
|
||||
}
|
||||
|
||||
svc.enabled = settings.Enabled
|
||||
svc.reqFilter = settings.RequestFilter
|
||||
}
|
||||
|
||||
// Request returns an intercepted request by ID. It's safe for concurrent use.
|
||||
func (svc *Service) RequestByID(id ulid.ULID) (*http.Request, error) {
|
||||
svc.mu.RLock()
|
||||
defer svc.mu.RUnlock()
|
||||
// ItemByID returns an intercepted item (request and possible response) by ID. It's safe for concurrent use.
|
||||
func (svc *Service) ItemByID(id ulid.ULID) (Item, error) {
|
||||
svc.reqMu.RLock()
|
||||
defer svc.reqMu.RUnlock()
|
||||
|
||||
req, ok := svc.requests[id]
|
||||
if !ok {
|
||||
return nil, ErrRequestNotFound
|
||||
svc.resMu.RLock()
|
||||
defer svc.resMu.RUnlock()
|
||||
|
||||
item := Item{}
|
||||
found := false
|
||||
|
||||
if req, ok := svc.requests[id]; ok {
|
||||
item.Request = req.req
|
||||
found = true
|
||||
}
|
||||
|
||||
return req.req, nil
|
||||
if res, ok := svc.responses[id]; ok {
|
||||
item.Response = res.res
|
||||
found = true
|
||||
}
|
||||
|
||||
if !found {
|
||||
return Item{}, ErrRequestNotFound
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (ids RequestIDs) Len() int {
|
||||
@ -243,3 +316,124 @@ func (ids RequestIDs) Less(i, j int) bool {
|
||||
func (ids RequestIDs) Swap(i, j int) {
|
||||
ids[i], ids[j] = ids[j], ids[i]
|
||||
}
|
||||
|
||||
func WithInterceptResponse(ctx context.Context, value bool) context.Context {
|
||||
return context.WithValue(ctx, interceptResponseKey, value)
|
||||
}
|
||||
|
||||
func ShouldInterceptResponseFromContext(ctx context.Context) (bool, bool) {
|
||||
shouldIntercept, ok := ctx.Value(interceptResponseKey).(bool)
|
||||
return shouldIntercept, ok
|
||||
}
|
||||
|
||||
// ResponseModifier is a proxy.ResponseModifyMiddleware for intercepting HTTP responses.
|
||||
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
|
||||
return func(res *http.Response) error {
|
||||
// This is a blocking operation, that gets unblocked when either a modified response is returned or an error.
|
||||
//nolint:bodyclose
|
||||
modifiedRes, err := svc.InterceptResponse(res.Request.Context(), res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to intercept response: %w", err)
|
||||
}
|
||||
|
||||
*res = *modifiedRes
|
||||
|
||||
return next(res)
|
||||
}
|
||||
}
|
||||
|
||||
// InterceptResponse adds an HTTP response to an array of pending intercepted responses, alongside channels used for
|
||||
// sending a cancellation signal and receiving a modified response. It's safe for concurrent use.
|
||||
func (svc *Service) InterceptResponse(ctx context.Context, res *http.Response) (*http.Response, error) {
|
||||
reqID, ok := proxy.RequestIDFromContext(ctx)
|
||||
if !ok {
|
||||
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
shouldIntercept, ok := ShouldInterceptResponseFromContext(ctx)
|
||||
if ok && !shouldIntercept {
|
||||
// If the related request explicitly disabled response intercept, return the response as-is.
|
||||
svc.logger.Debugw("Bypassed response interception: related request explicitly disabled response intercept.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if !svc.enabled {
|
||||
// If the feature is disabled, return the response as-is.
|
||||
svc.logger.Debugw("Bypassed response interception: feature disabled.")
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// if svc.reqFilter != nil {
|
||||
// match, err := MatchRequestFilter(req, svc.reqFilter)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
|
||||
// reqID.String(), err,
|
||||
// )
|
||||
// }
|
||||
|
||||
// if !match {
|
||||
// svc.logger.Debugw("Bypassed interception: request rules don't match.")
|
||||
// return req, nil
|
||||
// }
|
||||
// }
|
||||
|
||||
ch := make(chan *http.Response)
|
||||
done := make(chan struct{})
|
||||
|
||||
svc.resMu.Lock()
|
||||
svc.responses[reqID] = Response{
|
||||
res: res,
|
||||
ch: ch,
|
||||
done: done,
|
||||
}
|
||||
svc.resMu.Unlock()
|
||||
|
||||
// Whatever happens next (modified response returned, or a context cancelled error), any blocked channel senders
|
||||
// should be unblocked, and the response should be removed from the responses queue.
|
||||
defer func() {
|
||||
close(done)
|
||||
svc.resMu.Lock()
|
||||
defer svc.resMu.Unlock()
|
||||
delete(svc.responses, reqID)
|
||||
}()
|
||||
|
||||
select {
|
||||
case modRes := <-ch:
|
||||
if modRes == nil {
|
||||
return nil, ErrRequestAborted
|
||||
}
|
||||
|
||||
return modRes, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyResponse sends a modified HTTP response to the related channel, or returns ErrRequestDone when the related
|
||||
// request was cancelled. It's safe for concurrent use.
|
||||
func (svc *Service) ModifyResponse(reqID ulid.ULID, modRes *http.Response) error {
|
||||
svc.resMu.RLock()
|
||||
res, ok := svc.responses[reqID]
|
||||
svc.resMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return ErrRequestNotFound
|
||||
}
|
||||
|
||||
if modRes != nil {
|
||||
modRes.Request = res.res.Request
|
||||
}
|
||||
|
||||
select {
|
||||
case <-res.done:
|
||||
return ErrRequestDone
|
||||
case res.ch <- modRes:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CancelResponse ensures an intercepted response is dropped.
|
||||
func (svc *Service) CancelResponse(reqID ulid.ULID) error {
|
||||
return svc.ModifyResponse(reqID, nil)
|
||||
}
|
||||
|
Reference in New Issue
Block a user