Files
hetty/pkg/reqlog/reqlog.go

217 lines
5.0 KiB
Go
Raw Normal View History

2019-12-01 14:07:12 +01:00
package reqlog
import (
"bytes"
"compress/gzip"
2020-09-26 23:36:48 +02:00
"context"
"errors"
"fmt"
"io/ioutil"
"log"
2019-12-01 14:07:12 +01:00
"net/http"
"time"
2020-10-29 20:54:17 +01:00
"github.com/dstotijn/hetty/pkg/proj"
2020-09-22 18:33:02 +02:00
"github.com/dstotijn/hetty/pkg/proxy"
2020-10-01 21:46:35 +02:00
"github.com/dstotijn/hetty/pkg/scope"
2019-12-01 14:07:12 +01:00
)
2020-10-01 21:46:35 +02:00
type contextKey int
const LogBypassedKey contextKey = 0
2020-10-29 20:54:17 +01:00
const moduleName = "reqlog"
2020-10-11 17:09:39 +02:00
var (
ErrRequestNotFound = errors.New("reqlog: request not found")
)
type Request struct {
2020-10-04 11:50:03 +02:00
ID int64
Request http.Request
Body []byte
Timestamp time.Time
Response *Response
2019-12-01 14:07:12 +01:00
}
type Response struct {
2020-10-04 11:50:03 +02:00
ID int64
RequestID int64
Response http.Response
Body []byte
Timestamp time.Time
2019-12-01 14:07:12 +01:00
}
type Service struct {
2020-10-01 21:46:35 +02:00
BypassOutOfScopeRequests bool
2020-10-29 20:54:17 +01:00
FindReqsFilter FindRequestsFilter
2020-10-01 21:46:35 +02:00
scope *scope.Scope
repo Repository
}
2020-10-29 20:54:17 +01:00
type FindRequestsFilter struct {
OnlyInScope bool
2020-10-01 21:46:35 +02:00
}
type Config struct {
Scope *scope.Scope
Repository Repository
2020-10-29 20:54:17 +01:00
ProjectService *proj.Service
2020-10-01 21:46:35 +02:00
BypassOutOfScopeRequests bool
2019-12-01 14:07:12 +01:00
}
2020-10-01 21:46:35 +02:00
func NewService(cfg Config) *Service {
2020-10-29 20:54:17 +01:00
svc := &Service{
2020-10-01 21:46:35 +02:00
scope: cfg.Scope,
repo: cfg.Repository,
BypassOutOfScopeRequests: cfg.BypassOutOfScopeRequests,
}
2019-12-01 14:07:12 +01:00
2020-10-29 20:54:17 +01:00
cfg.ProjectService.OnProjectOpen(func(_ string) error {
err := svc.loadSettings()
if err == proj.ErrNoSettings {
return nil
}
if err != nil {
return fmt.Errorf("reqlog: could not load settings: %v", err)
}
return nil
})
cfg.ProjectService.OnProjectClose(func(_ string) error {
svc.unloadSettings()
return nil
})
2020-10-01 21:46:35 +02:00
2020-10-29 20:54:17 +01:00
return svc
}
func (svc *Service) FindRequests(ctx context.Context) ([]Request, error) {
return svc.repo.FindRequestLogs(ctx, svc.FindReqsFilter, svc.scope)
}
2020-10-04 11:50:03 +02:00
func (svc *Service) FindRequestLogByID(ctx context.Context, id int64) (Request, error) {
2020-09-26 23:36:48 +02:00
return svc.repo.FindRequestLogByID(ctx, id)
}
2020-10-29 20:54:17 +01:00
func (svc *Service) SetRequestLogFilter(ctx context.Context, filter FindRequestsFilter) error {
svc.FindReqsFilter = filter
return svc.repo.UpsertSettings(ctx, "reqlog", svc)
}
2020-10-04 11:50:03 +02:00
func (svc *Service) addRequest(
ctx context.Context,
req http.Request,
body []byte,
timestamp time.Time,
) (*Request, error) {
return svc.repo.AddRequestLog(ctx, req, body, timestamp)
2019-12-01 14:07:12 +01:00
}
2020-10-04 11:50:03 +02:00
func (svc *Service) addResponse(
ctx context.Context,
reqID int64,
res http.Response,
body []byte,
timestamp time.Time,
) (*Response, error) {
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
if err != nil {
2020-10-04 11:50:03 +02:00
return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err)
}
defer gzipReader.Close()
body, err = ioutil.ReadAll(gzipReader)
if err != nil {
2020-10-04 11:50:03 +02:00
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
}
}
2020-10-04 11:50:03 +02:00
return svc.repo.AddResponseLog(ctx, reqID, res, body, timestamp)
}
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
2020-10-04 11:50:03 +02:00
now := time.Now()
next(req)
clone := req.Clone(req.Context())
var body []byte
if req.Body != nil {
// TODO: Use io.LimitReader.
var err error
body, err = ioutil.ReadAll(req.Body)
if err != nil {
log.Printf("[ERROR] Could not read request body for logging: %v", err)
return
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
2020-10-01 21:46:35 +02:00
// Bypass logging if this setting is enabled and the incoming request
// doens't match any rules of the scope.
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
2020-10-04 11:50:03 +02:00
*req = *req.WithContext(ctx)
2020-10-01 21:46:35 +02:00
return
}
2020-10-04 11:50:03 +02:00
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
2020-10-29 20:54:17 +01:00
if err == proj.ErrNoProject {
2020-10-11 17:09:39 +02:00
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
return
}
2020-10-04 11:50:03 +02:00
if err != nil {
log.Printf("[ERROR] Could not store request log: %v", err)
return
}
2020-10-04 11:50:03 +02:00
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
*req = *req.WithContext(ctx)
}
}
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
2020-10-04 11:50:03 +02:00
now := time.Now()
if err := next(res); err != nil {
return err
}
2020-10-01 21:46:35 +02:00
if bypassed, _ := res.Request.Context().Value(LogBypassedKey).(bool); bypassed {
return nil
}
2020-10-04 11:50:03 +02:00
reqID, _ := res.Request.Context().Value(proxy.ReqIDKey).(int64)
if reqID == 0 {
return errors.New("reqlog: request is missing ID")
}
clone := *res
// TODO: Use io.LimitReader.
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("reqlog: could not read response body: %v", err)
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
go func() {
2020-10-04 11:50:03 +02:00
if _, err := svc.addResponse(context.Background(), reqID, clone, body, now); err != nil {
log.Printf("[ERROR] Could not store response log: %v", err)
}
}()
return nil
}
2019-12-01 14:07:12 +01:00
}
2020-10-29 20:54:17 +01:00
func (svc *Service) loadSettings() error {
return svc.repo.FindSettingsByModule(context.Background(), moduleName, svc)
}
func (svc *Service) unloadSettings() {
svc.BypassOutOfScopeRequests = false
svc.FindReqsFilter = FindRequestsFilter{}
}