From ea1df00b91963c690fe94a4fb7a0047cef3b5a14 Mon Sep 17 00:00:00 2001 From: Casper Beyer Date: Wed, 30 Oct 2024 16:00:38 +0100 Subject: [PATCH] Implement strict decoding for JetStream API requests (#5858) This implements optional strict JSON decoding for JetStream. The intent of this is to minimize accidental misalignments between server and clients, we've had numerous of these across the various clients, especially for rarely used fields in the request payloads. Signed-off-by: Casper Beyer --- server/jetstream.go | 10 ++++ server/jetstream_api.go | 71 ++++++++++++++++++--------- server/jetstream_test.go | 101 +++++++++++++++++++++++++++++++++++++++ server/opts.go | 7 +++ server/server.go | 1 + server/test_test.go | 1 + 6 files changed, 169 insertions(+), 22 deletions(-) diff --git a/server/jetstream.go b/server/jetstream.go index fea7157506e..32305121a0d 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -48,6 +48,7 @@ type JetStreamConfig struct { Domain string `json:"domain,omitempty"` CompressOK bool `json:"compress_ok,omitempty"` UniqueTag string `json:"unique_tag,omitempty"` + Strict bool `json:"strict,omitempty"` } // Statistics about JetStream for this server. @@ -462,6 +463,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error { s.Noticef("") } s.Noticef("---------------- JETSTREAM ----------------") + + if cfg.Strict { + s.Noticef(" Strict: %t", cfg.Strict) + } + s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory)) s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore)) s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir) @@ -554,6 +560,7 @@ func (s *Server) restartJetStream() error { MaxMemory: opts.JetStreamMaxMemory, MaxStore: opts.JetStreamMaxStore, Domain: opts.JetStreamDomain, + Strict: opts.JetStreamStrict, } s.Noticef("Restarting JetStream") err := s.EnableJetStream(&cfg) @@ -2527,6 +2534,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je opts := s.getOpts() + // Strict mode. + jsc.Strict = opts.JetStreamStrict + // Sync options. jsc.SyncInterval = opts.SyncInterval jsc.SyncAlways = opts.SyncAlways diff --git a/server/jetstream_api.go b/server/jetstream_api.go index 39500e48602..88b1f9e654b 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "math/rand" "os" "path/filepath" @@ -1164,6 +1165,32 @@ func (s *Server) getRequestInfo(c *client, raw []byte) (pci *ClientInfo, acc *Ac return &ci, acc, hdr, msg, nil } +func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg []byte, v any) error { + decoder := json.NewDecoder(bytes.NewReader(msg)) + decoder.DisallowUnknownFields() + + for { + if err := decoder.Decode(v); err != nil { + if err == io.EOF { + return nil + } + + var syntaxErr *json.SyntaxError + if errors.As(err, &syntaxErr) { + err = fmt.Errorf("%w at offset %d", err, syntaxErr.Offset) + } + + c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err) + + if s.JetStreamConfig().Strict { + return err + } + + return json.Unmarshal(msg, v) + } + } +} + func (a *Account) trackAPI() { a.mu.RLock() jsa := a.js @@ -1293,7 +1320,7 @@ func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Accoun } var cfg StreamTemplateConfig - if err := json.Unmarshal(msg, &cfg); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -1350,7 +1377,7 @@ func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account var offset int if isJSONObjectOrArray(msg) { var req JSApiStreamTemplatesRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -1533,7 +1560,7 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account, } var cfg StreamConfigRequest - if err := json.Unmarshal(msg, &cfg); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -1644,7 +1671,7 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account, return } var ncfg StreamConfigRequest - if err := json.Unmarshal(msg, &ncfg); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &ncfg); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -1743,7 +1770,7 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account, if isJSONObjectOrArray(msg) { var req JSApiStreamNamesRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -1873,7 +1900,7 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s if isJSONObjectOrArray(msg) { var req JSApiStreamListRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -2043,7 +2070,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s var offset int if isJSONObjectOrArray(msg) { var req JSApiStreamInfoRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -2400,7 +2427,7 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco } var req JSApiStreamRemovePeerRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -2480,7 +2507,7 @@ func (s *Server) jsLeaderServerRemoveRequest(sub *subscription, c *client, _ *Ac } var req JSApiMetaServerRemoveRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -2583,7 +2610,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _ var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}} var req JSApiMetaServerStreamMoveRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -2610,7 +2637,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _ if ok { sa, ok := streams[streamName] if ok { - cfg = *sa.Config.clone() + cfg = *sa.Config streamFound = true currPeers = sa.Group.Peers currCluster = sa.Group.Cluster @@ -2752,7 +2779,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli if ok { sa, ok := streams[streamName] if ok { - cfg = *sa.Config.clone() + cfg = *sa.Config streamFound = true currPeers = sa.Group.Peers } @@ -2933,7 +2960,7 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun if isJSONObjectOrArray(msg) { var req JSApiLeaderStepdownRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3160,7 +3187,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su return } var req JSApiMsgDeleteRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3279,7 +3306,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje return } var req JSApiMsgGetRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3422,7 +3449,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account, var purgeRequest *JSApiStreamPurgeRequest if isJSONObjectOrArray(msg) { var req JSApiStreamPurgeRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3512,7 +3539,7 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account } var req JSApiStreamRestoreRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3815,7 +3842,7 @@ func (s *Server) jsStreamSnapshotRequest(sub *subscription, c *client, _ *Accoun } var req JSApiStreamSnapshotRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, smsg, s.jsonResponse(&resp)) return @@ -4013,7 +4040,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}} var req CreateConsumerRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -4255,7 +4282,7 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account var offset int if isJSONObjectOrArray(msg) { var req JSApiConsumersRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -4377,7 +4404,7 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account, var offset int if isJSONObjectOrArray(msg) { var req JSApiConsumersRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -4688,7 +4715,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account var resp = JSApiConsumerPauseResponse{ApiResponse: ApiResponse{Type: JSApiConsumerPauseResponseType}} if isJSONObjectOrArray(msg) { - if err := json.Unmarshal(msg, &req); err != nil { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return diff --git a/server/jetstream_test.go b/server/jetstream_test.go index c3723c4a370..cf46f9c6fa0 100644 --- a/server/jetstream_test.go +++ b/server/jetstream_test.go @@ -24127,6 +24127,107 @@ func TestJetStreamStreamCreatePedanticMode(t *testing.T) { } } +func TestJetStreamStrictMode(t *testing.T) { + cfgFmt := []byte(fmt.Sprintf(` + jetstream: { + strict: true + enabled: true + max_file_store: 100MB + store_dir: %s + limits: {duplicate_window: "1m", max_request_batch: 250} + } + `, t.TempDir())) + conf := createConfFile(t, cfgFmt) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Error connecting to NATS: %v", err) + } + defer nc.Close() + + tests := []struct { + name string + subject string + payload []byte + expectedErr string + }{ + { + name: "Stream Create", + subject: "$JS.API.STREAM.CREATE.TEST_STREAM", + payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + { + name: "Stream Update", + subject: "$JS.API.STREAM.UPDATE.TEST_STREAM", + payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + { + name: "Stream Delete", + subject: "$JS.API.STREAM.DELETE.TEST_STREAM", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "expected an empty request payload", + }, + { + name: "Stream Info", + subject: "$JS.API.STREAM.INFO.TEST_STREAM", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + { + name: "Consumer Create", + subject: "$JS.API.CONSUMER.CREATE.TEST_STREAM.TEST_CONSUMER", + payload: []byte(`{"durable_name":"TEST_CONSUMER","ack_policy":"explicit","extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + { + name: "Consumer Delete", + subject: "$JS.API.CONSUMER.DELETE.TEST_STREAM.TEST_CONSUMER", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "expected an empty request payload", + }, + { + name: "Consumer Info", + subject: "$JS.API.CONSUMER.INFO.TEST_STREAM.TEST_CONSUMER", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "expected an empty request payload", + }, + { + name: "Stream List", + subject: "$JS.API.STREAM.LIST", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + { + name: "Consumer List", + subject: "$JS.API.CONSUMER.LIST.TEST_STREAM", + payload: []byte(`{"extra_field":"unexpected"}`), + expectedErr: "invalid JSON", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := nc.Request(tt.subject, tt.payload, time.Second*10) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + + var apiResp ApiResponse = ApiResponse{} + + if err := json.Unmarshal(resp.Data, &apiResp); err != nil { + t.Fatalf("Error unmarshalling response: %v", err) + } + + require_NotNil(t, apiResp.Error.Description) + require_Contains(t, apiResp.Error.Description, tt.expectedErr) + }) + } +} + func addConsumerWithError(t *testing.T, nc *nats.Conn, cfg *CreateConsumerRequest) (*ConsumerInfo, *ApiError) { t.Helper() req, err := json.Marshal(cfg) diff --git a/server/opts.go b/server/opts.go index c80866115b2..e5794040840 100644 --- a/server/opts.go +++ b/server/opts.go @@ -324,6 +324,7 @@ type Options struct { Gateway GatewayOpts `json:"gateway,omitempty"` LeafNode LeafNodeOpts `json:"leaf,omitempty"` JetStream bool `json:"jetstream"` + JetStreamStrict bool `json:"-"` JetStreamMaxMemory int64 `json:"-"` JetStreamMaxStore int64 `json:"-"` JetStreamDomain string `json:"-"` @@ -2353,6 +2354,12 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er for mk, mv := range vv { tk, mv = unwrapValue(mv, <) switch strings.ToLower(mk) { + case "strict": + if v, ok := mv.(bool); ok { + opts.JetStreamStrict = v + } else { + return &configErr{tk, fmt.Sprintf("Expected 'true' or 'false' for bool value, got '%s'", mv)} + } case "store", "store_dir", "storedir": // StoreDir can be set at the top level as well so have to prevent ambiguous declarations. if opts.StoreDir != _EMPTY_ { diff --git a/server/server.go b/server/server.go index f7e0d2b007b..7fa37ce201e 100644 --- a/server/server.go +++ b/server/server.go @@ -2358,6 +2358,7 @@ func (s *Server) Start() { StoreDir: opts.StoreDir, SyncInterval: opts.SyncInterval, SyncAlways: opts.SyncAlways, + Strict: opts.JetStreamStrict, MaxMemory: opts.JetStreamMaxMemory, MaxStore: opts.JetStreamMaxStore, Domain: opts.JetStreamDomain, diff --git a/server/test_test.go b/server/test_test.go index 1dc3a4e90ab..f3003b1edb2 100644 --- a/server/test_test.go +++ b/server/test_test.go @@ -32,6 +32,7 @@ var DefaultTestOptions = Options{ NoSigs: true, MaxControlLine: 4096, DisableShortFirstPing: true, + JetStreamStrict: true, } func testDefaultClusterOptionsForLeafNodes() *Options {