Skip to content

Commit

Permalink
[ADDED] TLSHandshakeFirst option (#1433)
Browse files Browse the repository at this point in the history
This allows to connect to a server that is also configured to
perform the TLS handshake first, that is, before sending the INFO
protocol.

Signed-off-by: Ivan Kozlovic <[email protected]>
  • Loading branch information
kozlovic authored Oct 9, 2023
1 parent e0f193c commit 1941a1a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 1 deletion.
39 changes: 38 additions & 1 deletion nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ type Options struct {
// TLSCertCB is used to fetch and return custom tls certificate.
TLSCertCB TLSCertHandler

// TLSHandshakeFirst is used to instruct the library perform
// the TLS handshake right after the connect and before receiving
// the INFO protocol from the server. If this option is enabled
// but the server is not configured to perform the TLS handshake
// first, the connection will fail.
TLSHandshakeFirst bool

// RootCAsCB is used to fetch and return a set of root certificate
// authorities that clients use when verifying server certificates.
RootCAsCB RootCAsHandler
Expand Down Expand Up @@ -1315,6 +1322,17 @@ func SkipHostLookup() Option {
}
}

// TLSHandshakeFirst is an Option to perform the TLS handshake first, that is
// before receiving the INFO protocol. This requires the server to also be
// configured with such option, otherwise the connection will fail.
func TLSHandshakeFirst() Option {
return func(o *Options) error {
o.TLSHandshakeFirst = true
o.Secure = true
return nil
}
}

// Handler processing

// SetDisconnectHandler will set the disconnect event handler.
Expand Down Expand Up @@ -1481,6 +1499,12 @@ func (o Options) Connect() (*Conn, error) {
}
}

// If the TLSHandshakeFirst option is specified, make sure that
// the Secure boolean is true.
if nc.Opts.TLSHandshakeFirst {
nc.Opts.Secure = true
}

if err := nc.setupServerPool(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -2235,6 +2259,14 @@ func (nc *Conn) processConnectInit() error {
// Set our status to connecting.
nc.changeConnStatus(CONNECTING)

// If we need to have a TLS connection and want the TLS handshake to occur
// first, do it now.
if nc.Opts.Secure && nc.Opts.TLSHandshakeFirst {
if err := nc.makeTLSConn(); err != nil {
return err
}
}

// Process the INFO protocol received from the server
err := nc.processExpectedInfo()
if err != nil {
Expand Down Expand Up @@ -2351,8 +2383,13 @@ func (nc *Conn) checkForSecure() error {
o.Secure = true
}

// Need to rewrap with bufio
if o.Secure {
// If TLS handshake first is true, we have already done
// the handshake, so we are done here.
if o.TLSHandshakeFirst {
return nil
}
// Need to rewrap with bufio
if err := nc.makeTLSConn(); err != nil {
return err
}
Expand Down
109 changes: 109 additions & 0 deletions test/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2863,3 +2863,112 @@ func TestConnStatusChangedEvents(t *testing.T) {
time.Sleep(100 * time.Millisecond)
})
}

func TestTLSHandshakeFirst(t *testing.T) {
s, opts := RunServerWithConfig("./configs/tls.conf")
defer s.Shutdown()

secureURL := fmt.Sprintf("tls://derek:porkchop@localhost:%d", opts.Port)
nc, err := nats.Connect(secureURL,
nats.RootCAs("./configs/certs/ca.pem"),
nats.TLSHandshakeFirst())
if err == nil || !strings.Contains(err.Error(), "TLS handshake") {
if err == nil {
nc.Close()
}
t.Fatalf("Expected error about not being a TLS handshake, got %v", err)
}

tc := &server.TLSConfigOpts{
CertFile: "./configs/certs/server.pem",
KeyFile: "./configs/certs/key.pem",
}
tlsConf, err := server.GenTLSConfig(tc)
if err != nil {
t.Fatalf("Can't build TLCConfig: %v", err)
}
tlsConf.ServerName = "localhost"

// Start a mockup server that will do the TLS handshake first
// and then send the INFO protcol.
l, e := net.Listen("tcp", ":0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
tl := l.(*net.TCPListener)
defer tl.Close()

addr := tl.Addr().(*net.TCPAddr)

errCh := make(chan error, 1)
doneCh := make(chan struct{})
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
conn, err := l.Accept()
if err != nil {
errCh <- fmt.Errorf("error accepting client connection: %v", err)
return
}
defer conn.Close()

// Do the TLS handshake now.
conn = tls.Server(conn, tlsConf)
tlsconn := conn.(*tls.Conn)
if err := tlsconn.Handshake(); err != nil {
errCh <- fmt.Errorf("Server error during handshake: %v", err)
return
}

// Send back the INFO
info := fmt.Sprintf("INFO {\"server_id\":\"foobar\",\"host\":\"localhost\",\"port\":%d,\"auth_required\":false,\"tls_required\":true,\"tls_available\":true,\"tls_verify\":true,\"max_payload\":1048576}\r\n", addr.Port)
tlsconn.Write([]byte(info))

// Read connect and ping commands sent from the client
line := make([]byte, 256)
_, err = tlsconn.Read(line)
if err != nil {
errCh <- fmt.Errorf("expected CONNECT and PING from client, got: %s", err)
return
}
tlsconn.Write([]byte("PONG\r\n"))

// Wait for the signal that client is ok
<-doneCh
// Server is done now.
errCh <- nil
}()

time.Sleep(100 * time.Millisecond)

secureURL = fmt.Sprintf("tls://derek:porkchop@localhost:%d", addr.Port)
nc, err = nats.Connect(secureURL,
nats.RootCAs("./configs/certs/ca.pem"),
nats.TLSHandshakeFirst())
if err != nil {
wg.Wait()
e := <-errCh
t.Fatalf("Unexpected error: %v (server error=%s)", err, e.Error())
}

state, err := nc.TLSConnectionState()
if err != nil {
t.Fatalf("Expected connection state: %v", err)
}
if !state.HandshakeComplete {
t.Fatalf("Expected valid connection state")
}
nc.Close()

close(doneCh)
wg.Wait()
select {
case e := <-errCh:
if e != nil {
t.Fatalf("Error from server: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Server did not exit")
}
}

0 comments on commit 1941a1a

Please sign in to comment.