From ce96de2ed5cfede3d7458de7c5282f3183f8bba2 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Mon, 9 Oct 2023 15:40:07 -0600 Subject: [PATCH] [ADDED] TLS: Handshake First for client connections A new option instructs the server to perform the TLS handshake first, that is prior to sending the INFO protocol to the client. Only clients that implement equivalent option would be able to connect if the server runs with this option enabled. The configuration would look something like this: ``` ... tls { cert_file: ... key_file: ... handshake_first: true } ``` The same option can be set to "auto" or a Go time duration to fallback to the old behavior. This is intended for deployments where it is known that not all clients have been upgraded to a client library providing the TLS handshake first option. After the delay has elapsed without receiving the TLS handshake from the client, the server reverts to sending the INFO protocol so that older clients can connect. Clients that do connect with the "TLS first" option will be marked as such in the monitoring's Connz page/result. It will allow the administrator to keep track of applications still needing to upgrade. The configuration would be similar to: ``` ... tls { cert_file: ... key_file: ... handshake_first: auto } ``` With the above value, the fallback delay used by the server is 50ms. The duration can be explcitly set, say 300 milliseconds: ``` ... tls { cert_file: ... key_file: ... handshake_first: "300ms" } ``` It is understood that any configuration other that "true" will result in the server sending the INFO protocol after the elapsed amount of time without the client initiating the TLS handshake. Therefore, for administrators that do not want any data transmitted in plain text, the value must be set to "true" only. It will require applications to be updated to a library that provides the option, which may or may not be readily available. Signed-off-by: Ivan Kozlovic --- go.mod | 2 +- go.sum | 4 +- server/client.go | 1 + server/client_test.go | 342 ++++++++++++++++++++++++++++++++++++ server/config_check_test.go | 24 +++ server/const.go | 6 + server/monitor.go | 2 + server/opts.go | 52 +++++- server/reload.go | 26 +++ server/server.go | 90 ++++++++-- test/tls_test.go | 1 + 11 files changed, 528 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index 62917a587e1..f8fab6b30df 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/klauspost/compress v1.17.0 github.com/minio/highwayhash v1.0.2 github.com/nats-io/jwt/v2 v2.5.2 - github.com/nats-io/nats.go v1.30.2 + github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f github.com/nats-io/nkeys v0.4.5 github.com/nats-io/nuid v1.0.1 go.uber.org/automaxprocs v1.5.3 diff --git a/go.sum b/go.sum index 72369f26e5d..cc87607cea8 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= -github.com/nats-io/nats.go v1.30.2 h1:aloM0TGpPorZKQhbAkdCzYDj+ZmsJDyeo3Gkbr72NuY= -github.com/nats-io/nats.go v1.30.2/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= +github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f h1:1OBmQ3HJsJAX4vemhoCQjonLBaQ7yx/7PUe6oF1kzvE= +github.com/nats-io/nats.go v1.30.3-0.20231009181226-1941a1a4f14f/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= github.com/nats-io/nkeys v0.4.5/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= diff --git a/server/client.go b/server/client.go index e3364c8a80f..afa849b8784 100644 --- a/server/client.go +++ b/server/client.go @@ -141,6 +141,7 @@ const ( expectConnect // Marks if this connection is expected to send a CONNECT connectProcessFinished // Marks if this connection has finished the connect process. compressionNegotiated // Marks if this connection has negotiated compression level with remote. + didTLSFirst // Marks if this connection requested and was accepted doing the TLS handshake first (prior to INFO). ) // set the flag (would be equivalent to set the boolean to true) diff --git a/server/client_test.go b/server/client_test.go index 95051a4f66e..55fef6fff38 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -2602,3 +2602,345 @@ func TestClientUserInfoReq(t *testing.T) { t.Fatalf("User info for %q did not match", "admin") } } + +func TestTLSClientHandshakeFirst(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "true"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + connect := func(tlsfirst, expectedOk bool) { + opts := []nats.Option{nats.RootCAs("../test/configs/certs/ca.pem")} + if tlsfirst { + opts = append(opts, nats.TLSHandshakeFirst()) + } + nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", o.Port), opts...) + if expectedOk { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if tlsfirst { + cz, err := s.Connz(nil) + if err != nil { + t.Fatalf("Error getting connz: %v", err) + } + if !cz.Conns[0].TLSFirst { + t.Fatal("Expected TLSFirst boolean to be set, it was not") + } + } + } else if !expectedOk && err == nil { + nc.Close() + t.Fatal("Expected error, got none") + } + } + + // Server is TLS first, but client is not, so should fail. + connect(false, false) + + // Now client is TLS first too, so should work. + connect(true, true) + + // Config reload the server and disable tls first + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "false")) + + // Now if client wants TLS first, connection should fail. + connect(true, false) + + // But if it does not, should be ok. + connect(false, true) + + // Config reload the server again and enable tls first + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "true")) + + // If both client and server are TLS first, this should work. + connect(true, true) +} + +func TestTLSClientHandshakeFirstFallbackDelayConfigValues(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + ` + for _, test := range []struct { + name string + val string + first bool + delay time.Duration + }{ + {"first as boolean true", "true", true, 0}, + {"first as boolean false", "false", false, 0}, + {"first as string true", "\"true\"", true, 0}, + {"first as string false", "\"false\"", false, 0}, + {"first as string on", "on", true, 0}, + {"first as string off", "off", false, 0}, + {"first as string auto", "auto", true, DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY}, + {"first as string auto_fallback", "auto_fallback", true, DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY}, + {"first as fallback duration", "300ms", true, 300 * time.Millisecond}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, test.val))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + if test.first { + if !o.TLSHandshakeFirst { + t.Fatal("Expected tls first to be true, was not") + } + if test.delay != o.TLSHandshakeFirstFallback { + t.Fatalf("Expected fallback delay to be %v, got %v", test.delay, o.TLSHandshakeFirstFallback) + } + } else { + if o.TLSHandshakeFirst { + t.Fatal("Expected tls first to be false, was not") + } + if o.TLSHandshakeFirstFallback != 0 { + t.Fatalf("Expected fallback delay to be 0, got %v", o.TLSHandshakeFirstFallback) + } + } + }) + } +} + +type pauseAfterDial struct { + delay time.Duration +} + +func (d *pauseAfterDial) Dial(network, address string) (net.Conn, error) { + c, err := net.Dial(network, address) + if err != nil { + return nil, err + } + time.Sleep(d.delay) + return c, nil +} + +func TestTLSClientHandshakeFirstFallbackDelay(t *testing.T) { + // Using certificates with RSA 4K to make sure that the fallback does + // not prevent a client with TLS first to successfully connect. + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "./configs/certs/tls/benchmark-server-cert-rsa-4096.pem" + key_file: "./configs/certs/tls/benchmark-server-key-rsa-4096.pem" + timeout: 1 + first: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "auto"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + url := fmt.Sprintf("tls://localhost:%d", o.Port) + d := &pauseAfterDial{delay: DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + 100*time.Millisecond} + + // Connect a client without "TLS first" and it should be accepted. + nc, err := nats.Connect(url, + nats.SetCustomDialer(d), + nats.Secure(&tls.Config{ + ServerName: "reuben.nats.io", + MinVersion: tls.VersionTLS12, + }), + nats.RootCAs("./configs/certs/tls/benchmark-ca-cert.pem")) + require_NoError(t, err) + defer nc.Close() + // Check that the TLS first in monitoring is set to false + cs, err := s.Connz(nil) + require_NoError(t, err) + if cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be false, it was not") + } + nc.Close() + + // Wait for the client to be removed + checkClientsCount(t, s, 0) + + // Increase the fallback delay with config reload. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "\"1s\"")) + + // This time, start the client with "TLS first". + // We will also make sure that we did not wait for the fallback delay + // in order to connect. + start := time.Now() + nc, err = nats.Connect(url, + nats.SetCustomDialer(d), + nats.Secure(&tls.Config{ + ServerName: "reuben.nats.io", + MinVersion: tls.VersionTLS12, + }), + nats.RootCAs("./configs/certs/tls/benchmark-ca-cert.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + require_True(t, time.Since(start) < 500*time.Millisecond) + defer nc.Close() + + // Check that the TLS first in monitoring is set to true. + cs, err = s.Connz(nil) + require_NoError(t, err) + if !cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } + nc.Close() +} + +func TestTLSClientHandshakeFirstFallbackDelayAndAllowNonTLS(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: %s + } + allow_non_tls: true + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, "true"))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + // We first start with a server that has handshake first set to true + // and allow_non_tls. In that case, only "TLS first" clients should be + // accepted. + url := fmt.Sprintf("tls://localhost:%d", o.Port) + nc, err := nats.Connect(url, + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + // Check that the TLS first in monitoring is set to true + cs, err := s.Connz(nil) + require_NoError(t, err) + if !cs.Conns[0].TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } + nc.Close() + + // Client not using "TLS First" should fail. + nc, err = nats.Connect(url, nats.RootCAs("../test/configs/certs/ca.pem")) + if err == nil { + nc.Close() + t.Fatal("Expected connection to fail, it did not") + } + + // And non TLS clients should also fail to connect. + nc, err = nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", o.Port)) + if err == nil { + nc.Close() + t.Fatal("Expected connection to fail, it did not") + } + + // Now we will replace TLS first in server with a fallback delay. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, "\"25ms\"")) + + // Clients with "TLS first" should still be able to connect + nc, err = nats.Connect(url, + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + + checkConnInfo := func(isTLS, isTLSFirst bool) { + t.Helper() + cs, err = s.Connz(nil) + require_NoError(t, err) + conn := cs.Conns[0] + if !isTLS { + if conn.TLSVersion != _EMPTY_ { + t.Fatalf("Being a non TLS client, there should not be TLSVersion set, got %v", conn.TLSVersion) + } + if conn.TLSFirst { + t.Fatal("Being a non TLS client, TLSFirst should not be set, but it was") + } + return + } + if isTLSFirst && !conn.TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be true, it was not") + } else if !isTLSFirst && conn.TLSFirst { + t.Fatal("Expected monitoring ConnInfo.TLSFirst to be false, it was not") + } + nc.Close() + + checkClientsCount(t, s, 0) + } + checkConnInfo(true, true) + + // Clients with TLS but not "TLS first" should also be able to connect. + nc, err = nats.Connect(url, nats.RootCAs("../test/configs/certs/ca.pem")) + require_NoError(t, err) + defer nc.Close() + checkConnInfo(true, false) + + // And non TLS clients should also be able to connect. + nc, err = nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", o.Port)) + require_NoError(t, err) + defer nc.Close() + checkConnInfo(false, false) +} + +func TestTLSClientHandshakeFirstAndInProcessConnection(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + first: true + } + `)) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Check that we can create an in process connection that does not use TLS + nc, err := nats.Connect(_EMPTY_, nats.InProcessServer(s)) + require_NoError(t, err) + defer nc.Close() + if nc.TLSRequired() { + t.Fatalf("Shouldn't have required TLS for in-process connection") + } + if _, err = nc.TLSConnectionState(); err == nil { + t.Fatal("Should have got an error retrieving TLS connection state") + } + nc.Close() + + // If the client wants TLS, it should get a TLS connection. + nc, err = nats.Connect(_EMPTY_, + nats.InProcessServer(s), + nats.RootCAs("../test/configs/certs/ca.pem")) + require_NoError(t, err) + defer nc.Close() + if _, err = nc.TLSConnectionState(); err != nil { + t.Fatal("Should have not got an error retrieving TLS connection state") + } + // However, the server would not have sent that TLS was required, + // but instead it is available. + if nc.TLSRequired() { + t.Fatalf("Shouldn't have required TLS for in-process connection") + } + nc.Close() + + // The in-process connection with TLS and "TLS first" should also be working. + nc, err = nats.Connect(_EMPTY_, + nats.InProcessServer(s), + nats.RootCAs("../test/configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + require_NoError(t, err) + defer nc.Close() + if !nc.TLSRequired() { + t.Fatalf("The server should have sent that TLS is required") + } + if _, err = nc.TLSConnectionState(); err != nil { + t.Fatal("Should have not got an error retrieving TLS connection state") + } +} diff --git a/server/config_check_test.go b/server/config_check_test.go index 34c95b17e05..a07962b7a18 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1808,6 +1808,30 @@ func TestConfigCheck(t *testing.T) { errorPos: 0, reason: "", }, + { + name: "TLS handshake first, wrong type", + config: ` + port: -1 + tls { + first: 123 + } + `, + err: fmt.Errorf("field %q should be a boolean or a string, got int64", "first"), + errorLine: 4, + errorPos: 6, + }, + { + name: "TLS handshake first, wrong value", + config: ` + port: -1 + tls { + first: "123" + } + `, + err: fmt.Errorf("field %q's value %q is invalid", "first", "123"), + errorLine: 4, + errorPos: 6, + }, } checkConfig := func(config string) error { diff --git a/server/const.go b/server/const.go index 64ec6b6267f..91b5c76e04b 100644 --- a/server/const.go +++ b/server/const.go @@ -82,6 +82,12 @@ const ( // TLS_TIMEOUT is the TLS wait time. TLS_TIMEOUT = 2 * time.Second + // DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY is the default amount of + // time for the server to wait for the TLS handshake with a client to + // be initiated before falling back to sending the INFO protocol first. + // See TLSHandshakeFirst and TLSHandshakeFirstFallback options. + DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY = 50 * time.Millisecond + // AUTH_TIMEOUT is the authorization wait time. AUTH_TIMEOUT = 2 * time.Second diff --git a/server/monitor.go b/server/monitor.go index 66f5e81a363..073c468e0d7 100644 --- a/server/monitor.go +++ b/server/monitor.go @@ -130,6 +130,7 @@ type ConnInfo struct { TLSVersion string `json:"tls_version,omitempty"` TLSCipher string `json:"tls_cipher_suite,omitempty"` TLSPeerCerts []*TLSPeerCert `json:"tls_peer_certs,omitempty"` + TLSFirst bool `json:"tls_first,omitempty"` AuthorizedUser string `json:"authorized_user,omitempty"` Account string `json:"account,omitempty"` Subs []string `json:"subscriptions_list,omitempty"` @@ -568,6 +569,7 @@ func (ci *ConnInfo) fill(client *client, nc net.Conn, now time.Time, auth bool) if auth && len(cs.PeerCertificates) > 0 { ci.TLSPeerCerts = makePeerCerts(cs.PeerCertificates) } + ci.TLSFirst = client.flags.isSet(didTLSFirst) } } diff --git a/server/opts.go b/server/opts.go index 039f982c027..38b5799e1d0 100644 --- a/server/opts.go +++ b/server/opts.go @@ -327,11 +327,23 @@ type Options struct { TLSConfig *tls.Config `json:"-"` TLSPinnedCerts PinnedCertSet `json:"-"` TLSRateLimit int64 `json:"-"` - AllowNonTLS bool `json:"-"` - WriteDeadline time.Duration `json:"-"` - MaxClosedClients int `json:"-"` - LameDuckDuration time.Duration `json:"-"` - LameDuckGracePeriod time.Duration `json:"-"` + // When set to true, the server will perform the TLS handshake before + // sending the INFO protocol. For clients that are not configured + // with a similar option, their connection will fail with some sort + // of timeout or EOF error since they are expecting to receive an + // INFO protocol first. + TLSHandshakeFirst bool `json:"-"` + // If TLSHandshakeFirst is true and this value is strictly positive, + // the server will wait for that amount of time for the TLS handshake + // to start before falling back to previous behavior of sending the + // INFO protocol first. It allows for a mix of newer clients that can + // require a TLS handshake first, and older clients that can't. + TLSHandshakeFirstFallback time.Duration `json:"-"` + AllowNonTLS bool `json:"-"` + WriteDeadline time.Duration `json:"-"` + MaxClosedClients int `json:"-"` + LameDuckDuration time.Duration `json:"-"` + LameDuckGracePeriod time.Duration `json:"-"` // MaxTracedMsgLen is the maximum printable length for traced messages. MaxTracedMsgLen int `json:"-"` @@ -638,7 +650,8 @@ type TLSConfigOpts struct { Insecure bool Map bool TLSCheckKnownURLs bool - HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol + HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol. + FallbackDelay time.Duration // Where supported, indicates how long to wait for the handshake before falling back to sending the INFO protocol first. Timeout float64 RateLimit int64 Ciphers []uint16 @@ -1072,6 +1085,8 @@ func (o *Options) processConfigFileLine(k string, v interface{}, errors *[]error o.TLSMap = tc.Map o.TLSPinnedCerts = tc.PinnedCerts o.TLSRateLimit = tc.RateLimit + o.TLSHandshakeFirst = tc.HandshakeFirst + o.TLSHandshakeFirstFallback = tc.FallbackDelay // Need to keep track of path of the original TLS config // and certs path for OCSP Stapling monitoring. @@ -4312,7 +4327,30 @@ func parseTLS(v interface{}, isClientCtx bool) (t *TLSConfigOpts, retErr error) } tc.CertMatch = certMatch case "handshake_first", "first", "immediate": - tc.HandshakeFirst = mv.(bool) + switch mv := mv.(type) { + case bool: + tc.HandshakeFirst = mv + case string: + switch strings.ToLower(mv) { + case "true", "on": + tc.HandshakeFirst = true + case "false", "off": + tc.HandshakeFirst = false + case "auto", "auto_fallback": + tc.HandshakeFirst = true + tc.FallbackDelay = DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + default: + // Check to see if this is a duration. + if dur, err := time.ParseDuration(mv); err == nil { + tc.HandshakeFirst = true + tc.FallbackDelay = dur + break + } + return nil, &configErr{tk, fmt.Sprintf("field %q's value %q is invalid", mk, mv)} + } + default: + return nil, &configErr{tk, fmt.Sprintf("field %q should be a boolean or a string, got %T", mk, mv)} + } case "ocsp_peer": switch vv := mv.(type) { case bool: diff --git a/server/reload.go b/server/reload.go index 239881715e1..4e1c2f71b5e 100644 --- a/server/reload.go +++ b/server/reload.go @@ -266,6 +266,28 @@ func (t *tlsPinnedCertOption) Apply(server *Server) { server.Noticef("Reloaded: %d pinned_certs", len(t.newValue)) } +// tlsHandshakeFirst implements the option interface for the tls `handshake first` setting. +type tlsHandshakeFirst struct { + noopOption + newValue bool +} + +// Apply is a no-op because the timeout will be reloaded after options are applied. +func (t *tlsHandshakeFirst) Apply(server *Server) { + server.Noticef("Reloaded: Client TLS handshake first: %v", t.newValue) +} + +// tlsHandshakeFirstFallback implements the option interface for the tls `handshake first fallback delay` setting. +type tlsHandshakeFirstFallback struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the timeout will be reloaded after options are applied. +func (t *tlsHandshakeFirstFallback) Apply(server *Server) { + server.Noticef("Reloaded: Client TLS handshake first fallback delay: %v", t.newValue) +} + // authOption is a base struct that provides default option behaviors. type authOption struct { noopOption @@ -1222,6 +1244,10 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { diffOpts = append(diffOpts, &tlsTimeoutOption{newValue: newValue.(float64)}) case "tlspinnedcerts": diffOpts = append(diffOpts, &tlsPinnedCertOption{newValue: newValue.(PinnedCertSet)}) + case "tlshandshakefirst": + diffOpts = append(diffOpts, &tlsHandshakeFirst{newValue: newValue.(bool)}) + case "tlshandshakefirstfallback": + diffOpts = append(diffOpts, &tlsHandshakeFirstFallback{newValue: newValue.(time.Duration)}) case "username": diffOpts = append(diffOpts, &usernameOption{}) case "password": diff --git a/server/server.go b/server/server.go index d1d0d109d49..b113890305b 100644 --- a/server/server.go +++ b/server/server.go @@ -2573,6 +2573,9 @@ func (s *Server) AcceptLoop(clr chan struct{}) { // Alert of TLS enabled. if opts.TLSConfig != nil { s.Noticef("TLS required for client connections") + if opts.TLSHandshakeFirst && opts.TLSHandshakeFirstFallback == 0 { + s.Warnf("Clients that are not using \"TLS Handshake First\" option will fail to connect") + } } // If server was started with RANDOM_PORT (-1), opts.Port would be equal @@ -3041,10 +3044,37 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { c.Debugf("Client connection created") - // Send our information. - // Need to be sent in place since writeLoop cannot be started until - // TLS handshake is done (if applicable). - c.sendProtoNow(c.generateClientInfoJSON(info)) + // Save info.TLSRequired value since we may neeed to change it back and forth. + orgInfoTLSReq := info.TLSRequired + + var tlsFirstFallback time.Duration + // Check if we should do TLS first. + tlsFirst := opts.TLSConfig != nil && opts.TLSHandshakeFirst + if tlsFirst { + // Make sure info.TLSRequired is set to true (it could be false + // if AllowNonTLS is enabled). + info.TLSRequired = true + // Get the fallback delay value if applicable. + if f := opts.TLSHandshakeFirstFallback; f > 0 { + tlsFirstFallback = f + } else if inProcess { + // For in-process connection, we will always have a fallback + // delay. It allows support for non-TLS, TLS and "TLS First" + // in-process clients to successfully connect. + tlsFirstFallback = DEFAULT_TLS_HANDSHAKE_FIRST_FALLBACK_DELAY + } + } + + // Decide if we are going to require TLS or not and generate INFO json. + tlsRequired := info.TLSRequired + infoBytes := c.generateClientInfoJSON(info) + + // Send our information, except if TLS and TLSHandshakeFirst is requested. + if !tlsFirst { + // Need to be sent in place since writeLoop cannot be started until + // TLS handshake is done (if applicable). + c.sendProtoNow(infoBytes) + } // Unlock to register c.mu.Unlock() @@ -3077,20 +3107,50 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } s.clients[c.cid] = c - tlsRequired := info.TLSRequired s.mu.Unlock() // Re-Grab lock c.mu.Lock() - // Connection could have been closed while sending the INFO proto. isClosed := c.isClosed() - var pre []byte + // We need first to check for "TLS First" fallback delay. + if !isClosed && tlsFirstFallback > 0 { + // We wait and see if we are getting any data. Since we did not send + // the INFO protocol yet, only clients that use TLS first should be + // sending data (the TLS handshake). We don't really check the content: + // if it is a rogue agent and not an actual client performing the + // TLS handshake, the error will be detected when performing the + // handshake on our side. + pre = make([]byte, 4) + c.nc.SetReadDeadline(time.Now().Add(tlsFirstFallback)) + n, _ := io.ReadFull(c.nc, pre[:]) + c.nc.SetReadDeadline(time.Time{}) + // If we get any data (regardless of possible timeout), we will proceed + // with the TLS handshake. + if n > 0 { + pre = pre[:n] + } else { + // We did not get anything so we will send the INFO protocol. + pre = nil + + // Restore the original info.TLSRequired value if it is + // different that the current value and regenerate infoBytes. + if orgInfoTLSReq != info.TLSRequired { + info.TLSRequired = orgInfoTLSReq + infoBytes = c.generateClientInfoJSON(info) + } + c.sendProtoNow(infoBytes) + // Set the boolean to false for the rest of the function. + tlsFirst = false + // Check closed status again + isClosed = c.isClosed() + } + } // If we have both TLS and non-TLS allowed we need to see which // one the client wants. We'll always allow this for in-process // connections. - if !isClosed && opts.TLSConfig != nil && (inProcess || opts.AllowNonTLS) { + if !isClosed && !tlsFirst && opts.TLSConfig != nil && (inProcess || opts.AllowNonTLS) { pre = make([]byte, 4) c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) @@ -3125,12 +3185,18 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } } - // If connection is marked as closed, bail out. + // Now, send the INFO if it was delayed + if !isClosed && tlsFirst { + c.flags.set(didTLSFirst) + c.sendProtoNow(infoBytes) + // Check closed status + isClosed = c.isClosed() + } + + // Connection could have been closed while sending the INFO proto. if isClosed { c.mu.Unlock() - // Connection could have been closed due to TLS timeout or while trying - // to send the INFO protocol. We need to call closeConnection() to make - // sure that proper cleanup is done. + // We need to call closeConnection() to make sure that proper cleanup is done. c.closeConnection(WriteError) return nil } diff --git a/test/tls_test.go b/test/tls_test.go index ad0c91af179..0fc391cb831 100644 --- a/test/tls_test.go +++ b/test/tls_test.go @@ -82,6 +82,7 @@ func TestTLSInProcessConnection(t *testing.T) { if err != nil { t.Fatal(err) } + defer nc.Close() if nc.TLSRequired() { t.Fatalf("Shouldn't have required TLS for in-process connection")