diff --git a/go/cmd/vtctldclient/command/routing_rules.go b/go/cmd/vtctldclient/command/routing_rules.go index 0ffee0c2c24..8a228589098 100644 --- a/go/cmd/vtctldclient/command/routing_rules.go +++ b/go/cmd/vtctldclient/command/routing_rules.go @@ -82,7 +82,7 @@ func commandApplyRoutingRules(cmd *cobra.Command, args []string) error { } rr := &vschemapb.RoutingRules{} - if err := json2.Unmarshal(rulesBytes, &rr); err != nil { + if err := json2.UnmarshalPB(rulesBytes, rr); err != nil { return err } diff --git a/go/cmd/vtctldclient/command/shard_routing_rules.go b/go/cmd/vtctldclient/command/shard_routing_rules.go index 10ce7e81747..2214269d0f3 100644 --- a/go/cmd/vtctldclient/command/shard_routing_rules.go +++ b/go/cmd/vtctldclient/command/shard_routing_rules.go @@ -87,7 +87,7 @@ func commandApplyShardRoutingRules(cmd *cobra.Command, args []string) error { } srr := &vschemapb.ShardRoutingRules{} - if err := json2.Unmarshal(rulesBytes, &srr); err != nil { + if err := json2.UnmarshalPB(rulesBytes, srr); err != nil { return err } // Round-trip so when we display the result it's readable. diff --git a/go/json2/unmarshal.go b/go/json2/unmarshal.go index e382b8ad47a..e2034fa71c9 100644 --- a/go/json2/unmarshal.go +++ b/go/json2/unmarshal.go @@ -33,8 +33,7 @@ var carriageReturn = []byte("\n") // efficient and should not be used for high QPS operations. func Unmarshal(data []byte, v any) error { if pb, ok := v.(proto.Message); ok { - opts := protojson.UnmarshalOptions{DiscardUnknown: true} - return annotate(data, opts.Unmarshal(data, pb)) + return UnmarshalPB(data, pb) } return annotate(data, json.Unmarshal(data, v)) } @@ -53,3 +52,9 @@ func annotate(data []byte, err error) error { return fmt.Errorf("line: %d, position %d: %v", line, pos, err) } + +// UnmarshalPB is similar to Unmarshal but specifically for proto.Message to add type safety. +func UnmarshalPB(data []byte, pb proto.Message) error { + opts := protojson.UnmarshalOptions{DiscardUnknown: true} + return annotate(data, opts.Unmarshal(data, pb)) +} diff --git a/go/json2/unmarshal_test.go b/go/json2/unmarshal_test.go index 9b6a6af1ca2..e46c7c6e123 100644 --- a/go/json2/unmarshal_test.go +++ b/go/json2/unmarshal_test.go @@ -18,6 +18,10 @@ package json2 import ( "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/emptypb" ) func TestUnmarshal(t *testing.T) { @@ -48,3 +52,14 @@ func TestUnmarshal(t *testing.T) { } } } + +func TestUnmarshalPB(t *testing.T) { + want := &emptypb.Empty{} + json, err := protojson.Marshal(want) + require.NoError(t, err) + + var got emptypb.Empty + err = UnmarshalPB(json, &got) + require.NoError(t, err) + require.Equal(t, want, &got) +}