Files
hetty/proxy.go

118 lines
2.7 KiB
Go
Raw Normal View History

2019-11-17 15:09:37 +01:00
package main
import (
2019-11-24 00:14:49 +01:00
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"fmt"
2019-11-17 15:09:37 +01:00
"log"
"net"
"net/http"
"net/http/httputil"
)
2019-11-24 00:14:49 +01:00
var httpHandler = &httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL.Host = r.Host
r.URL.Scheme = "http"
},
ErrorHandler: proxyErrorHandler,
}
var httpsHandler = &httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL.Host = r.Host
r.URL.Scheme = "https"
},
ErrorHandler: proxyErrorHandler,
}
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if err == context.Canceled {
return
}
log.Printf("[ERROR]: Proxy error: %v", err)
w.WriteHeader(http.StatusBadGateway)
}
2019-11-17 15:09:37 +01:00
// Proxy is used to forward HTTP requests.
type Proxy struct {
2019-11-24 00:14:49 +01:00
certConfig *CertConfig
2019-11-17 15:09:37 +01:00
}
// NewProxy returns a new Proxy.
2019-11-24 00:14:49 +01:00
func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
certConfig, err := NewCertConfig(ca, key)
if err != nil {
return nil, err
2019-11-17 15:09:37 +01:00
}
2019-11-24 00:14:49 +01:00
return &Proxy{
certConfig: certConfig,
}, nil
2019-11-17 15:09:37 +01:00
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
return
}
2019-11-24 00:14:49 +01:00
httpHandler.ServeHTTP(w, r)
2019-11-17 15:09:37 +01:00
}
2019-11-24 00:14:49 +01:00
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
// During the TLS handshake with the client, we use the proxy's CA config to
// create a certificate on-the-fly.
2019-11-17 15:09:37 +01:00
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
2019-11-24 00:14:49 +01:00
log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w)
2019-11-17 15:09:37 +01:00
writeError(w, r, http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
clientConn, _, err := hj.Hijack()
if err != nil {
2019-11-24 00:14:49 +01:00
log.Printf("[ERROR] Hijacking client connection failed: %v", err)
2019-11-17 15:09:37 +01:00
writeError(w, r, http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
2019-11-24 00:14:49 +01:00
// Secure connection to client.
clientConn, err = p.clientTLSConn(clientConn)
if err != nil {
log.Printf("[ERROR] Securing client connection failed: %v", err)
return
}
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
l := &OnceAcceptListener{clientConnNotify.Conn}
err = http.Serve(l, httpsHandler)
if err != nil && err != ErrAlreadyAccepted {
log.Printf("[ERROR] Serving HTTP request failed: %v", err)
}
<-clientConnNotify.closed
2019-11-17 15:09:37 +01:00
}
2019-11-24 00:14:49 +01:00
func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
tlsConfig := p.certConfig.TLSConfig()
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("handshake error: %v", err)
}
return tlsConn, nil
2019-11-17 15:09:37 +01:00
}
func writeError(w http.ResponseWriter, r *http.Request, code int) {
http.Error(w, http.StatusText(code), code)
}