diff --git a/p2p/transport/websocket/addrs.go b/p2p/transport/websocket/addrs.go index fed649dcbc..5fea8567b5 100644 --- a/p2p/transport/websocket/addrs.go +++ b/p2p/transport/websocket/addrs.go @@ -132,7 +132,7 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) { type parsedWebsocketMultiaddr struct { isWSS bool - // sni is the SNI value for the TLS handshake + // sni is the SNI value for the TLS handshake, and for setting HTTP Host header sni *ma.Component // the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss restMultiaddr ma.Multiaddr diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 2e9fc0b032..f1294a5702 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -4,6 +4,7 @@ package websocket import ( "context" "crypto/tls" + "net" "net/http" "time" @@ -186,6 +187,17 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma copytlsClientConf := t.tlsClientConf.Clone() copytlsClientConf.ServerName = sni dialer.TLSClientConfig = copytlsClientConf + ipAddr := wsurl.Host + // Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution. + // We set the `.Host` to the sni field so that the host header gets properly set. + dialer.NetDial = func(network, address string) (net.Conn, error) { + tcpAddr, err := net.ResolveTCPAddr(network, ipAddr) + if err != nil { + return nil, err + } + return net.DialTCP("tcp", nil, tcpAddr) + } + wsurl.Host = sni + ":" + wsurl.Port() } else { dialer.TLSClientConfig = t.tlsClientConf } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 1961e9cec9..70e122d821 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -9,10 +9,13 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "errors" "fmt" "io" "math/big" "net" + "net/http" + "strings" "testing" "time" @@ -218,6 +221,44 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { } } +func TestHostHeaderWss(t *testing.T) { + server := &http.Server{} + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + defer server.Close() + + errChan := make(chan error, 1) + go func() { + server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(errChan) + if !strings.Contains(r.Host, "example.com") { + errChan <- errors.New("Didn't see host header") + } + w.WriteHeader(http.StatusNotFound) + }) + server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour)) + server.ServeTLS(l, "", "") + }() + + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoError(t, err) + serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws") + + tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA + _, u := newSecureUpgrader(t) + tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig)) + require.NoError(t, err) + + masToDial, err := tpt.Resolve(context.Background(), serverMA) + require.NoError(t, err) + + _, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t)) + require.Error(t, err) + + err = <-errChan + require.NoError(t, err) +} + func TestDialWss(t *testing.T) { serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws")) require.Contains(t, serverMA.String(), "tls")