Skip to content

Commit

Permalink
ocsp: Ensure callbacks are set when gateways reconnect
Browse files Browse the repository at this point in the history
Signed-off-by: Waldemar Quevedo <[email protected]>
  • Loading branch information
wallyqs committed Mar 22, 2024
1 parent c814866 commit dcfce2e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 14 deletions.
2 changes: 1 addition & 1 deletion internal/ocsp/ocsp.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2021 The NATS Authors
// Copyright 2019-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down
8 changes: 8 additions & 0 deletions server/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,14 @@ func (s *Server) createGateway(cfg *gatewayCfg, url *url.URL, conn net.Conn) {
tlsConfig = cfg.TLSConfig.Clone()
timeout = cfg.TLSTimeout
cfg.RUnlock()

// Ensure that OCSP callbacks are always setup on gateway reconnect.
if (opts.Gateway.TLSConfig.GetClientCertificate != nil && tlsConfig.GetClientCertificate == nil) {
tlsConfig.GetClientCertificate = opts.Gateway.TLSConfig.GetClientCertificate
}
if (opts.Gateway.TLSConfig.VerifyConnection != nil && tlsConfig.VerifyConnection == nil) {
tlsConfig.VerifyConnection = opts.Gateway.TLSConfig.VerifyConnection
}
} else {
tlsConfig = opts.Gateway.TLSConfig
timeout = opts.Gateway.TLSTimeout
Expand Down
83 changes: 70 additions & 13 deletions server/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6954,13 +6954,18 @@ func TestGatewayConnectEvents(t *testing.T) {
checkEvents(t, "Queued", true)
}

func disconnectInboundGatewaysAsStale(s *Server) {
func disconnectInboundGateways(s *Server) {
s.gateway.RLock()
in := s.gateway.in
s.gateway.RUnlock()

s.gateway.RLock()
for _, client := range in {
client.closeConnection(StaleConnection)
s.gateway.RUnlock()
client.closeConnection(ClientClosed)
s.gateway.RLock()
}
s.gateway.RUnlock()
}

type testMissingOCSPStapleLogger struct {
Expand All @@ -6978,7 +6983,7 @@ func (l *testMissingOCSPStapleLogger) Errorf(format string, v ...interface{}) {
}
}

func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
func TestOCSPGatewayMissingPeerStapleIssue(t *testing.T) {
const (
caCert = "../test/configs/certs/ocsp/ca-cert.pem"
caKey = "../test/configs/certs/ocsp/ca-key.pem"
Expand Down Expand Up @@ -7072,6 +7077,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
name: "A"
url: "nats://127.0.0.1:%d"
}]
tls {
cert_file: "../test/configs/certs/ocsp/server-status-request-url-04-cert.pem"
key_file: "../test/configs/certs/ocsp/server-status-request-url-04-key.pem"
Expand Down Expand Up @@ -7136,7 +7142,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {

////////////////////////////////////////////////////////////////////////////
// //
// A and B are connected at this point and C is starting with certs that //
// A and B are connected at this point and A is starting with certs that //
// will be rotated.
// //
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -7233,7 +7239,7 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
}

// Reload and disconnect very fast trying to produce the race.
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()

// Swap logger from server to capture the missing peer log.
Expand All @@ -7244,9 +7250,18 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
srvB.SetLogger(lB, false, false)

lC := &testMissingOCSPStapleLogger{ch: make(chan string, 30)}
srvB.SetLogger(lC, false, false)
srvC.SetLogger(lC, false, false)

// Start with a reload from the last server that connected directly to A.
if err := srvC.Reload(); err != nil {
t.Fatal(err)
}

// Stress reconnections and reloading servers without getting
// missing OCSP peer staple errors.
var wg sync.WaitGroup

wg.Add(1)
go func() {
for range time.NewTicker(500 * time.Millisecond).C {
select {
Expand All @@ -7255,11 +7270,11 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
return
default:
}
disconnectInboundGatewaysAsStale(srvC)
disconnectInboundGateways(srvA)
}
}()
wg.Add(1)

wg.Add(1)
go func() {
for range time.NewTicker(500 * time.Millisecond).C {
select {
Expand All @@ -7268,13 +7283,26 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
return
default:
}
disconnectInboundGatewaysAsStale(srvA)
disconnectInboundGateways(srvB)
}
}()

wg.Add(1)
go func() {
for range time.NewTicker(500 * time.Millisecond).C {
select {
case <-ctx.Done():
wg.Done()
return
default:
}
disconnectInboundGateways(srvC)
}
}()

wg.Add(1)
go func() {
for range time.NewTicker(1 * time.Second).C {
for range time.NewTicker(700 * time.Millisecond).C {
select {
case <-ctx.Done():
wg.Done()
Expand All @@ -7286,16 +7314,45 @@ func TestOCSPGatewayMissingPeerStaple(t *testing.T) {
}
}
}()

wg.Add(1)
go func() {
for range time.NewTicker(800 * time.Millisecond).C {
select {
case <-ctx.Done():
wg.Done()
return
default:
}
if err := srvB.Reload(); err != nil {
t.Fatal(err)
}
}
}()

wg.Add(1)
go func() {
for range time.NewTicker(900 * time.Millisecond).C {
select {
case <-ctx.Done():
wg.Done()
return
default:
}
if err := srvA.Reload(); err != nil {
t.Fatal(err)
}
}
}()

select {
case <-ctx.Done():
case msg := <-lA.ch:
t.Fatalf("Got OCSP Staple error: %v", msg)
t.Fatalf("Server A: Got OCSP Staple error: %v", msg)
case msg := <-lB.ch:
t.Fatalf("Got OCSP Staple error: %v", msg)
t.Fatalf("Server B: Got OCSP Staple error: %v", msg)
case msg := <-lC.ch:
t.Fatalf("Got OCSP Staple error: %v", msg)
t.Fatalf("Server C: Got OCSP Staple error: %v", msg)
}
wg.Wait()
}

0 comments on commit dcfce2e

Please sign in to comment.