diff --git a/main.go b/main.go index 50c3ffe..cb4b6c8 100644 --- a/main.go +++ b/main.go @@ -28,23 +28,28 @@ func main() { log.Fatalf("[FATAL] Could not parse CA: %v", err) } - proxy, err := proxy.NewProxy(caCert, tlsCA.PrivateKey) + p, err := proxy.NewProxy(caCert, tlsCA.PrivateKey) if err != nil { log.Fatalf("[FATAL] Could not create Proxy: %v", err) } - proxy.UseRequestModifier(func(req *http.Request) { - log.Printf("[DEBUG] Incoming request: %v", req.URL) + p.UseRequestModifier(func(next proxy.RequestModifyFunc) proxy.RequestModifyFunc { + return func(req *http.Request) { + log.Printf("[DEBUG] Incoming request: %v", req.URL) + next(req) + } }) - proxy.UseResponseModifier(func(res *http.Response) error { - log.Printf("[DEBUG] Downstream response: %v %v %v", res.Proto, res.StatusCode, http.StatusText(res.StatusCode)) - return nil + p.UseResponseModifier(func(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc { + return func(res *http.Response) error { + log.Printf("[DEBUG] Downstream response: %v %v %v", res.Proto, res.StatusCode, http.StatusText(res.StatusCode)) + return next(res) + } }) s := &http.Server{ Addr: ":8080", - Handler: proxy, + Handler: p, TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, // Disable HTTP/2 } diff --git a/proxy/modify.go b/proxy/modify.go index 4364083..dc8d851 100644 --- a/proxy/modify.go +++ b/proxy/modify.go @@ -7,5 +7,18 @@ var ( nopResModifier = func(res *http.Response) error { return nil } ) +// RequestModifyFunc defines a type for a function that can modify a HTTP +// request before it's proxied. type RequestModifyFunc func(req *http.Request) + +// RequestModifyMiddleware defines a type for chaining request modifier +// middleware. +type RequestModifyMiddleware func(next RequestModifyFunc) RequestModifyFunc + +// ResponseModifyFunc defines a type for a function that can modify a HTTP +// response before it's written back to the client. type ResponseModifyFunc func(res *http.Response) error + +// ResponseModifyMiddleware defines a type for chaining response modifier +// middleware. +type ResponseModifyMiddleware func(ResponseModifyFunc) ResponseModifyFunc diff --git a/proxy/proxy.go b/proxy/proxy.go index 4597b6c..f238ff3 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -12,14 +12,6 @@ import ( "net/http/httputil" ) -func errorHandler(w http.ResponseWriter, r *http.Request, err error) { - if err == context.Canceled { - return - } - log.Printf("[ERROR]: Proxy error: %v", err) - w.WriteHeader(http.StatusBadGateway) -} - // Proxy implements http.Handler and offers MITM behaviour for modifying // HTTP requests and responses. type Proxy struct { @@ -27,8 +19,8 @@ type Proxy struct { handler http.Handler // TODO: Add mutex for modifier funcs. - reqModifier RequestModifyFunc - resModifier ResponseModifyFunc + reqModifiers []RequestModifyMiddleware + resModifiers []ResponseModifyMiddleware } // NewProxy returns a new Proxy. @@ -39,9 +31,9 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) { } p := &Proxy{ - certConfig: certConfig, - reqModifier: nopReqModifier, - resModifier: nopResModifier, + certConfig: certConfig, + reqModifiers: make([]RequestModifyMiddleware, 0), + resModifiers: make([]ResponseModifyMiddleware, 0), } p.handler = &httputil.ReverseProxy{ @@ -62,6 +54,14 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.handler.ServeHTTP(w, r) } +func (p *Proxy) UseRequestModifier(fn ...RequestModifyMiddleware) { + p.reqModifiers = append(p.reqModifiers, fn...) +} + +func (p *Proxy) UseResponseModifier(fn ...ResponseModifyMiddleware) { + p.resModifiers = append(p.resModifiers, fn...) +} + func (p *Proxy) modifyRequest(r *http.Request) { // Fix r.URL for HTTPS requests after CONNECT. if r.URL.Scheme == "" { @@ -69,11 +69,23 @@ func (p *Proxy) modifyRequest(r *http.Request) { r.URL.Scheme = "https" } - p.reqModifier(r) + fn := nopReqModifier + + for i := len(p.reqModifiers) - 1; i >= 0; i-- { + fn = p.reqModifiers[i](fn) + } + + fn(r) } func (p *Proxy) modifyResponse(res *http.Response) error { - return p.resModifier(res) + fn := nopResModifier + + for i := len(p.resModifiers) - 1; i >= 0; i-- { + fn = p.resModifiers[i](fn) + } + + return fn(res) } // handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel. @@ -126,12 +138,12 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) { return tlsConn, nil } -func (p *Proxy) UseRequestModifier(fn RequestModifyFunc) { - p.reqModifier = fn -} - -func (p *Proxy) UseResponseModifier(fn ResponseModifyFunc) { - p.resModifier = fn +func errorHandler(w http.ResponseWriter, r *http.Request, err error) { + if err == context.Canceled { + return + } + log.Printf("[ERROR]: Proxy error: %v", err) + w.WriteHeader(http.StatusBadGateway) } func writeError(w http.ResponseWriter, r *http.Request, code int) {