From d9c76059aeea40b96dc3aa7592cb77600d64e42e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Wed, 27 Sep 2023 14:49:21 +0200 Subject: [PATCH] martian: add TLS termination in CONNECT This adds support for special header "X-Martian-Terminate-TLS". When set to true the proxy will perform TLS handshake before returning the connection. This has the benefit of using the proxy's TLS configuration for the handshake. --- internal/martian/handler.go | 13 ++++ internal/martian/proxy.go | 22 ++++++ internal/martian/proxy_test.go | 124 +++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) diff --git a/internal/martian/handler.go b/internal/martian/handler.go index 00da7095f..a5214b593 100644 --- a/internal/martian/handler.go +++ b/internal/martian/handler.go @@ -18,6 +18,7 @@ package martian import ( "context" + "crypto/tls" "fmt" "io" "net" @@ -143,6 +144,18 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter, defer cconn.Close() cr = cconn cw = cconn + + if shouldTerminateTLS(req) { + log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) + tconn := tls.Client(cconn, p.clientTLSConfig()) + if err := tconn.Handshake(); err == nil { + cr = tconn + cw = tconn + } else { + log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err) + cerr = err + } + } } } diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 9d74d55bf..18d0fbf73 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -25,6 +25,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -417,6 +418,15 @@ func (p *Proxy) shouldMITM(req *http.Request) bool { return true } +func shouldTerminateTLS(req *http.Request) bool { + h := req.Header.Get("X-Martian-Terminate-TLS") + if h == "" { + return false + } + b, _ := strconv.ParseBool(h) + return b +} + func (p *Proxy) handleMITM(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) error { log.Debugf(req.Context(), "attempting MITM for connection: %s / %s", req.Host, req.URL.String()) @@ -520,6 +530,18 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S defer cconn.Close() cr = cconn cw = cconn + + if shouldTerminateTLS(req) { + log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) + tconn := tls.Client(cconn, p.clientTLSConfig()) + if err := tconn.Handshake(); err == nil { + cr = tconn + cw = tconn + } else { + log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err) + cerr = err + } + } } } diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index 91fe9fca3..b425277bd 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -1076,6 +1076,130 @@ func TestIntegrationConnectPassthrough(t *testing.T) { } } +func TestIntegrationConnectTerminateTLS(t *testing.T) { + t.Parallel() + + l := newListener(t) + p := NewProxy() + defer p.Close() + + // Test TLS server. + ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) + if err != nil { + t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) + } + mc, err := mitm.NewConfig(ca, priv) + if err != nil { + t.Fatalf("mitm.NewConfig(): got %v, want no error", err) + } + + // Set the TLS config to terminate TLS. + roots := x509.NewCertPool() + roots.AddCert(ca) + + rt := http.DefaultTransport.(*http.Transport).Clone() //nolint:force + rt.TLSClientConfig = &tls.Config{ + ServerName: "example.com", + RootCAs: roots, + } + p.SetRoundTripper(rt) + + tl, err := net.Listen("tcp", "[::]:0") + if err != nil { + t.Fatalf("tls.Listen(): got %v, want no error", err) + } + tl = tls.NewListener(tl, mc.TLS()) + + go http.Serve(tl, http.HandlerFunc( + func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(299) + })) + + tm := martiantest.NewModifier() + reqerr := errors.New("request error") + reserr := errors.New("response error") + + // Force the CONNECT request to dial the local TLS server. + tm.RequestFunc(func(req *http.Request) { + req.URL.Host = tl.Addr().String() + }) + + tm.RequestError(reqerr) + tm.ResponseError(reserr) + + p.SetRequestModifier(tm) + p.SetResponseModifier(tm) + + go serve(p, l) + + conn, err := l.dial() + if err != nil { + t.Fatalf("net.Dial(): got %v, want no error", err) + } + defer conn.Close() + + req, err := http.NewRequest(http.MethodConnect, "//example.com:443", http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): got %v, want no error", err) + } + req.Header.Set("X-Martian-Terminate-TLS", "true") + + // CONNECT example.com:443 HTTP/1.1 + // Host: example.com + // X-Martian-Terminate-TLS: true + // + // Rewritten to CONNECT to host:port in CONNECT request modifier. + if err := req.Write(conn); err != nil { + t.Fatalf("req.Write(): got %v, want no error", err) + } + + // CONNECT response after establishing tunnel. + res, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("http.ReadResponse(): got %v, want no error", err) + } + + if got, want := res.StatusCode, 200; got != want { + t.Fatalf("res.StatusCode: got %d, want %d", got, want) + } + + if !tm.RequestModified() { + t.Error("tm.RequestModified(): got false, want true") + } + if !tm.ResponseModified() { + t.Error("tm.ResponseModified(): got false, want true") + } + if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { + t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) + } + + req, err = http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): got %v, want no error", err) + } + req.Header.Set("Connection", "close") + + // GET / HTTP/1.1 + // Host: example.com + // Connection: close + if err := req.Write(conn); err != nil { + t.Fatalf("req.Write(): got %v, want no error", err) + } + + res, err = http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("http.ReadResponse(): got %v, want no error", err) + } + defer res.Body.Close() + + if got, want := res.StatusCode, 299; got != want { + t.Fatalf("res.StatusCode: got %d, want %d", got, want) + } + if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) { + t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want) + } +} + func TestIntegrationMITM(t *testing.T) { t.Parallel()