From 2345fa3e3ecf4d9a46bdca9d420a71a6eee94284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Thu, 12 Sep 2024 13:38:31 +0200 Subject: [PATCH] chore(martian): extract common code of proxyConn.handleConnectRequest() and proxyHandler.handleConnectRequest() to a function Add Proxy.Connect() that wraps connect() and also handles ConnectFunc and TLS termination. Fixes #445 --- internal/martian/proxy_conn.go | 31 ++--------------------------- internal/martian/proxy_connect.go | 29 +++++++++++++++++++++++++++ internal/martian/proxy_handler.go | 33 +------------------------------ 3 files changed, 32 insertions(+), 61 deletions(-) diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index 3c48d413..551807a9 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -203,35 +203,8 @@ func (p *proxyConn) handleConnectRequest(req *http.Request) error { return p.handleMITM(req) } - var ( - res *http.Response - crw io.ReadWriteCloser - cerr error - ) - if p.ConnectFunc != nil { - res, crw, cerr = p.ConnectFunc(req) - } - if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) { - var cconn net.Conn - res, cconn, cerr = p.connect(req) - - if cconn != nil { - defer cconn.Close() - crw = cconn - - if terminateTLS { - log.Debugf(ctx, "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) - tconn := tls.Client(cconn, p.clientTLSConfig()) - if err := tconn.Handshake(); err == nil { - crw = tconn - } else { - log.Errorf(ctx, "failed to terminate TLS on CONNECT tunnel: %v", err) - cerr = err - } - } - } - } - + log.Debugf(ctx, "attempting to establish CONNECT tunnel: %s", req.URL.Host) + res, crw, cerr := p.Connect(ctx, req, terminateTLS) if res != nil { defer res.Body.Close() } diff --git a/internal/martian/proxy_connect.go b/internal/martian/proxy_connect.go index 443bd7f0..f8854e4f 100644 --- a/internal/martian/proxy_connect.go +++ b/internal/martian/proxy_connect.go @@ -17,6 +17,7 @@ package martian import ( + "context" "crypto/tls" "errors" "fmt" @@ -50,6 +51,34 @@ var ErrConnectFallback = errors.New("martian: connect fallback") // If the returned net.Conn is not nil, the response must be not nil. type ConnectFunc func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) +func (p *Proxy) Connect(ctx context.Context, req *http.Request, terminateTLS bool) (res *http.Response, crw io.ReadWriteCloser, cerr error) { + if p.ConnectFunc != nil { + res, crw, cerr = p.ConnectFunc(req) + } + if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) { + var cconn net.Conn + res, cconn, cerr = p.connect(req) + + if cconn != nil { + defer cconn.Close() + crw = cconn + + if terminateTLS { + log.Debugf(ctx, "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) + tconn := tls.Client(cconn, p.clientTLSConfig()) + if err := tconn.Handshake(); err == nil { + crw = tconn + } else { + log.Errorf(ctx, "failed to terminate TLS on CONNECT tunnel: %v", err) + cerr = err + } + } + } + } + + return +} + func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { ctx := req.Context() diff --git a/internal/martian/proxy_handler.go b/internal/martian/proxy_handler.go index 754732c5..35fedaf5 100644 --- a/internal/martian/proxy_handler.go +++ b/internal/martian/proxy_handler.go @@ -18,11 +18,8 @@ package martian import ( "context" - "crypto/tls" - "errors" "fmt" "io" - "net" "net/http" "strings" @@ -113,35 +110,7 @@ func (p proxyHandler) handleConnectRequest(rw http.ResponseWriter, req *http.Req } log.Debugf(ctx, "attempting to establish CONNECT tunnel: %s", req.URL.Host) - var ( - res *http.Response - crw io.ReadWriteCloser - cerr error - ) - if p.ConnectFunc != nil { - res, crw, cerr = p.ConnectFunc(req) - } - if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) { - var cconn net.Conn - res, cconn, cerr = p.connect(req) - - if cconn != nil { - defer cconn.Close() - crw = cconn - - if terminateTLS { - log.Debugf(ctx, "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) - tconn := tls.Client(cconn, p.clientTLSConfig()) - if err := tconn.Handshake(); err == nil { - crw = tconn - } else { - log.Errorf(ctx, "failed to terminate TLS on CONNECT tunnel: %v", err) - cerr = err - } - } - } - } - + res, crw, cerr := p.Connect(ctx, req, terminateTLS) if res != nil { defer res.Body.Close() }