diff --git a/internal/ocsp/ocsp.go b/internal/ocsp/ocsp.go index 8beab9d3c16..3110cc9d58a 100644 --- a/internal/ocsp/ocsp.go +++ b/internal/ocsp/ocsp.go @@ -1,4 +1,4 @@ -// Copyright 2019-2021 The NATS Authors +// Copyright 2019-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/server/gateway.go b/server/gateway.go index 65769546892..165f797f5bd 100644 --- a/server/gateway.go +++ b/server/gateway.go @@ -816,6 +816,14 @@ func (s *Server) createGateway(cfg *gatewayCfg, url *url.URL, conn net.Conn) { tlsConfig = cfg.TLSConfig.Clone() timeout = cfg.TLSTimeout cfg.RUnlock() + + // Ensure that OCSP callbacks are always setup on gateway reconnect. + if (opts.Gateway.TLSConfig.GetClientCertificate != nil && tlsConfig.GetClientCertificate == nil) { + tlsConfig.GetClientCertificate = opts.Gateway.TLSConfig.GetClientCertificate + } + if (opts.Gateway.TLSConfig.VerifyConnection != nil && tlsConfig.VerifyConnection == nil) { + tlsConfig.VerifyConnection = opts.Gateway.TLSConfig.VerifyConnection + } } else { tlsConfig = opts.Gateway.TLSConfig timeout = opts.Gateway.TLSTimeout diff --git a/server/gateway_test.go b/server/gateway_test.go index e85f2956d1a..b5796828f19 100644 --- a/server/gateway_test.go +++ b/server/gateway_test.go @@ -6954,13 +6954,18 @@ func TestGatewayConnectEvents(t *testing.T) { checkEvents(t, "Queued", true) } -func disconnectInboundGatewaysAsStale(s *Server) { +func disconnectInboundGateways(s *Server) { s.gateway.RLock() in := s.gateway.in s.gateway.RUnlock() + + s.gateway.RLock() for _, client := range in { - client.closeConnection(StaleConnection) + s.gateway.RUnlock() + client.closeConnection(ClientClosed) + s.gateway.RLock() } + s.gateway.RUnlock() } type testMissingOCSPStapleLogger struct { @@ -6978,7 +6983,7 @@ func (l *testMissingOCSPStapleLogger) Errorf(format string, v ...interface{}) { } } -func TestOCSPGatewayMissingPeerStaple(t *testing.T) { +func TestOCSPGatewayMissingPeerStapleIssue(t *testing.T) { const ( caCert = "../test/configs/certs/ocsp/ca-cert.pem" caKey = "../test/configs/certs/ocsp/ca-key.pem" @@ -7072,6 +7077,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { name: "A" url: "nats://127.0.0.1:%d" }] + tls { cert_file: "../test/configs/certs/ocsp/server-status-request-url-04-cert.pem" key_file: "../test/configs/certs/ocsp/server-status-request-url-04-key.pem" @@ -7136,7 +7142,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { //////////////////////////////////////////////////////////////////////////// // // - // A and B are connected at this point and C is starting with certs that // + // A and B are connected at this point and A is starting with certs that // // will be rotated. // // //////////////////////////////////////////////////////////////////////////// @@ -7233,7 +7239,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { } // Reload and disconnect very fast trying to produce the race. - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second) defer cancel() // Swap logger from server to capture the missing peer log. @@ -7244,9 +7250,18 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { srvB.SetLogger(lB, false, false) lC := &testMissingOCSPStapleLogger{ch: make(chan string, 30)} - srvB.SetLogger(lC, false, false) + srvC.SetLogger(lC, false, false) + + // Start with a reload from the last server that connected directly to A. + if err := srvC.Reload(); err != nil { + t.Fatal(err) + } + // Stress reconnections and reloading servers without getting + // missing OCSP peer staple errors. var wg sync.WaitGroup + + wg.Add(1) go func() { for range time.NewTicker(500 * time.Millisecond).C { select { @@ -7255,11 +7270,11 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { return default: } - disconnectInboundGatewaysAsStale(srvC) + disconnectInboundGateways(srvA) } }() - wg.Add(1) + wg.Add(1) go func() { for range time.NewTicker(500 * time.Millisecond).C { select { @@ -7268,13 +7283,26 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { return default: } - disconnectInboundGatewaysAsStale(srvA) + disconnectInboundGateways(srvB) } }() + wg.Add(1) + go func() { + for range time.NewTicker(500 * time.Millisecond).C { + select { + case <-ctx.Done(): + wg.Done() + return + default: + } + disconnectInboundGateways(srvC) + } + }() + wg.Add(1) go func() { - for range time.NewTicker(1 * time.Second).C { + for range time.NewTicker(700 * time.Millisecond).C { select { case <-ctx.Done(): wg.Done() @@ -7286,16 +7314,45 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) { } } }() + + wg.Add(1) + go func() { + for range time.NewTicker(800 * time.Millisecond).C { + select { + case <-ctx.Done(): + wg.Done() + return + default: + } + if err := srvB.Reload(); err != nil { + t.Fatal(err) + } + } + }() + wg.Add(1) + go func() { + for range time.NewTicker(900 * time.Millisecond).C { + select { + case <-ctx.Done(): + wg.Done() + return + default: + } + if err := srvA.Reload(); err != nil { + t.Fatal(err) + } + } + }() select { case <-ctx.Done(): case msg := <-lA.ch: - t.Fatalf("Got OCSP Staple error: %v", msg) + t.Fatalf("Server A: Got OCSP Staple error: %v", msg) case msg := <-lB.ch: - t.Fatalf("Got OCSP Staple error: %v", msg) + t.Fatalf("Server B: Got OCSP Staple error: %v", msg) case msg := <-lC.ch: - t.Fatalf("Got OCSP Staple error: %v", msg) + t.Fatalf("Server C: Got OCSP Staple error: %v", msg) } wg.Wait() }