diff --git a/server/filestore.go b/server/filestore.go index c6372ed5278..c8f8b0271d6 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -1748,134 +1748,33 @@ func (fs *fileStore) recoverFullState() (rerr error) { return errPriorState } if matched = bytes.Equal(mb.lastChecksum(), lchk[:]); !matched { - // If we are tracking max msgs per subject and we are not up to date we should rebuild. - if fs.cfg.MaxMsgsPer > 0 { - fs.warn("Stream state block state outdated, will rebuild") - return errPriorState - } - - // Remove the last message block since recover will add in the new one. - fs.removeMsgBlockFromList(mb) - // Reverse update of tracking state for this mb, will add new state in below. - mstate.Msgs -= mb.msgs - mstate.Bytes -= mb.bytes - if nmb, err := fs.recoverMsgBlock(mb.index); err != nil && !os.IsNotExist(err) { - fs.warn("Stream state could not recover last msg block") - os.Remove(fn) - return errCorruptState - } else if nmb != nil { - fs.adjustAccounting(mb, nmb) - updateTrackingState(&mstate, nmb) - } - } - - // On success double check our state. - checkState := func() error { - // We check first and last seq and number of msgs and bytes. If there is a difference, - // return and error so we rebuild from the message block state on disk. - if !trackingStatesEqual(&fs.state, &mstate) { - fs.warn("Stream state encountered internal inconsistency on recover") - os.Remove(fn) - return errCorruptState - } - return nil - } - - // We may need to check other blocks. Even if we matched last checksum we will see if there is another block. - for bi := blkIndex + 1; ; bi++ { - nmb, err := fs.recoverMsgBlock(bi) - if err != nil { - if os.IsNotExist(err) { - return checkState() - } - os.Remove(fn) - fs.warn("Stream state could not recover msg block %d", bi) - return err - } - if nmb != nil { - // Update top level accounting - if fseq := atomic.LoadUint64(&nmb.first.seq); fs.state.FirstSeq == 0 || fseq < fs.state.FirstSeq { - fs.state.FirstSeq = fseq - if nmb.first.ts == 0 { - fs.state.FirstTime = time.Time{} - } else { - fs.state.FirstTime = time.Unix(0, nmb.first.ts).UTC() - } - } - if lseq := atomic.LoadUint64(&nmb.last.seq); lseq > fs.state.LastSeq { - fs.state.LastSeq = lseq - if mb.last.ts == 0 { - fs.state.LastTime = time.Time{} - } else { - fs.state.LastTime = time.Unix(0, nmb.last.ts).UTC() - } - } - fs.state.Msgs += nmb.msgs - fs.state.Bytes += nmb.bytes - updateTrackingState(&mstate, nmb) - } + // Detected a stale index.db, we didn't write it upon shutdown so can't rely on it being correct. + fs.warn("Stream state outdated, last block has additional entries, will rebuild") + return errPriorState } -} -// adjustAccounting will be called when a stream state was only partially accounted for -// within a message block, e.g. additional records were added after the stream state. -// Lock should be held. -func (fs *fileStore) adjustAccounting(mb, nmb *msgBlock) { - nmb.mu.Lock() - defer nmb.mu.Unlock() + // We need to see if any blocks exist after our last one even though we matched the last record exactly. + mdir := filepath.Join(fs.fcfg.StoreDir, msgDir) + var dirs []os.DirEntry - // First make sure the new block is loaded. - if nmb.cacheNotLoaded() { - nmb.loadMsgsWithLock() + <-dios + if f, err := os.Open(mdir); err == nil { + dirs, _ = f.ReadDir(-1) + f.Close() } - nmb.ensurePerSubjectInfoLoaded() - - var smv StoreMsg + dios <- struct{}{} - // Need to walk previous messages and undo psim stats. - // We already undid msgs and bytes accounting. - for seq, lseq := atomic.LoadUint64(&mb.first.seq), atomic.LoadUint64(&mb.last.seq); seq <= lseq; seq++ { - // Lookup the message. If an error will be deleted, so can skip. - sm, err := nmb.cacheLookup(seq, &smv) - if err != nil { - continue - } - if len(sm.subj) > 0 && fs.psim != nil { - if info, ok := fs.psim.Find(stringToBytes(sm.subj)); ok { - info.total-- + var index uint32 + for _, fi := range dirs { + if n, err := fmt.Sscanf(fi.Name(), blkScan, &index); err == nil && n == 1 { + if index > blkIndex { + fs.warn("Stream state outdated, found extra blocks, will rebuild") + return errPriorState } } } - // Walk only new messages and update accounting at fs level. Any messages that should have - // triggered limits exceeded will be handled after the recovery and prior to the stream - // being available to the system. - for seq, lseq := atomic.LoadUint64(&mb.last.seq)+1, atomic.LoadUint64(&nmb.last.seq); seq <= lseq; seq++ { - // Lookup the message. If an error will be deleted, so can skip. - sm, err := nmb.cacheLookup(seq, &smv) - if err != nil { - continue - } - // Since we found it we just need to adjust fs totals and psim. - fs.state.Msgs++ - fs.state.Bytes += fileStoreMsgSize(sm.subj, sm.hdr, sm.msg) - } - - // Now check to see if we had a higher first for the recovered state mb vs nmb. - if atomic.LoadUint64(&nmb.first.seq) < atomic.LoadUint64(&mb.first.seq) { - // Now set first for nmb. - atomic.StoreUint64(&nmb.first.seq, atomic.LoadUint64(&mb.first.seq)) - } - - // Update top level accounting. - if fseq := atomic.LoadUint64(&nmb.first.seq); fs.state.FirstSeq == 0 || fseq < fs.state.FirstSeq { - fs.state.FirstSeq = fseq - fs.state.FirstTime = time.Unix(0, nmb.first.ts).UTC() - } - if lseq := atomic.LoadUint64(&nmb.last.seq); lseq > fs.state.LastSeq { - fs.state.LastSeq = lseq - fs.state.LastTime = time.Unix(0, nmb.last.ts).UTC() - } + return nil } // Grabs last checksum for the named block file. diff --git a/server/filestore_test.go b/server/filestore_test.go index cb5737cec2c..6fbffe6bd0e 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -7540,6 +7540,72 @@ func TestFileStoreDmapBlockRecoverAfterCompact(t *testing.T) { require_Equal(t, dmap.Size(), 4) } +func TestFileStoreRestoreIndexWithMatchButLeftOverBlocks(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 256}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("hello") + + // 6 msgs per block. + // Fill the first 2 blocks. + for i := 1; i <= 12; i++ { + fs.StoreMsg(fmt.Sprintf("foo.%d", i), nil, msg) + } + require_Equal(t, fs.numMsgBlocks(), 2) + + // We will now stop which will create the index.db file which will + // match the last record exactly. + sfile := filepath.Join(sd, msgDir, streamStreamStateFile) + fs.Stop() + + // Grab it since we will put it back. + buf, err := os.ReadFile(sfile) + require_NoError(t, err) + require_True(t, len(buf) > 0) + + // Now do an additional block, but with the MaxMsgsPer this will remove the first block, + // but leave the second so on recovery will match the checksum for the last msg in second block. + + fs, err = newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 256}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) + require_NoError(t, err) + defer fs.Stop() + + for i := 1; i <= 6; i++ { + fs.StoreMsg(fmt.Sprintf("foo.%d", i), nil, msg) + } + + // Grab correct state, we will use it to make sure we do the right thing. + var state StreamState + fs.FastState(&state) + + require_Equal(t, state.Msgs, 12) + require_Equal(t, state.FirstSeq, 7) + require_Equal(t, state.LastSeq, 18) + // This will be block 2 and 3. + require_Equal(t, fs.numMsgBlocks(), 2) + + fs.Stop() + // Put old stream state back. + require_NoError(t, os.WriteFile(sfile, buf, defaultFilePerms)) + + fs, err = newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 256}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) + require_NoError(t, err) + defer fs.Stop() + + fs.FastState(&state) + require_Equal(t, state.Msgs, 12) + require_Equal(t, state.FirstSeq, 7) + require_Equal(t, state.LastSeq, 18) +} + /////////////////////////////////////////////////////////////////////////// // Benchmarks /////////////////////////////////////////////////////////////////////////// diff --git a/server/jetstream.go b/server/jetstream.go index 4ace6731ce2..4c56b79775f 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -105,6 +105,7 @@ type jetStream struct { storeReserved int64 memUsed int64 storeUsed int64 + queueLimit int64 clustered int32 mu sync.RWMutex srv *Server @@ -377,6 +378,9 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error { } s.gcbMu.Unlock() + // TODO: Not currently reloadable. + atomic.StoreInt64(&js.queueLimit, s.getOpts().JetStreamRequestQueueLimit) + s.js.Store(js) // FIXME(dlc) - Allow memory only operation? diff --git a/server/jetstream_api.go b/server/jetstream_api.go index 479babf81ca..6e7d82f9fc0 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -299,6 +299,9 @@ const ( // JSAdvisoryServerRemoved notification that a server has been removed from the system. JSAdvisoryServerRemoved = "$JS.EVENT.ADVISORY.SERVER.REMOVED" + // JSAdvisoryAPILimitReached notification that a server has reached the JS API hard limit. + JSAdvisoryAPILimitReached = "$JS.EVENT.ADVISORY.API.LIMIT_REACHED" + // JSAuditAdvisory is a notification about JetStream API access. // FIXME - Add in details about who.. JSAuditAdvisory = "$JS.EVENT.ADVISORY.API" @@ -346,6 +349,10 @@ const JSMaxMetadataLen = 128 * 1024 // Picked 255 as it seems to be a widely used file name limit const JSMaxNameLen = 255 +// JSDefaultRequestQueueLimit is the default number of entries that we will +// put on the global request queue before we react. +const JSDefaultRequestQueueLimit = 10_000 + // Responses for API calls. // ApiResponse is a standard response from the JetStream JSON API @@ -825,10 +832,22 @@ func (js *jetStream) apiDispatch(sub *subscription, c *client, acc *Account, sub // Copy the state. Note the JSAPI only uses the hdr index to piece apart the // header from the msg body. No other references are needed. // Check pending and warn if getting backed up. - const warnThresh = 128 pending := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa}) - if pending >= warnThresh { - s.rateLimitFormatWarnf("JetStream request queue has high pending count: %d", pending) + limit := atomic.LoadInt64(&js.queueLimit) + if pending >= int(limit) { + s.rateLimitFormatWarnf("JetStream API queue limit reached, dropping %d requests", pending) + s.jsAPIRoutedReqs.drain() + + s.publishAdvisory(nil, JSAdvisoryAPILimitReached, JSAPILimitReachedAdvisory{ + TypedEvent: TypedEvent{ + Type: JSAPILimitReachedAdvisoryType, + ID: nuid.Next(), + Time: time.Now().UTC(), + }, + Server: s.Name(), + Domain: js.config.Domain, + Dropped: int64(pending), + }) } } diff --git a/server/jetstream_cluster_4_test.go b/server/jetstream_cluster_4_test.go index 592a0e0148f..fd95ce95cde 100644 --- a/server/jetstream_cluster_4_test.go +++ b/server/jetstream_cluster_4_test.go @@ -18,6 +18,7 @@ package server import ( "context" + "encoding/json" "errors" "fmt" "math/rand" @@ -2055,7 +2056,7 @@ func TestJetStreamClusterAndNamesWithSpaces(t *testing.T) { } accounts { - sys { + sys { users = [ { user: sys, pass: sys } ] } js { jetstream: enabled @@ -3445,3 +3446,74 @@ func TestJetStreamClusterAckDeleted(t *testing.T) { ) } } + +func TestJetStreamClusterAPILimitDefault(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + for _, s := range c.servers { + s.optsMu.RLock() + lim := s.opts.JetStreamRequestQueueLimit + s.optsMu.RUnlock() + + require_Equal(t, lim, JSDefaultRequestQueueLimit) + require_Equal(t, atomic.LoadInt64(&s.getJetStream().queueLimit), JSDefaultRequestQueueLimit) + } +} + +func TestJetStreamClusterAPILimitAdvisory(t *testing.T) { + // Hit the limit straight away. + const queueLimit = 1 + + config := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: { + max_mem_store: 256MB + max_file_store: 2GB + store_dir: '%s' + request_queue_limit: ` + fmt.Sprintf("%d", queueLimit) + ` + } + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + accounts { $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } } + ` + c := createJetStreamClusterWithTemplate(t, config, "R3S", 3) + defer c.shutdown() + + c.waitOnLeader() + s := c.randomNonLeader() + + for _, s := range c.servers { + lim := atomic.LoadInt64(&s.getJetStream().queueLimit) + require_Equal(t, lim, queueLimit) + } + + nc, _ := jsClientConnect(t, s) + defer nc.Close() + + snc, _ := jsClientConnect(t, c.randomServer(), nats.UserInfo("admin", "s3cr3t!")) + defer snc.Close() + + sub, err := snc.SubscribeSync(JSAdvisoryAPILimitReached) + require_NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + require_NoError(t, nc.PublishMsg(&nats.Msg{ + Subject: fmt.Sprintf(JSApiConsumerListT, "TEST"), + Reply: nc.NewInbox(), + })) + + // Wait for the advisory to come in. + msg, err := sub.NextMsgWithContext(ctx) + require_NoError(t, err) + var advisory JSAPILimitReachedAdvisory + require_NoError(t, json.Unmarshal(msg.Data, &advisory)) + require_Equal(t, advisory.Domain, _EMPTY_) // No JetStream domain was set. + require_Equal(t, advisory.Dropped, queueLimit) // Configured queue limit. +} diff --git a/server/jetstream_events.go b/server/jetstream_events.go index 1852811bb96..8302fcc4048 100644 --- a/server/jetstream_events.go +++ b/server/jetstream_events.go @@ -283,3 +283,14 @@ type JSServerRemovedAdvisory struct { Cluster string `json:"cluster"` Domain string `json:"domain,omitempty"` } + +// JSAPILimitReachedAdvisoryType is sent when the JS API request queue limit is reached. +const JSAPILimitReachedAdvisoryType = "io.nats.jetstream.advisory.v1.api_limit_reached" + +// JSAPILimitReachedAdvisory is a advisory published when JetStream hits the queue length limit. +type JSAPILimitReachedAdvisory struct { + TypedEvent + Server string `json:"server"` // Server that created the event, name or ID + Domain string `json:"domain,omitempty"` // Domain the server belongs to + Dropped int64 `json:"dropped"` // How many messages did we drop from the queue +} diff --git a/server/norace_test.go b/server/norace_test.go index bc0c604e354..f1a8fa51688 100644 --- a/server/norace_test.go +++ b/server/norace_test.go @@ -10849,67 +10849,111 @@ func TestNoRaceJetStreamStandaloneDontReplyToAckBeforeProcessingIt(t *testing.T) } } -// Under certain scenarios an old index.db with a stream that has max msgs per set will not restore properly -// due to and old index.db and compaction after the index.db took place which could lose per subject information. -func TestNoRaceFileStoreMaxMsgsPerSubjectAndOldRecoverState(t *testing.T) { - sd := t.TempDir() - fs, err := newFileStore( - FileStoreConfig{StoreDir: sd}, - StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) - require_NoError(t, err) - defer fs.Stop() +// Under certain scenarios an old index.db with a stream that has msg limits set will not restore properly +// due to an old index.db and compaction after the index.db took place which could lose per subject information. +func TestNoRaceFileStoreMsgLimitsAndOldRecoverState(t *testing.T) { + for _, test := range []struct { + name string + expectedFirstSeq uint64 + expectedLastSeq uint64 + expectedMsgs uint64 + transform func(StreamConfig) StreamConfig + }{ + { + name: "MaxMsgsPer", + expectedFirstSeq: 10_001, + expectedLastSeq: 1_010_001, + expectedMsgs: 1_000_001, + transform: func(config StreamConfig) StreamConfig { + config.MaxMsgsPer = 1 + return config + }, + }, + { + name: "MaxMsgs", + expectedFirstSeq: 10_001, + expectedLastSeq: 1_010_001, + expectedMsgs: 1_000_001, + transform: func(config StreamConfig) StreamConfig { + config.MaxMsgs = 1_000_001 + return config + }, + }, + { + name: "MaxBytes", + expectedFirstSeq: 8_624, + expectedLastSeq: 1_010_001, + expectedMsgs: 1_001_378, + transform: func(config StreamConfig) StreamConfig { + config.MaxBytes = 1_065_353_216 + return config + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd}, + test.transform(StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}), + ) + require_NoError(t, err) + defer fs.Stop() - msg := make([]byte, 1024) + msg := make([]byte, 1024) - for i := 0; i < 10_000; i++ { - subj := fmt.Sprintf("foo.%d", i) - fs.StoreMsg(subj, nil, msg) - } + for i := 0; i < 10_000; i++ { + subj := fmt.Sprintf("foo.%d", i) + fs.StoreMsg(subj, nil, msg) + } - // This will write the index.db file. We will capture this and use it to replace a new one. - sfile := filepath.Join(fs.fcfg.StoreDir, msgDir, streamStreamStateFile) - fs.Stop() - _, err = os.Stat(sfile) - require_NoError(t, err) + // This will write the index.db file. We will capture this and use it to replace a new one. + sfile := filepath.Join(fs.fcfg.StoreDir, msgDir, streamStreamStateFile) + fs.Stop() + _, err = os.Stat(sfile) + require_NoError(t, err) - // Read it in and make sure len > 0. - buf, err := os.ReadFile(sfile) - require_NoError(t, err) - require_True(t, len(buf) > 0) + // Read it in and make sure len > 0. + buf, err := os.ReadFile(sfile) + require_NoError(t, err) + require_True(t, len(buf) > 0) - // Restart - fs, err = newFileStore( - FileStoreConfig{StoreDir: sd}, - StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) - require_NoError(t, err) - defer fs.Stop() + // Restart + fs, err = newFileStore( + FileStoreConfig{StoreDir: sd}, + test.transform(StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}), + ) + require_NoError(t, err) + defer fs.Stop() + + // Put in more messages with wider range. This will compact a bunch of the previous blocks. + for i := 0; i < 1_000_001; i++ { + subj := fmt.Sprintf("foo.%d", i) + fs.StoreMsg(subj, nil, msg) + } + + var ss StreamState + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, test.expectedFirstSeq) + require_Equal(t, ss.LastSeq, test.expectedLastSeq) + require_Equal(t, ss.Msgs, test.expectedMsgs) + + // Now stop again, but replace index.db with old one. + fs.Stop() + // Put back old stream state. + require_NoError(t, os.WriteFile(sfile, buf, defaultFilePerms)) + + // Restart + fs, err = newFileStore( + FileStoreConfig{StoreDir: sd}, + test.transform(StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}), + ) + require_NoError(t, err) + defer fs.Stop() - // Put in more messages with wider range. This will compact a bunch of the previous blocks. - for i := 0; i < 1_000_001; i++ { - subj := fmt.Sprintf("foo.%d", i) - fs.StoreMsg(subj, nil, msg) + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, test.expectedFirstSeq) + require_Equal(t, ss.LastSeq, test.expectedLastSeq) + require_Equal(t, ss.Msgs, test.expectedMsgs) + }) } - - var ss StreamState - fs.FastState(&ss) - require_Equal(t, ss.FirstSeq, 10_001) - require_Equal(t, ss.LastSeq, 1_010_001) - require_Equal(t, ss.Msgs, 1_000_001) - - // Now stop again, but replace index.db with old one. - fs.Stop() - // Put back old stream state. - require_NoError(t, os.WriteFile(sfile, buf, defaultFilePerms)) - - // Restart - fs, err = newFileStore( - FileStoreConfig{StoreDir: sd}, - StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage, MaxMsgsPer: 1}) - require_NoError(t, err) - defer fs.Stop() - - fs.FastState(&ss) - require_Equal(t, ss.FirstSeq, 10_001) - require_Equal(t, ss.LastSeq, 1_010_001) - require_Equal(t, ss.Msgs, 1_000_001) } diff --git a/server/opts.go b/server/opts.go index 931fa7868f6..0b4ed483dce 100644 --- a/server/opts.go +++ b/server/opts.go @@ -251,86 +251,87 @@ type AuthCallout struct { // NOTE: This structure is no longer used for monitoring endpoints // and json tags are deprecated and may be removed in the future. type Options struct { - ConfigFile string `json:"-"` - ServerName string `json:"server_name"` - Host string `json:"addr"` - Port int `json:"port"` - DontListen bool `json:"dont_listen"` - ClientAdvertise string `json:"-"` - Trace bool `json:"-"` - Debug bool `json:"-"` - TraceVerbose bool `json:"-"` - NoLog bool `json:"-"` - NoSigs bool `json:"-"` - NoSublistCache bool `json:"-"` - NoHeaderSupport bool `json:"-"` - DisableShortFirstPing bool `json:"-"` - Logtime bool `json:"-"` - LogtimeUTC bool `json:"-"` - MaxConn int `json:"max_connections"` - MaxSubs int `json:"max_subscriptions,omitempty"` - MaxSubTokens uint8 `json:"-"` - Nkeys []*NkeyUser `json:"-"` - Users []*User `json:"-"` - Accounts []*Account `json:"-"` - NoAuthUser string `json:"-"` - SystemAccount string `json:"-"` - NoSystemAccount bool `json:"-"` - Username string `json:"-"` - Password string `json:"-"` - Authorization string `json:"-"` - AuthCallout *AuthCallout `json:"-"` - PingInterval time.Duration `json:"ping_interval"` - MaxPingsOut int `json:"ping_max"` - HTTPHost string `json:"http_host"` - HTTPPort int `json:"http_port"` - HTTPBasePath string `json:"http_base_path"` - HTTPSPort int `json:"https_port"` - AuthTimeout float64 `json:"auth_timeout"` - MaxControlLine int32 `json:"max_control_line"` - MaxPayload int32 `json:"max_payload"` - MaxPending int64 `json:"max_pending"` - Cluster ClusterOpts `json:"cluster,omitempty"` - Gateway GatewayOpts `json:"gateway,omitempty"` - LeafNode LeafNodeOpts `json:"leaf,omitempty"` - JetStream bool `json:"jetstream"` - JetStreamMaxMemory int64 `json:"-"` - JetStreamMaxStore int64 `json:"-"` - JetStreamDomain string `json:"-"` - JetStreamExtHint string `json:"-"` - JetStreamKey string `json:"-"` - JetStreamOldKey string `json:"-"` - JetStreamCipher StoreCipher `json:"-"` - JetStreamUniqueTag string - JetStreamLimits JSLimitOpts - JetStreamMaxCatchup int64 - StoreDir string `json:"-"` - SyncInterval time.Duration `json:"-"` - SyncAlways bool `json:"-"` - JsAccDefaultDomain map[string]string `json:"-"` // account to domain name mapping - Websocket WebsocketOpts `json:"-"` - MQTT MQTTOpts `json:"-"` - ProfPort int `json:"-"` - ProfBlockRate int `json:"-"` - PidFile string `json:"-"` - PortsFileDir string `json:"-"` - LogFile string `json:"-"` - LogSizeLimit int64 `json:"-"` - LogMaxFiles int64 `json:"-"` - Syslog bool `json:"-"` - RemoteSyslog string `json:"-"` - Routes []*url.URL `json:"-"` - RoutesStr string `json:"-"` - TLSTimeout float64 `json:"tls_timeout"` - TLS bool `json:"-"` - TLSVerify bool `json:"-"` - TLSMap bool `json:"-"` - TLSCert string `json:"-"` - TLSKey string `json:"-"` - TLSCaCert string `json:"-"` - TLSConfig *tls.Config `json:"-"` - TLSPinnedCerts PinnedCertSet `json:"-"` - TLSRateLimit int64 `json:"-"` + ConfigFile string `json:"-"` + ServerName string `json:"server_name"` + Host string `json:"addr"` + Port int `json:"port"` + DontListen bool `json:"dont_listen"` + ClientAdvertise string `json:"-"` + Trace bool `json:"-"` + Debug bool `json:"-"` + TraceVerbose bool `json:"-"` + NoLog bool `json:"-"` + NoSigs bool `json:"-"` + NoSublistCache bool `json:"-"` + NoHeaderSupport bool `json:"-"` + DisableShortFirstPing bool `json:"-"` + Logtime bool `json:"-"` + LogtimeUTC bool `json:"-"` + MaxConn int `json:"max_connections"` + MaxSubs int `json:"max_subscriptions,omitempty"` + MaxSubTokens uint8 `json:"-"` + Nkeys []*NkeyUser `json:"-"` + Users []*User `json:"-"` + Accounts []*Account `json:"-"` + NoAuthUser string `json:"-"` + SystemAccount string `json:"-"` + NoSystemAccount bool `json:"-"` + Username string `json:"-"` + Password string `json:"-"` + Authorization string `json:"-"` + AuthCallout *AuthCallout `json:"-"` + PingInterval time.Duration `json:"ping_interval"` + MaxPingsOut int `json:"ping_max"` + HTTPHost string `json:"http_host"` + HTTPPort int `json:"http_port"` + HTTPBasePath string `json:"http_base_path"` + HTTPSPort int `json:"https_port"` + AuthTimeout float64 `json:"auth_timeout"` + MaxControlLine int32 `json:"max_control_line"` + MaxPayload int32 `json:"max_payload"` + MaxPending int64 `json:"max_pending"` + Cluster ClusterOpts `json:"cluster,omitempty"` + Gateway GatewayOpts `json:"gateway,omitempty"` + LeafNode LeafNodeOpts `json:"leaf,omitempty"` + JetStream bool `json:"jetstream"` + JetStreamMaxMemory int64 `json:"-"` + JetStreamMaxStore int64 `json:"-"` + JetStreamDomain string `json:"-"` + JetStreamExtHint string `json:"-"` + JetStreamKey string `json:"-"` + JetStreamOldKey string `json:"-"` + JetStreamCipher StoreCipher `json:"-"` + JetStreamUniqueTag string + JetStreamLimits JSLimitOpts + JetStreamMaxCatchup int64 + JetStreamRequestQueueLimit int64 + StoreDir string `json:"-"` + SyncInterval time.Duration `json:"-"` + SyncAlways bool `json:"-"` + JsAccDefaultDomain map[string]string `json:"-"` // account to domain name mapping + Websocket WebsocketOpts `json:"-"` + MQTT MQTTOpts `json:"-"` + ProfPort int `json:"-"` + ProfBlockRate int `json:"-"` + PidFile string `json:"-"` + PortsFileDir string `json:"-"` + LogFile string `json:"-"` + LogSizeLimit int64 `json:"-"` + LogMaxFiles int64 `json:"-"` + Syslog bool `json:"-"` + RemoteSyslog string `json:"-"` + Routes []*url.URL `json:"-"` + RoutesStr string `json:"-"` + TLSTimeout float64 `json:"tls_timeout"` + TLS bool `json:"-"` + TLSVerify bool `json:"-"` + TLSMap bool `json:"-"` + TLSCert string `json:"-"` + TLSKey string `json:"-"` + TLSCaCert string `json:"-"` + TLSConfig *tls.Config `json:"-"` + TLSPinnedCerts PinnedCertSet `json:"-"` + TLSRateLimit int64 `json:"-"` // When set to true, the server will perform the TLS handshake before // sending the INFO protocol. For clients that are not configured // with a similar option, their connection will fail with some sort @@ -675,6 +676,7 @@ type TLSConfigOpts struct { CertMatch string OCSPPeerConfig *certidp.OCSPPeerConfig Certificates []*TLSCertPairOpt + MinVersion uint16 } // TLSCertPairOpt are the paths to a certificate and private key. @@ -2234,6 +2236,12 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er return &configErr{tk, fmt.Sprintf("%s %s", strings.ToLower(mk), err)} } opts.JetStreamMaxCatchup = s + case "request_queue_limit": + lim, ok := mv.(int64) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)} + } + opts.JetStreamRequestQueueLimit = lim default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -4230,6 +4238,24 @@ func parseCurvePreferences(curveName string) (tls.CurveID, error) { return curve, nil } +func parseTLSVersion(v any) (uint16, error) { + var tlsVersionNumber uint16 + switch v := v.(type) { + case string: + n, err := tlsVersionFromString(v) + if err != nil { + return 0, err + } + tlsVersionNumber = n + default: + return 0, fmt.Errorf("'min_version' wrong type: %v", v) + } + if tlsVersionNumber < tls.VersionTLS12 { + return 0, fmt.Errorf("unsupported TLS version: %s", tls.VersionName(tlsVersionNumber)) + } + return tlsVersionNumber, nil +} + // Helper function to parse TLS configs. func parseTLS(v any, isClientCtx bool) (t *TLSConfigOpts, retErr error) { var ( @@ -4473,6 +4499,12 @@ func parseTLS(v any, isClientCtx bool) (t *TLSConfigOpts, retErr error) { } tc.Certificates[i] = certPair } + case "min_version": + minVersion, err := parseTLSVersion(mv) + if err != nil { + return nil, &configErr{tk, fmt.Sprintf("error parsing tls config: %v", err)} + } + tc.MinVersion = minVersion default: return nil, &configErr{tk, fmt.Sprintf("error parsing tls config, unknown field %q", mk)} } @@ -4824,6 +4856,13 @@ func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { } config.ClientCAs = pool } + // Allow setting TLS minimum version. + if tc.MinVersion > 0 { + if tc.MinVersion < tls.VersionTLS12 { + return nil, fmt.Errorf("unsupported minimum TLS version: %s", tls.VersionName(tc.MinVersion)) + } + config.MinVersion = tc.MinVersion + } return &config, nil } @@ -5193,6 +5232,9 @@ func setBaselineOptions(opts *Options) { if opts.SyncInterval == 0 && !opts.syncSet { opts.SyncInterval = defaultSyncInterval } + if opts.JetStreamRequestQueueLimit <= 0 { + opts.JetStreamRequestQueueLimit = JSDefaultRequestQueueLimit + } } func getDefaultAuthTimeout(tls *tls.Config, tlsTimeout float64) float64 { diff --git a/server/opts_test.go b/server/opts_test.go index 351603dfa66..0755c7d02dd 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -67,12 +67,13 @@ func TestDefaultOptions(t *testing.T) { LeafNode: LeafNodeOpts{ ReconnectInterval: DEFAULT_LEAF_NODE_RECONNECT, }, - ConnectErrorReports: DEFAULT_CONNECT_ERROR_REPORTS, - ReconnectErrorReports: DEFAULT_RECONNECT_ERROR_REPORTS, - MaxTracedMsgLen: 0, - JetStreamMaxMemory: -1, - JetStreamMaxStore: -1, - SyncInterval: 2 * time.Minute, + ConnectErrorReports: DEFAULT_CONNECT_ERROR_REPORTS, + ReconnectErrorReports: DEFAULT_RECONNECT_ERROR_REPORTS, + MaxTracedMsgLen: 0, + JetStreamMaxMemory: -1, + JetStreamMaxStore: -1, + SyncInterval: 2 * time.Minute, + JetStreamRequestQueueLimit: JSDefaultRequestQueueLimit, } opts := &Options{} diff --git a/server/reload.go b/server/reload.go index 15bbae1e385..347fcfd8b79 100644 --- a/server/reload.go +++ b/server/reload.go @@ -1150,7 +1150,7 @@ func imposeOrder(value any) error { slices.SortFunc(value.Gateways, func(i, j *RemoteGatewayOpts) int { return cmp.Compare(i.Name, j.Name) }) case WebsocketOpts: slices.Sort(value.AllowedOrigins) - case string, bool, uint8, int, int32, int64, time.Duration, float64, nil, LeafNodeOpts, ClusterOpts, *tls.Config, PinnedCertSet, + case string, bool, uint8, uint16, int, int32, int64, time.Duration, float64, nil, LeafNodeOpts, ClusterOpts, *tls.Config, PinnedCertSet, *URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts, jwt.TagList, *OCSPConfig, map[string]string, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig: // explicitly skipped types diff --git a/server/server.go b/server/server.go index c2693332abb..2db2f35574c 100644 --- a/server/server.go +++ b/server/server.go @@ -2926,8 +2926,10 @@ func (s *Server) startMonitoring(secure bool) error { } hp = net.JoinHostPort(opts.HTTPHost, strconv.Itoa(port)) config := opts.TLSConfig.Clone() - config.GetConfigForClient = s.getMonitoringTLSConfig - config.ClientAuth = tls.NoClientCert + if !s.ocspPeerVerify { + config.GetConfigForClient = s.getMonitoringTLSConfig + config.ClientAuth = tls.NoClientCert + } httpListener, err = tls.Listen("tcp", hp, config) } else { @@ -3441,6 +3443,20 @@ func tlsVersion(ver uint16) string { return fmt.Sprintf("Unknown [0x%x]", ver) } +func tlsVersionFromString(ver string) (uint16, error) { + switch ver { + case "1.0": + return tls.VersionTLS10, nil + case "1.1": + return tls.VersionTLS11, nil + case "1.2": + return tls.VersionTLS12, nil + case "1.3": + return tls.VersionTLS13, nil + } + return 0, fmt.Errorf("unknown version: %v", ver) +} + // We use hex here so we don't need multiple versions func tlsCipher(cs uint16) string { name, present := cipherMapByID[cs] diff --git a/server/server_test.go b/server/server_test.go index 915d2a98a55..d1ae37e0498 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -203,7 +203,102 @@ func TestTLSVersions(t *testing.T) { } } -func TestTlsCipher(t *testing.T) { +func TestTLSMinVersionConfig(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + tls { + cert_file: "../test/configs/certs/server-cert.pem" + key_file: "../test/configs/certs/server-key.pem" + timeout: 1 + min_version: %s + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, `"1.3"`))) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + connect := func(t *testing.T, tlsConf *tls.Config, expectedErr error) { + t.Helper() + opts := []nats.Option{} + if tlsConf != nil { + opts = append(opts, nats.Secure(tlsConf)) + } + opts = append(opts, nats.RootCAs("../test/configs/certs/ca.pem")) + nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", o.Port), opts...) + if expectedErr == nil { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } else if err == nil || err.Error() != expectedErr.Error() { + nc.Close() + t.Fatalf("Expected error %v, got: %v", expectedErr, err) + } + } + + // Cannot connect with client requiring a lower minimum TLS Version. + connect(t, &tls.Config{ + MaxVersion: tls.VersionTLS12, + }, errors.New(`remote error: tls: protocol version not supported`)) + + // Should connect since matching minimum TLS version. + connect(t, &tls.Config{ + MinVersion: tls.VersionTLS13, + }, nil) + + // Reloading with invalid values should fail. + if err := os.WriteFile(conf, []byte(fmt.Sprintf(tmpl, `"1.0"`)), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + if err := s.Reload(); err == nil { + t.Fatalf("Expected reload to fail: %v", err) + } + + // Reloading with original values and no changes should be ok. + if err := os.WriteFile(conf, []byte(fmt.Sprintf(tmpl, `"1.3"`)), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + if err := s.Reload(); err != nil { + t.Fatalf("Unexpected error reloading TLS version: %v", err) + } + + // Reloading with a new minimum lower version. + if err := os.WriteFile(conf, []byte(fmt.Sprintf(tmpl, `"1.2"`)), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + if err := s.Reload(); err != nil { + t.Fatalf("Unexpected error reloading: %v", err) + } + + // Should connect since now matching minimum TLS version. + connect(t, &tls.Config{ + MaxVersion: tls.VersionTLS12, + }, nil) + connect(t, &tls.Config{ + MinVersion: tls.VersionTLS13, + }, nil) + + // Setting unsupported TLS versions + if err := os.WriteFile(conf, []byte(fmt.Sprintf(tmpl, `"1.4"`)), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + if err := s.Reload(); err == nil || !strings.Contains(err.Error(), `unknown version: 1.4`) { + t.Fatalf("Unexpected error reloading: %v", err) + } + + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + Timeout: 4.0, + MinVersion: tls.VersionTLS11, + } + _, err := GenTLSConfig(tc) + if err == nil || err.Error() != `unsupported minimum TLS version: TLS 1.1` { + t.Fatalf("Expected error generating TLS config: %v", err) + } +} + +func TestTLSCipher(t *testing.T) { if strings.Compare(tlsCipher(0x0005), "TLS_RSA_WITH_RC4_128_SHA") != 0 { t.Fatalf("Invalid tls cipher") } diff --git a/test/ocsp_peer_test.go b/test/ocsp_peer_test.go index 1c97100d44a..bc66e37dbf2 100644 --- a/test/ocsp_peer_test.go +++ b/test/ocsp_peer_test.go @@ -16,6 +16,7 @@ package test import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -2949,3 +2950,153 @@ func TestOCSPPeerNextUpdateUnset(t *testing.T) { }) } } + +func TestOCSPMonitoringPort(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rootCAResponder := NewOCSPResponderRootCA(t) + rootCAResponderURL := fmt.Sprintf("http://%s", rootCAResponder.Addr) + defer rootCAResponder.Shutdown(ctx) + SetOCSPStatus(t, rootCAResponderURL, "configs/certs/ocsp_peer/mini-ca/intermediate1/intermediate1_cert.pem", ocsp.Good) + + respCertPEM := "configs/certs/ocsp_peer/mini-ca/ocsp1/ocsp1_bundle.pem" + respKeyPEM := "configs/certs/ocsp_peer/mini-ca/ocsp1/private/ocsp1_keypair.pem" + issuerCertPEM := "configs/certs/ocsp_peer/mini-ca/intermediate1/intermediate1_cert.pem" + intermediateCA1Responder := NewOCSPResponderBase(t, issuerCertPEM, respCertPEM, respKeyPEM, true, "127.0.0.1:18888", 0, "") + intermediateCA1ResponderURL := fmt.Sprintf("http://%s", intermediateCA1Responder.Addr) + defer intermediateCA1Responder.Shutdown(ctx) + SetOCSPStatus(t, intermediateCA1ResponderURL, "configs/certs/ocsp_peer/mini-ca/client1/UserA1_cert.pem", ocsp.Good) + SetOCSPStatus(t, intermediateCA1ResponderURL, "configs/certs/ocsp_peer/mini-ca/server1/TestServer1_bundle.pem", ocsp.Good) + + for _, test := range []struct { + name string + config string + opts []nats.Option + err error + rerr error + }{ + { + "https with ocsp_peer", + ` + net: 127.0.0.1 + port: -1 + https: -1 + # Short form configuration + ocsp_cache: true + store_dir = %s + tls: { + cert_file: "configs/certs/ocsp_peer/mini-ca/server1/TestServer1_bundle.pem" + key_file: "configs/certs/ocsp_peer/mini-ca/server1/private/TestServer1_keypair.pem" + ca_file: "configs/certs/ocsp_peer/mini-ca/root/root_cert.pem" + timeout: 5 + verify: true + # Long form configuration + ocsp_peer: { + verify: true + ca_timeout: 5 + allowed_clockskew: 30 + } + } + `, + []nats.Option{ + nats.ClientCert("./configs/certs/ocsp_peer/mini-ca/client1/UserA1_bundle.pem", "./configs/certs/ocsp_peer/mini-ca/client1/private/UserA1_keypair.pem"), + nats.RootCAs("./configs/certs/ocsp_peer/mini-ca/root/root_cert.pem"), + nats.ErrorHandler(noOpErrHandler), + }, + nil, + nil, + }, + { + "https with just ocsp", + ` + net: 127.0.0.1 + port: -1 + https: -1 + ocsp { + mode = always + url = http://127.0.0.1:18888 + } + store_dir = %s + + tls: { + cert_file: "configs/certs/ocsp_peer/mini-ca/server1/TestServer1_bundle.pem" + key_file: "configs/certs/ocsp_peer/mini-ca/server1/private/TestServer1_keypair.pem" + ca_file: "configs/certs/ocsp_peer/mini-ca/root/root_cert.pem" + timeout: 5 + verify: true + } + `, + []nats.Option{ + nats.ClientCert("./configs/certs/ocsp_peer/mini-ca/client1/UserA1_bundle.pem", "./configs/certs/ocsp_peer/mini-ca/client1/private/UserA1_keypair.pem"), + nats.RootCAs("./configs/certs/ocsp_peer/mini-ca/root/root_cert.pem"), + nats.ErrorHandler(noOpErrHandler), + }, + nil, + nil, + }, + } { + t.Run(test.name, func(t *testing.T) { + deleteLocalStore(t, "") + content := test.config + conf := createConfFile(t, []byte(fmt.Sprintf(content, t.TempDir()))) + s, opts := RunServerWithConfig(conf) + defer s.Shutdown() + nc, err := nats.Connect(fmt.Sprintf("tls://localhost:%d", opts.Port), test.opts...) + if test.err == nil && err != nil { + t.Errorf("Expected to connect, got %v", err) + } else if test.err != nil && err == nil { + t.Errorf("Expected error on connect") + } else if test.err != nil && err != nil { + // Error on connect was expected + if test.err.Error() != err.Error() { + t.Errorf("Expected error %s, got: %s", test.err, err) + } + return + } + defer nc.Close() + nc.Subscribe("ping", func(m *nats.Msg) { + m.Respond([]byte("pong")) + }) + nc.Flush() + _, err = nc.Request("ping", []byte("ping"), 250*time.Millisecond) + if test.rerr != nil && err == nil { + t.Errorf("Expected error getting response") + } else if test.rerr == nil && err != nil { + t.Errorf("Expected response") + } + + // Make request to the HTTPS port using the client cert. + tlsConfig := &tls.Config{} + clientCertFile := "./configs/certs/ocsp_peer/mini-ca/client1/UserA1_bundle.pem" + clientKeyFile := "./configs/certs/ocsp_peer/mini-ca/client1/private/UserA1_keypair.pem" + cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) + if err != nil { + t.Fatal(err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + caCertFile := "./configs/certs/ocsp_peer/mini-ca/root/root_cert.pem" + caCert, err := os.ReadFile(caCertFile) + if err != nil { + t.Fatal(err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool + + hc := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + resp, err := hc.Get("https://" + s.MonitorAddr().String()) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("Unexpected status: %v", resp.Status) + } + }) + } +}