Skip to content

Commit

Permalink
Implement strict decoding for JetStream API requests (#5858)
Browse files Browse the repository at this point in the history
This implements optional strict JSON decoding for JetStream.

The intent of this is to minimize accidental misalignments between
server and clients, we've had numerous of these across the various
clients, especially for rarely used fields in the request payloads.

Signed-off-by: Casper Beyer <[email protected]>
  • Loading branch information
caspervonb authored Oct 30, 2024
1 parent e339076 commit ea1df00
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 22 deletions.
10 changes: 10 additions & 0 deletions server/jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type JetStreamConfig struct {
Domain string `json:"domain,omitempty"`
CompressOK bool `json:"compress_ok,omitempty"`
UniqueTag string `json:"unique_tag,omitempty"`
Strict bool `json:"strict,omitempty"`
}

// Statistics about JetStream for this server.
Expand Down Expand Up @@ -462,6 +463,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
s.Noticef("")
}
s.Noticef("---------------- JETSTREAM ----------------")

if cfg.Strict {
s.Noticef(" Strict: %t", cfg.Strict)
}

s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory))
s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore))
s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir)
Expand Down Expand Up @@ -554,6 +560,7 @@ func (s *Server) restartJetStream() error {
MaxMemory: opts.JetStreamMaxMemory,
MaxStore: opts.JetStreamMaxStore,
Domain: opts.JetStreamDomain,
Strict: opts.JetStreamStrict,
}
s.Noticef("Restarting JetStream")
err := s.EnableJetStream(&cfg)
Expand Down Expand Up @@ -2527,6 +2534,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je

opts := s.getOpts()

// Strict mode.
jsc.Strict = opts.JetStreamStrict

// Sync options.
jsc.SyncInterval = opts.SyncInterval
jsc.SyncAlways = opts.SyncAlways
Expand Down
71 changes: 49 additions & 22 deletions server/jetstream_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
Expand Down Expand Up @@ -1164,6 +1165,32 @@ func (s *Server) getRequestInfo(c *client, raw []byte) (pci *ClientInfo, acc *Ac
return &ci, acc, hdr, msg, nil
}

func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg []byte, v any) error {
decoder := json.NewDecoder(bytes.NewReader(msg))
decoder.DisallowUnknownFields()

for {
if err := decoder.Decode(v); err != nil {
if err == io.EOF {
return nil
}

var syntaxErr *json.SyntaxError
if errors.As(err, &syntaxErr) {
err = fmt.Errorf("%w at offset %d", err, syntaxErr.Offset)
}

c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err)

if s.JetStreamConfig().Strict {
return err
}

return json.Unmarshal(msg, v)
}
}
}

func (a *Account) trackAPI() {
a.mu.RLock()
jsa := a.js
Expand Down Expand Up @@ -1293,7 +1320,7 @@ func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Accoun
}

var cfg StreamTemplateConfig
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1350,7 +1377,7 @@ func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiStreamTemplatesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1533,7 +1560,7 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account,
}

var cfg StreamConfigRequest
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1644,7 +1671,7 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account,
return
}
var ncfg StreamConfigRequest
if err := json.Unmarshal(msg, &ncfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &ncfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1743,7 +1770,7 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account,

if isJSONObjectOrArray(msg) {
var req JSApiStreamNamesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1873,7 +1900,7 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s

if isJSONObjectOrArray(msg) {
var req JSApiStreamListRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2043,7 +2070,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiStreamInfoRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2400,7 +2427,7 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco
}

var req JSApiStreamRemovePeerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2480,7 +2507,7 @@ func (s *Server) jsLeaderServerRemoveRequest(sub *subscription, c *client, _ *Ac
}

var req JSApiMetaServerRemoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2583,7 +2610,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}}

var req JSApiMetaServerStreamMoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand All @@ -2610,7 +2637,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
currCluster = sa.Group.Cluster
Expand Down Expand Up @@ -2752,7 +2779,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
}
Expand Down Expand Up @@ -2933,7 +2960,7 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun

if isJSONObjectOrArray(msg) {
var req JSApiLeaderStepdownRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3160,7 +3187,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su
return
}
var req JSApiMsgDeleteRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3279,7 +3306,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje
return
}
var req JSApiMsgGetRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3422,7 +3449,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account,
var purgeRequest *JSApiStreamPurgeRequest
if isJSONObjectOrArray(msg) {
var req JSApiStreamPurgeRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3512,7 +3539,7 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
}

var req JSApiStreamRestoreRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3815,7 +3842,7 @@ func (s *Server) jsStreamSnapshotRequest(sub *subscription, c *client, _ *Accoun
}

var req JSApiStreamSnapshotRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, smsg, s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4013,7 +4040,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}}

var req CreateConsumerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4255,7 +4282,7 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4377,7 +4404,7 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account,
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4688,7 +4715,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
var resp = JSApiConsumerPauseResponse{ApiResponse: ApiResponse{Type: JSApiConsumerPauseResponseType}}

if isJSONObjectOrArray(msg) {
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down
101 changes: 101 additions & 0 deletions server/jetstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24127,6 +24127,107 @@ func TestJetStreamStreamCreatePedanticMode(t *testing.T) {
}
}

func TestJetStreamStrictMode(t *testing.T) {
cfgFmt := []byte(fmt.Sprintf(`
jetstream: {
strict: true
enabled: true
max_file_store: 100MB
store_dir: %s
limits: {duplicate_window: "1m", max_request_batch: 250}
}
`, t.TempDir()))
conf := createConfFile(t, cfgFmt)
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to NATS: %v", err)
}
defer nc.Close()

tests := []struct {
name string
subject string
payload []byte
expectedErr string
}{
{
name: "Stream Create",
subject: "$JS.API.STREAM.CREATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Update",
subject: "$JS.API.STREAM.UPDATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Delete",
subject: "$JS.API.STREAM.DELETE.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream Info",
subject: "$JS.API.STREAM.INFO.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Create",
subject: "$JS.API.CONSUMER.CREATE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"durable_name":"TEST_CONSUMER","ack_policy":"explicit","extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Delete",
subject: "$JS.API.CONSUMER.DELETE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Consumer Info",
subject: "$JS.API.CONSUMER.INFO.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream List",
subject: "$JS.API.STREAM.LIST",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer List",
subject: "$JS.API.CONSUMER.LIST.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := nc.Request(tt.subject, tt.payload, time.Second*10)
if err != nil {
t.Fatalf("Request failed: %v", err)
}

var apiResp ApiResponse = ApiResponse{}

if err := json.Unmarshal(resp.Data, &apiResp); err != nil {
t.Fatalf("Error unmarshalling response: %v", err)
}

require_NotNil(t, apiResp.Error.Description)
require_Contains(t, apiResp.Error.Description, tt.expectedErr)
})
}
}

func addConsumerWithError(t *testing.T, nc *nats.Conn, cfg *CreateConsumerRequest) (*ConsumerInfo, *ApiError) {
t.Helper()
req, err := json.Marshal(cfg)
Expand Down
Loading

0 comments on commit ea1df00

Please sign in to comment.