Add middleware chaining for request and response modifiers

This commit is contained in:
David Stotijn
2019-11-30 09:24:38 +01:00
parent 9d9c46c63f
commit e4f5f2278a
3 changed files with 58 additions and 28 deletions

15
main.go
View File

@ -28,23 +28,28 @@ func main() {
log.Fatalf("[FATAL] Could not parse CA: %v", err) log.Fatalf("[FATAL] Could not parse CA: %v", err)
} }
proxy, err := proxy.NewProxy(caCert, tlsCA.PrivateKey) p, err := proxy.NewProxy(caCert, tlsCA.PrivateKey)
if err != nil { if err != nil {
log.Fatalf("[FATAL] Could not create Proxy: %v", err) log.Fatalf("[FATAL] Could not create Proxy: %v", err)
} }
proxy.UseRequestModifier(func(req *http.Request) { p.UseRequestModifier(func(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
log.Printf("[DEBUG] Incoming request: %v", req.URL) log.Printf("[DEBUG] Incoming request: %v", req.URL)
next(req)
}
}) })
proxy.UseResponseModifier(func(res *http.Response) error { p.UseResponseModifier(func(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
log.Printf("[DEBUG] Downstream response: %v %v %v", res.Proto, res.StatusCode, http.StatusText(res.StatusCode)) log.Printf("[DEBUG] Downstream response: %v %v %v", res.Proto, res.StatusCode, http.StatusText(res.StatusCode))
return nil return next(res)
}
}) })
s := &http.Server{ s := &http.Server{
Addr: ":8080", Addr: ":8080",
Handler: proxy, Handler: p,
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, // Disable HTTP/2 TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, // Disable HTTP/2
} }

View File

@ -7,5 +7,18 @@ var (
nopResModifier = func(res *http.Response) error { return nil } nopResModifier = func(res *http.Response) error { return nil }
) )
// RequestModifyFunc defines a type for a function that can modify a HTTP
// request before it's proxied.
type RequestModifyFunc func(req *http.Request) type RequestModifyFunc func(req *http.Request)
// RequestModifyMiddleware defines a type for chaining request modifier
// middleware.
type RequestModifyMiddleware func(next RequestModifyFunc) RequestModifyFunc
// ResponseModifyFunc defines a type for a function that can modify a HTTP
// response before it's written back to the client.
type ResponseModifyFunc func(res *http.Response) error type ResponseModifyFunc func(res *http.Response) error
// ResponseModifyMiddleware defines a type for chaining response modifier
// middleware.
type ResponseModifyMiddleware func(ResponseModifyFunc) ResponseModifyFunc

View File

@ -12,14 +12,6 @@ import (
"net/http/httputil" "net/http/httputil"
) )
func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
if err == context.Canceled {
return
}
log.Printf("[ERROR]: Proxy error: %v", err)
w.WriteHeader(http.StatusBadGateway)
}
// Proxy implements http.Handler and offers MITM behaviour for modifying // Proxy implements http.Handler and offers MITM behaviour for modifying
// HTTP requests and responses. // HTTP requests and responses.
type Proxy struct { type Proxy struct {
@ -27,8 +19,8 @@ type Proxy struct {
handler http.Handler handler http.Handler
// TODO: Add mutex for modifier funcs. // TODO: Add mutex for modifier funcs.
reqModifier RequestModifyFunc reqModifiers []RequestModifyMiddleware
resModifier ResponseModifyFunc resModifiers []ResponseModifyMiddleware
} }
// NewProxy returns a new Proxy. // NewProxy returns a new Proxy.
@ -40,8 +32,8 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
p := &Proxy{ p := &Proxy{
certConfig: certConfig, certConfig: certConfig,
reqModifier: nopReqModifier, reqModifiers: make([]RequestModifyMiddleware, 0),
resModifier: nopResModifier, resModifiers: make([]ResponseModifyMiddleware, 0),
} }
p.handler = &httputil.ReverseProxy{ p.handler = &httputil.ReverseProxy{
@ -62,6 +54,14 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.handler.ServeHTTP(w, r) p.handler.ServeHTTP(w, r)
} }
func (p *Proxy) UseRequestModifier(fn ...RequestModifyMiddleware) {
p.reqModifiers = append(p.reqModifiers, fn...)
}
func (p *Proxy) UseResponseModifier(fn ...ResponseModifyMiddleware) {
p.resModifiers = append(p.resModifiers, fn...)
}
func (p *Proxy) modifyRequest(r *http.Request) { func (p *Proxy) modifyRequest(r *http.Request) {
// Fix r.URL for HTTPS requests after CONNECT. // Fix r.URL for HTTPS requests after CONNECT.
if r.URL.Scheme == "" { if r.URL.Scheme == "" {
@ -69,11 +69,23 @@ func (p *Proxy) modifyRequest(r *http.Request) {
r.URL.Scheme = "https" r.URL.Scheme = "https"
} }
p.reqModifier(r) fn := nopReqModifier
for i := len(p.reqModifiers) - 1; i >= 0; i-- {
fn = p.reqModifiers[i](fn)
}
fn(r)
} }
func (p *Proxy) modifyResponse(res *http.Response) error { func (p *Proxy) modifyResponse(res *http.Response) error {
return p.resModifier(res) fn := nopResModifier
for i := len(p.resModifiers) - 1; i >= 0; i-- {
fn = p.resModifiers[i](fn)
}
return fn(res)
} }
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel. // handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
@ -126,12 +138,12 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
return tlsConn, nil return tlsConn, nil
} }
func (p *Proxy) UseRequestModifier(fn RequestModifyFunc) { func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
p.reqModifier = fn if err == context.Canceled {
} return
}
func (p *Proxy) UseResponseModifier(fn ResponseModifyFunc) { log.Printf("[ERROR]: Proxy error: %v", err)
p.resModifier = fn w.WriteHeader(http.StatusBadGateway)
} }
func writeError(w http.ResponseWriter, r *http.Request, code int) { func writeError(w http.ResponseWriter, r *http.Request, code int) {