diff --git a/m2m-oauth-server/pb/tags.go b/m2m-oauth-server/pb/tags.go index 54d87687d..f77f1d074 100644 --- a/m2m-oauth-server/pb/tags.go +++ b/m2m-oauth-server/pb/tags.go @@ -8,4 +8,5 @@ const ( BlackListedKey = "blacklisted" TimestampKey = "timestamp" AudienceKey = "audience" + IssuedAtKey = "issuedAt" ) diff --git a/m2m-oauth-server/pb/token.go b/m2m-oauth-server/pb/token.go index 94d9bebda..82174ce8c 100644 --- a/m2m-oauth-server/pb/token.go +++ b/m2m-oauth-server/pb/token.go @@ -1,13 +1,10 @@ package pb import ( - "encoding/json" "errors" "fmt" - "strconv" - "github.com/hashicorp/go-multierror" - "google.golang.org/protobuf/encoding/protojson" + pkgMongo "github.com/plgd-dev/hub/v2/pkg/mongodb" ) var errTokenIsNil = errors.New("Token is nil") @@ -31,160 +28,58 @@ func (x *Token) Validate() error { return nil } -func (x *Token) ToMap() (map[string]interface{}, error) { - v := protojson.MarshalOptions{ - AllowPartial: true, - EmitUnpopulated: true, +func (x *Token) jsonToBSONTag(json map[string]interface{}) error { + json["_id"] = x.GetId() + delete(json, "id") + if _, err := pkgMongo.ConvertStringValueToInt64(json, false, "."+IssuedAtKey); err != nil { + return fmt.Errorf("cannot convert issueAt to int64: %w", err) } - data, err := v.Marshal(x) - if err != nil { - return nil, err - } - var m map[string]interface{} - err = json.Unmarshal(data, &m) - if err != nil { - return nil, err - } - return m, nil -} - -func replaceStrToInt64(m map[string]interface{}, keys ...string) error { - var errs *multierror.Error - for _, k := range keys { - exp, ok := m[k] - if ok { - str, ok := exp.(string) - if ok { - i, err := strconv.ParseInt(str, 10, 64) - if err != nil { - errs = multierror.Append(errs, fmt.Errorf("cannot convert key %v to int64, %w", k, err)) - } else { - m[k] = i - } - } - } - } - return errs.ErrorOrNil() -} - -func replaceInt64ToStr(m map[string]interface{}, keys ...string) { - for _, k := range keys { - exp, ok := m[k] - if ok { - i, ok := exp.(int64) - if ok { - m[k] = strconv.FormatInt(i, 10) - } - } - } -} - -func (x *Token) ToBsonMap() (map[string]interface{}, error) { - m, err := x.ToMap() - if err != nil { - return nil, err - } - m["_id"] = x.GetId() - delete(m, "id") - err = replaceStrToInt64(m, ExpirationKey, TimestampKey) - if err != nil { - return nil, err + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+ExpirationKey); err != nil { + return fmt.Errorf("cannot convert expiration to int64: %w", err) } - blackListed, ok := m[BlackListedKey] - if ok { - mapBlacklisted, ok := blackListed.(map[string]interface{}) - if ok { - err = replaceStrToInt64(mapBlacklisted, TimestampKey) - if err != nil { - return nil, err - } - } + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+BlackListedKey+"."+TimestampKey); err != nil { + return fmt.Errorf("cannot convert blacklisted.timestamp to int64: %w", err) } - return m, nil + return nil } -func (x *Token) FromMap(m map[string]interface{}) error { +func (x *Token) MarshalBSON() ([]byte, error) { if x == nil { - return errTokenIsNil + return nil, errTokenIsNil } - data, err := json.Marshal(m) - if err != nil { - return err - } - v := protojson.UnmarshalOptions{ - AllowPartial: true, - DiscardUnknown: true, - } - return v.Unmarshal(data, x) + return pkgMongo.MarshalProtoBSON(x, x.jsonToBSONTag) } -func (x *Token) FromBsonMap(m map[string]interface{}) error { +func (x *Token) UnmarshalBSON(data []byte) error { if x == nil { return errTokenIsNil } - m["id"] = m["_id"] - delete(m, "_id") - - replaceInt64ToStr(m, ExpirationKey, TimestampKey) - blackListed, ok := m[BlackListedKey] - if ok { - mapBlacklisted, ok := blackListed.(map[string]interface{}) + var id string + update := func(json map[string]interface{}) error { + idI, ok := json["_id"] if ok { - replaceInt64ToStr(mapBlacklisted, TimestampKey) + id = idI.(string) } + delete(json, "_id") + return nil } - - return x.FromMap(m) -} - -func (x *Token_BlackListed) ToMap() (map[string]interface{}, error) { - v := protojson.MarshalOptions{ - AllowPartial: true, - EmitUnpopulated: true, - } - data, err := v.Marshal(x) - if err != nil { - return nil, err - } - var m map[string]interface{} - err = json.Unmarshal(data, &m) - if err != nil { - return nil, err - } - return m, nil -} - -func (x *Token_BlackListed) FromMap(m map[string]interface{}) error { - if x == nil { - return errors.New("Token_BlackListed is nil") - } - data, err := json.Marshal(m) + err := pkgMongo.UnmarshalProtoBSON(data, x, update) if err != nil { return err } - v := protojson.UnmarshalOptions{ - AllowPartial: true, - DiscardUnknown: true, + if x.GetId() == "" && id != "" { + x.Id = id } - return v.Unmarshal(data, x) + return nil } -func (x *Token_BlackListed) ToBsonMap() (map[string]interface{}, error) { - m, err := x.ToMap() - if err != nil { - return nil, err +func (x *Token_BlackListed) jsonToBSONTag(json map[string]interface{}) error { + if _, err := pkgMongo.ConvertStringValueToInt64(json, false, "."+TimestampKey); err != nil { + return fmt.Errorf("cannot convert timestamp to int64: %w", err) } - err = replaceStrToInt64(m, TimestampKey) - if err != nil { - return nil, err - } - return m, nil + return nil } -func (x *Token_BlackListed) FromBsonMap(m map[string]interface{}) error { - if x == nil { - return errors.New("Token_BlackListed is nil") - } - replaceInt64ToStr(m, TimestampKey) - return x.FromMap(m) +func (x *Token_BlackListed) MarshalBSON() ([]byte, error) { + return pkgMongo.MarshalProtoBSON(x, x.jsonToBSONTag) } diff --git a/m2m-oauth-server/service/http/postToken_test.go b/m2m-oauth-server/service/http/postToken_test.go index adf229d80..9e005c9bb 100644 --- a/m2m-oauth-server/service/http/postToken_test.go +++ b/m2m-oauth-server/service/http/postToken_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "testing" + "time" "github.com/golang-jwt/jwt/v5" oauthsigner "github.com/plgd-dev/hub/v2/m2m-oauth-server/oauthSigner" @@ -81,6 +82,22 @@ func TestPostToken(t *testing.T) { existOriginalTokenClaims: true, }, }, + { + name: "ownerToken with expiration- JWT", + args: m2mOauthServerTest.AccessTokenOptions{ + Ctx: context.Background(), + ClientID: m2mOauthServerTest.JWTPrivateKeyOAuthClient.ID, + GrantType: string(oauthsigner.GrantTypeClientCredentials), + Host: config.M2M_OAUTH_SERVER_HTTP_HOST, + JWT: token, + Expiration: time.Now().Add(time.Hour), + }, + wantCode: http.StatusOK, + want: want{ + owner: "1", + existOriginalTokenClaims: true, + }, + }, { name: "invalid client", args: m2mOauthServerTest.AccessTokenOptions{ @@ -103,6 +120,18 @@ func TestPostToken(t *testing.T) { }, wantCode: http.StatusUnauthorized, }, + { + name: "invalid expiration", + args: m2mOauthServerTest.AccessTokenOptions{ + Ctx: context.Background(), + ClientID: m2mOauthServerTest.JWTPrivateKeyOAuthClient.ID, + GrantType: string(oauthsigner.GrantTypeClientCredentials), + Host: config.M2M_OAUTH_SERVER_HTTP_HOST, + JWT: token, + Expiration: time.Now().Add(-time.Hour), + }, + wantCode: http.StatusUnauthorized, + }, } webTearDown := m2mOauthServerTest.SetUp(t) diff --git a/m2m-oauth-server/service/service.go b/m2m-oauth-server/service/service.go index 93ea4346d..139d5dca3 100644 --- a/m2m-oauth-server/service/service.go +++ b/m2m-oauth-server/service/service.go @@ -49,6 +49,7 @@ func createStore(ctx context.Context, config storeConfig.Config, fileWatcher *fs } }) if err2 != nil { + s.Close(ctx) return nil, fmt.Errorf("cannot create scheduler: %w", err2) } s.AddCloseFunc(func() { diff --git a/m2m-oauth-server/store/mongodb/tokens.go b/m2m-oauth-server/store/mongodb/tokens.go index 87eeea0da..08c88344e 100644 --- a/m2m-oauth-server/store/mongodb/tokens.go +++ b/m2m-oauth-server/store/mongodb/tokens.go @@ -28,11 +28,7 @@ func (s *Store) CreateToken(ctx context.Context, owner string, token *pb.Token) if err != nil { return nil, err } - m, err := token.ToBsonMap() - if err != nil { - return nil, err - } - _, err = s.Store.Collection(tokensCol).InsertOne(ctx, m) + _, err = s.Store.Collection(tokensCol).InsertOne(ctx, token) if err != nil { return nil, err } @@ -158,15 +154,11 @@ func (s *Store) BlacklistTokens(ctx context.Context, owner string, req *pb.Black Flag: true, Timestamp: time.Now().Unix(), } - value, err := blacklisted.ToBsonMap() - if err != nil { - return nil, err - } update := bson.D{ { Key: mongodb.Set, Value: bson.M{ - pb.BlackListedKey: value, + pb.BlackListedKey: &blacklisted, }, }, } diff --git a/m2m-oauth-server/store/mongodb/tokens_test.go b/m2m-oauth-server/store/mongodb/tokens_test.go index c5e2f3990..f8ecb88c5 100644 --- a/m2m-oauth-server/store/mongodb/tokens_test.go +++ b/m2m-oauth-server/store/mongodb/tokens_test.go @@ -18,16 +18,19 @@ func TestGetTokens(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), config.TEST_TIMEOUT) defer cancel() + expiration := time.Now().Add(time.Minute * 10).Unix() + // Set the owner and request parameters owner := "testOwner" tokens := []*pb.Token{ { - Id: "token1", - Owner: owner, - Version: 0, - Name: "name1", - IssuedAt: time.Now().Unix(), - ClientId: "client1", + Id: "token1", + Owner: owner, + Version: 0, + Name: "name1", + IssuedAt: time.Now().Unix(), + ClientId: "client1", + Expiration: expiration, }, { Id: "token2", @@ -50,9 +53,9 @@ func TestGetTokens(t *testing.T) { } tests := []struct { - name string - args args - wantLen int + name string + args args + want []*pb.Token }{ { name: "all tokens", @@ -61,7 +64,9 @@ func TestGetTokens(t *testing.T) { owner: owner, req: &pb.GetTokensRequest{}, }, - wantLen: 1, + want: []*pb.Token{ + tokens[0], + }, }, { name: "all tokens including blacklisted", @@ -72,7 +77,7 @@ func TestGetTokens(t *testing.T) { IncludeBlacklisted: true, }, }, - wantLen: 2, + want: tokens, }, { name: "certain token", @@ -84,7 +89,9 @@ func TestGetTokens(t *testing.T) { IncludeBlacklisted: true, }, }, - wantLen: 1, + want: []*pb.Token{ + tokens[1], + }, }, { name: "all tokens another owner", @@ -93,7 +100,7 @@ func TestGetTokens(t *testing.T) { owner: "anotherOwner", req: &pb.GetTokensRequest{}, }, - wantLen: 0, + want: nil, }, } @@ -114,7 +121,18 @@ func TestGetTokens(t *testing.T) { // Call the GetTokens method err := s.GetTokens(tt.args.ctx, tt.args.owner, tt.args.req, process) require.NoError(t, err) - require.Len(t, result, tt.wantLen) + require.Len(t, result, len(tt.want)) + for _, token := range tt.want { + require.Contains(t, result, token.GetId()) + require.Equal(t, token.GetExpiration(), result[token.GetId()].GetExpiration()) + require.Equal(t, token.GetIssuedAt(), result[token.GetId()].GetIssuedAt()) + require.Equal(t, token.GetClientId(), result[token.GetId()].GetClientId()) + require.Equal(t, token.GetOwner(), result[token.GetId()].GetOwner()) + require.Equal(t, token.GetVersion(), result[token.GetId()].GetVersion()) + require.Equal(t, token.GetName(), result[token.GetId()].GetName()) + require.Equal(t, token.GetBlacklisted().GetFlag(), result[token.GetId()].GetBlacklisted().GetFlag()) + require.Equal(t, token.GetBlacklisted().GetTimestamp(), result[token.GetId()].GetBlacklisted().GetTimestamp()) + } }) } } diff --git a/m2m-oauth-server/store/store.go b/m2m-oauth-server/store/store.go index e88a01e5f..60015a071 100644 --- a/m2m-oauth-server/store/store.go +++ b/m2m-oauth-server/store/store.go @@ -27,10 +27,6 @@ var ( ErrPartialDelete = errors.New("some errors occurred while deleting") ) -type BsonMapper interface { - FromBsonMap(m map[string]interface{}) error -} - type MongoIterator[T any] struct { Cursor *mongo.Cursor } @@ -39,15 +35,6 @@ func (i *MongoIterator[T]) Next(ctx context.Context, s *T) bool { if !i.Cursor.Next(ctx) { return false } - var tmp interface{} = s - if tmp, ok := tmp.(BsonMapper); ok { - var mapValue map[string]interface{} - err := i.Cursor.Decode(&mapValue) - if err == nil { - err = tmp.FromBsonMap(mapValue) - } - return err == nil - } err := i.Cursor.Decode(s) return err == nil } diff --git a/m2m-oauth-server/test/test.go b/m2m-oauth-server/test/test.go index a412c8c6e..2cc2b08c1 100644 --- a/m2m-oauth-server/test/test.go +++ b/m2m-oauth-server/test/test.go @@ -142,6 +142,7 @@ type AccessTokenOptions struct { Audience string JWT string PostForm bool + Expiration time.Time Ctx context.Context } @@ -193,6 +194,12 @@ func WithPostFrom(enabled bool) func(opts *AccessTokenOptions) { } } +func WithExpiration(expiration time.Time) func(opts *AccessTokenOptions) { + return func(opts *AccessTokenOptions) { + opts.Expiration = expiration + } +} + func WithContext(ctx context.Context) func(opts *AccessTokenOptions) { return func(opts *AccessTokenOptions) { opts.Ctx = ctx @@ -229,6 +236,9 @@ func GetAccessToken(t *testing.T, expectedCode int, opts ...func(opts *AccessTok reqBody[uri.ClientAssertionKey] = options.JWT reqBody[uri.ClientAssertionTypeKey] = uri.ClientAssertionTypeJWT } + if !options.Expiration.IsZero() { + reqBody[uri.ExpirationKey] = options.Expiration.Unix() + } var data []byte if options.PostForm { data = []byte(mapToURLValues(reqBody).Encode()) diff --git a/pkg/mongodb/marshal.go b/pkg/mongodb/marshal.go index 54dbd8acb..8ae33d21f 100644 --- a/pkg/mongodb/marshal.go +++ b/pkg/mongodb/marshal.go @@ -2,68 +2,230 @@ package mongodb import ( "encoding/json" + "errors" + "fmt" + "regexp" "strconv" "strings" + "github.com/hashicorp/go-multierror" "go.mongodb.org/mongo-driver/bson" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) -type updateJSON = func(map[string]interface{}) +type updateJSON = func(map[string]any) error -func ConvertStringValueToInt64(json map[string]interface{}, path string) { - pos := strings.Index(path, ".") - if pos == -1 { - valueI, ok := json[path] - if !ok { - return - } - valueStr, ok := valueI.(string) - if !ok { - return +var ErrPathNotFound = errors.New("path not found") + +// ConvertStringValueToInt64 converts string values to int64 in a JSON map based on provided paths. +// It iterates over the specified paths in the JSON map and converts the string values found at those paths to int64 values. +// If permitMissingPaths is set to true, missing paths in the JSON map will be ignored and the modified JSON map will be returned. +// If permitMissingPaths is set to false, an error will be returned if any of the specified paths are not found in the JSON map. +// The function returns the updated JSON map with the converted int64 values. +// If an error occurs during the conversion, the partially modified JSON map is returned along with the error. +func ConvertStringValueToInt64(jsonMap any, permitMissingPaths bool, paths ...string) (any, error) { + for _, path := range paths { + newMap, err := convertPath(jsonMap, permitMissingPaths, path) + if err != nil { + return jsonMap, err } - value, err := strconv.ParseInt(valueStr, 10, 64) + jsonMap = newMap + } + return jsonMap, nil +} + +func handleSlice(slice []any, permitMissingPaths bool, remainingParts []string) ([]any, error) { + var ( + parents []any + errs *multierror.Error + ) + + for _, item := range slice { + p, err := findParents(item, permitMissingPaths, remainingParts) if err != nil { - return + errs = multierror.Append(errs, err) + } else if p != nil { + parents = append(parents, p...) } - json[path] = value - return } - elemPath := path[:pos] - elem, ok := json[elemPath] - if !ok { - return - } - elemArray, ok := elem.([]interface{}) - if ok { - for i, elem := range elemArray { - elemMap, ok2 := elem.(map[string]interface{}) - if !ok2 { - continue + if len(parents) == 0 { + return nil, errs.ErrorOrNil() + } + + return parents, errs.ErrorOrNil() +} + +func findParents(current any, permitMissingPaths bool, parts []string) ([]any, error) { + for idx, part := range parts { + if part == "" { + continue + } + switch curr := current.(type) { + case map[string]any: + if value, exists := curr[part]; exists { + current = value + } else if permitMissingPaths { + return nil, nil + } else { + return nil, fmt.Errorf("path segment %s: %w", part, ErrPathNotFound) + } + case []any: + if part == "" || part == "*" { + return handleSlice(curr, permitMissingPaths, parts[idx+1:]) + } + index, err := strconv.Atoi(part) + if err != nil { + return nil, fmt.Errorf("invalid array index %s", part) + } + if index < 0 || index >= len(curr) { + if permitMissingPaths { + return nil, nil + } + return nil, fmt.Errorf("index out of range %d: %w", index, ErrPathNotFound) } - ConvertStringValueToInt64(elemMap, path[pos+1:]) - elemArray[i] = elemMap + current = curr[index] + default: + return nil, fmt.Errorf("unsupported type %T at path segment %s", current, part) } - json[elemPath] = elemArray - return } - elemMap, ok := elem.(map[string]interface{}) + return []any{current}, nil +} + +var splitPathRE = regexp.MustCompile(`\.\[|\]\.|\.|\[|\]`) + +func splitPath(path string) []string { + parts := splitPathRE.Split(path, -1) + var cleanParts []string + for _, part := range parts { + if part != "" { + cleanParts = append(cleanParts, part) + } + } + return cleanParts +} + +func setMap(data any, permitMissingPaths bool, path string, parent map[string]any, lastPart string) (out any, err error) { + value, exists := parent[lastPart] + if !exists { + if permitMissingPaths { + return data, nil + } + return data, fmt.Errorf("path %s: %w", path, ErrPathNotFound) + } + strVal, ok := value.(string) if !ok { - return + return data, fmt.Errorf("expected string at path %s, but found %T", path, value) + } + intVal, err := strconv.ParseInt(strVal, 10, 64) + if err != nil { + return data, fmt.Errorf("error converting string to int64 at path %s: %w", path, err) } - ConvertStringValueToInt64(elemMap, path[pos+1:]) - json[elemPath] = elemMap + parent[lastPart] = intVal + + return data, nil +} + +func setSliceValue(data any, permitMissingPaths bool, path string, parent []any, index int) (out any, err error) { + if index < 0 || index >= len(parent) { + if permitMissingPaths { + return data, nil + } + return data, fmt.Errorf("index out of range %d", index) + } + if value, ok := parent[index].(string); ok { + intVal, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return data, fmt.Errorf("error converting string to int64 at path %s: %w", path, err) + } + parent[index] = intVal + } else { + return data, fmt.Errorf("expected string at path %s, but found %T", path, parent[index]) + } + return data, nil +} + +func setSlice(data any, permitMissingPaths bool, path string, parent []any, lastPart string) (out any, err error) { + if lastPart == "" || lastPart == "*" { + out = data + for i := range parent { + out, err = setSliceValue(out, permitMissingPaths, path, parent, i) + if err != nil { + return out, err + } + } + return out, err + } + index, err := strconv.Atoi(lastPart) + if err != nil { + return data, fmt.Errorf("invalid array index %s", lastPart) + } + return setSliceValue(data, permitMissingPaths, path, parent, index) +} + +func setDirectValue(data any, path string) (out any, err error) { + intVal, err := strconv.ParseInt(data.(string), 10, 64) + if err != nil { + return data, fmt.Errorf("error converting string to int64 at path %s: %w", path, err) + } + return intVal, nil +} + +func convertPath(data any, permitMissingPaths bool, path string) (out any, err error) { + var parentsRaw []any + var lastPart string + var parts []string + if path == "." { + parentsRaw = []any{data} + } else { + parts = splitPath(path) + if len(parts) == 0 { + return data, errors.New("empty path") + } + + lastPart = parts[len(parts)-1] + parentsRaw, err = findParents(data, permitMissingPaths, parts[:len(parts)-1]) + if err != nil { + return data, fmt.Errorf("error finding parent for path %s: %w", path, err) + } + } + + out = data + var errs *multierror.Error + for _, parentRaw := range parentsRaw { + switch parent := parentRaw.(type) { + case map[string]any: + out, err = setMap(out, permitMissingPaths, path, parent, lastPart) + if err != nil { + errs = multierror.Append(errs, err) + } + case []any: + out, err = setSlice(out, permitMissingPaths, path, parent, lastPart) + if err != nil { + errs = multierror.Append(errs, err) + } + case string: + out, err = setDirectValue(parent, path) + if err != nil { + errs = multierror.Append(errs, err) + } + default: + return data, fmt.Errorf("unsupported type %T at parent path %s", parent, strings.Join(parts[:len(parts)-1], ".")) + } + } + return out, errs.ErrorOrNil() } func UnmarshalProtoBSON(data []byte, m proto.Message, update updateJSON) error { - var obj map[string]interface{} + var obj map[string]any if err := bson.Unmarshal(data, &obj); err != nil { return err } if update != nil { - update(obj) + if err := update(obj); err != nil { + return err + } } jsonData, err := json.Marshal(obj) if err != nil { @@ -77,13 +239,15 @@ func MarshalProtoBSON(m proto.Message, update updateJSON) ([]byte, error) { if err != nil { return nil, err } - var obj map[string]interface{} + var obj map[string]any err = json.Unmarshal(data, &obj) if err != nil { return nil, err } if update != nil { - update(obj) + if err := update(obj); err != nil { + return nil, err + } } return bson.Marshal(obj) } diff --git a/pkg/mongodb/marshal_test.go b/pkg/mongodb/marshal_test.go new file mode 100644 index 000000000..edeb1ad87 --- /dev/null +++ b/pkg/mongodb/marshal_test.go @@ -0,0 +1,223 @@ +package mongodb_test + +import ( + "testing" + + "github.com/plgd-dev/hub/v2/pkg/mongodb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConvertStringValueToInt64(t *testing.T) { + type args struct { + data interface{} + paths []string + permitMissingPaths bool + } + + tests := []struct { + name string + args args + want interface{} + wantErr bool + ignoreErr bool + }{ + { + name: "emptyPath", + args: args{ + data: map[string]interface{}{}, + paths: []string{""}, + }, + wantErr: true, + }, + { + name: "invalidPath", + args: args{ + data: map[string]interface{}{}, + paths: []string{"foo"}, + }, + wantErr: true, + }, + { + name: "directValue", + args: args{ + data: "123", + paths: []string{"."}, + }, + want: int64(123), + }, + { + name: "arrayValue", + args: args{ + data: []interface{}{ + "123", + "456", + "789", + }, + paths: []string{".[0]", ".[2]"}, + }, + want: []interface{}{int64(123), "456", int64(789)}, + }, + { + name: "mapValue", + args: args{ + data: map[string]interface{}{ + "foo": "123", + }, + paths: []string{".foo"}, + }, + want: map[string]interface{}{ + "foo": int64(123), + }, + }, + { + name: "mapArrayValue", + args: args{ + data: map[string]interface{}{ + "foo": []interface{}{ + "123", + "456", + "789", + }, + }, + paths: []string{".foo[0]", ".foo[2]"}, + }, + want: map[string]interface{}{ + "foo": []interface{}{int64(123), "456", int64(789)}, + }, + }, + { + name: "nestedMapValue", + args: args{ + data: map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": "123", + }, + }, + paths: []string{".foo.bar"}, + }, + want: map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(123), + }, + }, + }, + { + name: "nestedArrayMapValue", + args: args{ + data: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": "123", + }, + map[string]interface{}{ + "bar": "456", + }, + }, + }, + paths: []string{".foo[0].bar", ".foo[1].bar"}, + }, + want: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": int64(123), + }, + map[string]interface{}{ + "bar": int64(456), + }, + }, + }, + }, + { + name: "nestedArrayMapAllValues", + args: args{ + data: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": "123", + }, + map[string]interface{}{ + "bar": "456", + }, + }, + }, + paths: []string{".foo[*].bar"}, + }, + want: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": int64(123), + }, + map[string]interface{}{ + "bar": int64(456), + }, + }, + }, + }, + { + name: "nestedArrayMapWithMissingPathsAllValues", + args: args{ + data: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": "123", + }, + map[string]interface{}{ + "efg": "456", + }, + map[string]interface{}{ + "bar": "789", + }, + }, + }, + paths: []string{".foo[*].bar"}, + permitMissingPaths: true, + }, + want: map[string]interface{}{ + "foo": []interface{}{ + map[string]interface{}{ + "bar": int64(123), + }, + map[string]interface{}{ + "efg": "456", + }, + map[string]interface{}{ + "bar": int64(789), + }, + }, + }, + }, + { + name: "mapArrayAllValues", + args: args{ + data: map[string]interface{}{ + "foo": []interface{}{ + "123", + "456", + "789", + }, + }, + paths: []string{".foo[*]"}, + }, + want: map[string]interface{}{ + "foo": []interface{}{int64(123), int64(456), int64(789)}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := mongodb.ConvertStringValueToInt64(tt.args.data, tt.args.permitMissingPaths, tt.args.paths...) + if tt.wantErr { + require.Error(t, err) + if !tt.ignoreErr { + return + } + } + if !tt.ignoreErr { + require.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/snippet-service/pb/appliedConfiguration.go b/snippet-service/pb/appliedConfiguration.go index 813be74dd..02ddf05c4 100644 --- a/snippet-service/pb/appliedConfiguration.go +++ b/snippet-service/pb/appliedConfiguration.go @@ -92,8 +92,11 @@ func (r *AppliedConfiguration_Resource) Clone() *AppliedConfiguration_Resource { } } -func (r *AppliedConfiguration_Resource) jsonToBSONTag(json map[string]interface{}) { - pkgMongo.ConvertStringValueToInt64(json, "validUntil") +func (r *AppliedConfiguration_Resource) jsonToBSONTag(json map[string]interface{}) error { + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+ValidUntil); err != nil { + return fmt.Errorf("cannot convert .validUntil to int64: %w", err) + } + return nil } func (r *AppliedConfiguration_Resource) MarshalBSON() ([]byte, error) { @@ -133,10 +136,17 @@ func (c *AppliedConfiguration) Clone() *AppliedConfiguration { } } -func (c *AppliedConfiguration) jsonToBSONTag(json map[string]interface{}) { - pkgMongo.ConvertStringValueToInt64(json, "configurationId.version") - pkgMongo.ConvertStringValueToInt64(json, "conditionId.version") - pkgMongo.ConvertStringValueToInt64(json, "resources.validUntil") +func (c *AppliedConfiguration) jsonToBSONTag(json map[string]interface{}) error { + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+ConfigurationIDKey+"."+VersionKey); err != nil { + return fmt.Errorf("cannot convert configurationId.version to int64: %w", err) + } + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+ConditionIDKey+"."+VersionKey); err != nil { + return fmt.Errorf("cannot convert conditionId.version to int64: %w", err) + } + if _, err := pkgMongo.ConvertStringValueToInt64(json, true, "."+ResourcesKey+".[*]."+ValidUntil); err != nil { + return fmt.Errorf("cannot convert resources.validUntil to int64: %w", err) + } + return nil } func (c *AppliedConfiguration) MarshalBSON() ([]byte, error) { diff --git a/snippet-service/store/appliedConfiguration.go b/snippet-service/store/appliedConfiguration.go index 9664e97e0..a4a87b93f 100644 --- a/snippet-service/store/appliedConfiguration.go +++ b/snippet-service/store/appliedConfiguration.go @@ -44,14 +44,23 @@ func (c *AppliedConfiguration) GetAppliedConfiguration() *pb.AppliedConfiguratio } func (c *AppliedConfiguration) UnmarshalBSON(data []byte) error { - update := func(json map[string]interface{}) { - recordID, ok := json[pb.RecordIDKey] + var recordID string + update := func(json map[string]interface{}) error { + recordIDI, ok := json[pb.RecordIDKey] if ok { - c.RecordID = recordID.(primitive.ObjectID).Hex() + recordID = recordIDI.(primitive.ObjectID).Hex() } delete(json, pb.RecordIDKey) + return nil + } + err := pkgMongo.UnmarshalProtoBSON(data, &c.AppliedConfiguration, update) + if err != nil { + return err } - return pkgMongo.UnmarshalProtoBSON(data, &c.AppliedConfiguration, update) + if c.GetId() == "" && recordID != "" { + c.RecordID = recordID + } + return nil } type UpdateAppliedConfigurationResourceRequest struct {