diff --git a/internal/api/group.go b/internal/api/group.go index e48191ee15..bff0089748 100644 --- a/internal/api/group.go +++ b/internal/api/group.go @@ -19,8 +19,6 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/protocol/group" "github.com/openimsdk/tools/a2r" - "github.com/openimsdk/tools/apiresp" - "github.com/openimsdk/tools/log" ) type GroupApi rpcclient.Group @@ -148,45 +146,7 @@ func (o *GroupApi) GetIncrementalGroupMember(c *gin.Context) { } func (o *GroupApi) GetIncrementalGroupMemberBatch(c *gin.Context) { - type BatchIncrementalReq struct { - UserID string `json:"user_id"` - List []*group.GetIncrementalGroupMemberReq `json:"list"` - } - type BatchIncrementalResp struct { - List map[string]*group.GetIncrementalGroupMemberResp `json:"list"` - } - req, err := a2r.ParseRequestNotCheck[BatchIncrementalReq](c) - if err != nil { - apiresp.GinError(c, err) - return - } - resp := &BatchIncrementalResp{ - List: make(map[string]*group.GetIncrementalGroupMemberResp), - } - var ( - changeCount int - ) - for _, req := range req.List { - if _, ok := resp.List[req.GroupID]; ok { - continue - } - res, err := o.Client.GetIncrementalGroupMember(c, req) - if err != nil { - if len(resp.List) == 0 { - apiresp.GinError(c, err) - } else { - log.ZError(c, "group incr sync versopn", err, "groupID", req.GroupID, "success", len(resp.List)) - apiresp.GinSuccess(c, resp) - } - return - } - resp.List[req.GroupID] = res - changeCount += len(res.Insert) + len(res.Delete) + len(res.Update) - if changeCount >= 200 { - break - } - } - apiresp.GinSuccess(c, resp) + a2r.Call(group.GroupClient.BatchGetIncrementalGroupMember, o.Client, c) } func (o *GroupApi) GetFullGroupMemberUserIDs(c *gin.Context) { diff --git a/internal/rpc/conversation/sync.go b/internal/rpc/conversation/sync.go index 29c11c4a10..ad88b2bbd5 100644 --- a/internal/rpc/conversation/sync.go +++ b/internal/rpc/conversation/sync.go @@ -2,6 +2,7 @@ package conversation import ( "context" + "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" @@ -40,7 +41,6 @@ func (c *conversationServer) GetIncrementalConversation(ctx context.Context, req Find: func(ctx context.Context, conversationIDs []string) ([]*conversation.Conversation, error) { return c.getConversations(ctx, req.UserID, conversationIDs) }, - ID: func(elem *conversation.Conversation) string { return elem.GroupID }, Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*conversation.Conversation, full bool) *conversation.GetIncrementalConversationResp { return &conversation.GetIncrementalConversationResp{ VersionID: version.ID.Hex(), diff --git a/internal/rpc/friend/sync.go b/internal/rpc/friend/sync.go index eee9f2afdc..145c287da0 100644 --- a/internal/rpc/friend/sync.go +++ b/internal/rpc/friend/sync.go @@ -78,7 +78,6 @@ func (s *friendServer) GetIncrementalFriends(ctx context.Context, req *relation. Find: func(ctx context.Context, ids []string) ([]*sdkws.FriendInfo, error) { return s.getFriend(ctx, req.UserID, ids) }, - ID: func(elem *sdkws.FriendInfo) string { return elem.FriendUser.UserID }, Resp: func(version *model.VersionLog, deleteIds []string, insertList, updateList []*sdkws.FriendInfo, full bool) *relation.GetIncrementalFriendsResp { return &relation.GetIncrementalFriendsResp{ VersionID: version.ID.Hex(), diff --git a/internal/rpc/group/sync.go b/internal/rpc/group/sync.go index f89a98ee84..0592aa811c 100644 --- a/internal/rpc/group/sync.go +++ b/internal/rpc/group/sync.go @@ -2,6 +2,7 @@ package group import ( "context" + "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" @@ -10,13 +11,10 @@ import ( "github.com/openimsdk/protocol/constant" pbgroup "github.com/openimsdk/protocol/group" "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/log" ) -func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) { - //TODO implement me - panic("implement me") -} - func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) if err != nil { @@ -104,7 +102,6 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou Find: func(ctx context.Context, ids []string) ([]*sdkws.GroupMemberFullInfo, error) { return s.getGroupMembersInfo(ctx, req.GroupID, ids) }, - ID: func(elem *sdkws.GroupMemberFullInfo) string { return elem.UserID }, Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupMemberFullInfo, full bool) *pbgroup.GetIncrementalGroupMemberResp { return &pbgroup.GetIncrementalGroupMemberResp{ VersionID: version.ID.Hex(), @@ -135,6 +132,150 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou return resp, nil } +func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (resp *pbgroup.BatchGetIncrementalGroupMemberResp, err error) { + type VersionInfo struct { + GroupID string + VersionID string + VersionNumber uint64 + } + + var groupIDs []string + + groupsVersionMap := make(map[string]*VersionInfo) + groupsMap := make(map[string]*model.Group) + hasGroupUpdateMap := make(map[string]bool) + sortVersionMap := make(map[string]uint64) + + var targetKeys, versionIDs []string + var versionNumbers []uint64 + + var requestBodyLen int + + for _, group := range req.ReqList { + groupsVersionMap[group.GroupID] = &VersionInfo{ + GroupID: group.GroupID, + VersionID: group.VersionID, + VersionNumber: group.Version, + } + + groupIDs = append(groupIDs, group.GroupID) + } + + groups, err := s.db.FindGroup(ctx, groupIDs) + if err != nil { + return nil, errs.Wrap(err) + } + + for _, group := range groups { + if group.Status == constant.GroupStatusDismissed { + err = servererrs.ErrDismissedAlready.Wrap() + log.ZError(ctx, "This group is Dismissed Already", err, "group is", group.GroupID) + + delete(groupsVersionMap, group.GroupID) + } else { + groupsMap[group.GroupID] = group + } + } + + for groupID, vInfo := range groupsVersionMap { + targetKeys = append(targetKeys, groupID) + versionIDs = append(versionIDs, vInfo.VersionID) + versionNumbers = append(versionNumbers, vInfo.VersionNumber) + } + + opt := incrversion.BatchOption[[]*sdkws.GroupMemberFullInfo, pbgroup.BatchGetIncrementalGroupMemberResp]{ + Ctx: ctx, + TargetKeys: targetKeys, + VersionIDs: versionIDs, + VersionNumbers: versionNumbers, + Versions: func(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) { + vLogs, err := s.db.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits) + if err != nil { + return nil, errs.Wrap(err) + } + + for groupID, vlog := range vLogs { + vlogElems := make([]model.VersionLogElem, 0, len(vlog.Logs)) + for i, log := range vlog.Logs { + switch log.EID { + case model.VersionGroupChangeID: + vlog.LogLen-- + hasGroupUpdateMap[groupID] = true + case model.VersionSortChangeID: + vlog.LogLen-- + sortVersionMap[groupID] = uint64(log.Version) + default: + vlogElems = append(vlogElems, vlog.Logs[i]) + } + } + vlog.Logs = vlogElems + if vlog.LogLen > 0 { + hasGroupUpdateMap[groupID] = true + } + } + + return vLogs, nil + }, + CacheMaxVersions: s.db.BatchFindMaxGroupMemberVersionCache, + Find: func(ctx context.Context, groupID string, ids []string) ([]*sdkws.GroupMemberFullInfo, error) { + memberInfo, err := s.getGroupMembersInfo(ctx, groupID, ids) + if err != nil { + return nil, err + } + + return memberInfo, err + }, + Resp: func(versions map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string][]*sdkws.GroupMemberFullInfo, fullMap map[string]bool) *pbgroup.BatchGetIncrementalGroupMemberResp { + resList := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) + + for groupID, versionLog := range versions { + resList[groupID] = &pbgroup.GetIncrementalGroupMemberResp{ + VersionID: versionLog.ID.Hex(), + Version: uint64(versionLog.Version), + Full: fullMap[groupID], + Delete: deleteIdsMap[groupID], + Insert: insertListMap[groupID], + Update: updateListMap[groupID], + SortVersion: sortVersionMap[groupID], + } + + requestBodyLen += len(insertListMap[groupID]) + len(updateListMap[groupID]) + len(deleteIdsMap[groupID]) + if requestBodyLen > 200 { + break + } + } + + return &pbgroup.BatchGetIncrementalGroupMemberResp{ + RespList: resList, + } + }, + } + + resp, err = opt.Build() + if err != nil { + return nil, errs.Wrap(err) + } + + for groupID, val := range resp.RespList { + if val.Full || hasGroupUpdateMap[groupID] { + count, err := s.db.FindGroupMemberNum(ctx, groupID) + if err != nil { + return nil, err + } + + owner, err := s.db.TakeGroupOwner(ctx, groupID) + if err != nil { + return nil, err + } + + resp.RespList[groupID].Group = s.groupDB2PB(groupsMap[groupID], owner.UserID, count) + } + } + + return resp, nil + +} + func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { return nil, err @@ -147,7 +288,6 @@ func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup. Version: s.db.FindJoinIncrVersion, CacheMaxVersion: s.db.FindMaxJoinGroupVersionCache, Find: s.getGroupsInfo, - ID: func(elem *sdkws.GroupInfo) string { return elem.GroupID }, Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp { return &pbgroup.GetIncrementalJoinGroupResp{ VersionID: version.ID.Hex(), diff --git a/internal/rpc/incrversion/batch_option.go b/internal/rpc/incrversion/batch_option.go new file mode 100644 index 0000000000..34d1b25066 --- /dev/null +++ b/internal/rpc/incrversion/batch_option.go @@ -0,0 +1,207 @@ +package incrversion + +import ( + "context" + "fmt" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/errs" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type BatchOption[A, B any] struct { + Ctx context.Context + TargetKeys []string + VersionIDs []string + VersionNumbers []uint64 + //SyncLimit int + Versions func(ctx context.Context, dIds []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) + CacheMaxVersions func(ctx context.Context, dIds []string) (map[string]*model.VersionLog, error) + Find func(ctx context.Context, dId string, ids []string) (A, error) + Resp func(versionsMap map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string]A, fullMap map[string]bool) *B +} + +func (o *BatchOption[A, B]) newError(msg string) error { + return errs.ErrInternalServer.WrapMsg(msg) +} + +func (o *BatchOption[A, B]) check() error { + if o.Ctx == nil { + return o.newError("opt ctx is nil") + } + if len(o.TargetKeys) == 0 { + return o.newError("targetKeys is empty") + } + if o.Versions == nil { + return o.newError("func versions is nil") + } + if o.Find == nil { + return o.newError("func find is nil") + } + if o.Resp == nil { + return o.newError("func resp is nil") + } + return nil +} + +func (o *BatchOption[A, B]) validVersions() []bool { + valids := make([]bool, len(o.VersionIDs)) + for i, versionID := range o.VersionIDs { + objID, err := primitive.ObjectIDFromHex(versionID) + valids[i] = (err == nil && (!objID.IsZero()) && o.VersionNumbers[i] > 0) + } + return valids +} + +func (o *BatchOption[A, B]) equalIDs(objIDs []primitive.ObjectID) []bool { + equals := make([]bool, len(o.VersionIDs)) + for i, versionID := range o.VersionIDs { + equals[i] = versionID == objIDs[i].Hex() + } + return equals +} + +func (o *BatchOption[A, B]) getVersions(tags *[]int) (versions map[string]*model.VersionLog, err error) { + var dIDs []string + var versionNums []uint64 + var limits []int + + valids := o.validVersions() + + if o.CacheMaxVersions == nil { + for i, valid := range valids { + if valid { + (*tags)[i] = tagQuery + dIDs = append(dIDs, o.TargetKeys[i]) + versionNums = append(versionNums, o.VersionNumbers[i]) + limits = append(limits, syncLimit) + } else { + (*tags)[i] = tagFull + dIDs = append(dIDs, o.TargetKeys[i]) + versionNums = append(versionNums, 0) + limits = append(limits, 0) + } + } + + versions, err = o.Versions(o.Ctx, dIDs, versionNums, limits) + if err != nil { + return nil, errs.Wrap(err) + } + return versions, nil + + } else { + caches, err := o.CacheMaxVersions(o.Ctx, o.TargetKeys) + if err != nil { + return nil, errs.Wrap(err) + } + + objIDs := make([]primitive.ObjectID, len(o.VersionIDs)) + + for i, versionID := range o.VersionIDs { + objID, _ := primitive.ObjectIDFromHex(versionID) + objIDs[i] = objID + } + + equals := o.equalIDs(objIDs) + for i, valid := range valids { + if !valid { + (*tags)[i] = tagFull + } else if !equals[i] { + (*tags)[i] = tagFull + } else if o.VersionNumbers[i] == uint64(caches[o.TargetKeys[i]].Version) { + (*tags)[i] = tagEqual + } else { + (*tags)[i] = tagQuery + dIDs = append(dIDs, o.TargetKeys[i]) + versionNums = append(versionNums, o.VersionNumbers[i]) + limits = append(limits, syncLimit) + + delete(caches, o.TargetKeys[i]) + } + } + + if dIDs != nil { + versionMap, err := o.Versions(o.Ctx, dIDs, versionNums, limits) + if err != nil { + return nil, errs.Wrap(err) + } + + for k, v := range versionMap { + caches[k] = v + } + } + + versions = caches + } + return versions, nil +} + +func (o *BatchOption[A, B]) Build() (*B, error) { + if err := o.check(); err != nil { + return nil, errs.Wrap(err) + } + + tags := make([]int, len(o.TargetKeys)) + versions, err := o.getVersions(&tags) + if err != nil { + return nil, errs.Wrap(err) + } + + fullMap := make(map[string]bool) + for i, tag := range tags { + switch tag { + case tagQuery: + vLog := versions[o.TargetKeys[i]] + fullMap[o.TargetKeys[i]] = vLog.ID.Hex() != o.VersionIDs[i] || uint64(vLog.Version) < o.VersionNumbers[i] || len(vLog.Logs) != vLog.LogLen + case tagFull: + fullMap[o.TargetKeys[i]] = true + case tagEqual: + fullMap[o.TargetKeys[i]] = false + default: + panic(fmt.Errorf("undefined tag %d", tag)) + } + } + + var ( + insertIdsMap = make(map[string][]string) + deleteIdsMap = make(map[string][]string) + updateIdsMap = make(map[string][]string) + ) + + for _, targetKey := range o.TargetKeys { + if !fullMap[targetKey] { + version := versions[targetKey] + insertIds, deleteIds, updateIds := version.DeleteAndChangeIDs() + insertIdsMap[targetKey] = insertIds + deleteIdsMap[targetKey] = deleteIds + updateIdsMap[targetKey] = updateIds + } + } + + var ( + insertListMap = make(map[string]A) + updateListMap = make(map[string]A) + ) + + for targetKey, insertIds := range insertIdsMap { + if len(insertIds) > 0 { + insertList, err := o.Find(o.Ctx, targetKey, insertIds) + if err != nil { + return nil, errs.Wrap(err) + } + insertListMap[targetKey] = insertList + } + } + + for targetKey, updateIds := range updateIdsMap { + if len(updateIds) > 0 { + updateList, err := o.Find(o.Ctx, targetKey, updateIds) + if err != nil { + return nil, errs.Wrap(err) + } + updateListMap[targetKey] = updateList + } + } + + return o.Resp(versions, deleteIdsMap, insertListMap, updateListMap, fullMap), nil +} diff --git a/internal/rpc/incrversion/option.go b/internal/rpc/incrversion/option.go index f7a71244a0..af1200d5c0 100644 --- a/internal/rpc/incrversion/option.go +++ b/internal/rpc/incrversion/option.go @@ -3,6 +3,7 @@ package incrversion import ( "context" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/tools/errs" "go.mongodb.org/mongo-driver/bson/primitive" @@ -20,7 +21,7 @@ const syncLimit = 200 const ( tagQuery = iota + 1 tagFull - tageEqual + tagEqual ) type Option[A, B any] struct { @@ -33,7 +34,6 @@ type Option[A, B any] struct { Version func(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error) //SortID func(ctx context.Context, dId string) ([]string, error) Find func(ctx context.Context, ids []string) ([]A, error) - ID func(elem A) string Resp func(version *model.VersionLog, deleteIds []string, insertList, updateList []A, full bool) *B } @@ -60,9 +60,6 @@ func (o *Option[A, B]) check() error { if o.Find == nil { return o.newError("func find is nil") } - if o.ID == nil { - return o.newError("func id is nil") - } if o.Resp == nil { return o.newError("func resp is nil") } @@ -100,7 +97,7 @@ func (o *Option[A, B]) getVersion(tag *int) (*model.VersionLog, error) { return cache, nil } if o.VersionNumber == uint64(cache.Version) { - *tag = tageEqual + *tag = tagEqual return cache, nil } *tag = tagQuery @@ -123,7 +120,7 @@ func (o *Option[A, B]) Build() (*B, error) { full = version.ID.Hex() != o.VersionID || uint64(version.Version) < o.VersionNumber || len(version.Logs) != version.LogLen case tagFull: full = true - case tageEqual: + case tagEqual: full = false default: panic(fmt.Errorf("undefined tag %d", tag)) diff --git a/pkg/common/storage/cache/group.go b/pkg/common/storage/cache/group.go index 91953d9f9a..1ec0462956 100644 --- a/pkg/common/storage/cache/group.go +++ b/pkg/common/storage/cache/group.go @@ -64,5 +64,6 @@ type GroupCache interface { DelMaxGroupMemberVersion(groupIDs ...string) GroupCache DelMaxJoinGroupVersion(userIDs ...string) GroupCache FindMaxGroupMemberVersion(ctx context.Context, groupID string) (*model.VersionLog, error) + BatchFindMaxGroupMemberVersion(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error) FindMaxJoinGroupVersion(ctx context.Context, userID string) (*model.VersionLog, error) } diff --git a/pkg/common/storage/cache/redis/group.go b/pkg/common/storage/cache/redis/group.go index d327c218f3..736111df31 100644 --- a/pkg/common/storage/cache/redis/group.go +++ b/pkg/common/storage/cache/redis/group.go @@ -17,6 +17,8 @@ package redis import ( "context" "fmt" + "time" + "github.com/dtm-labs/rockscache" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" @@ -28,7 +30,6 @@ import ( "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/redis/go-redis/v9" - "time" ) const ( @@ -390,6 +391,21 @@ func (g *GroupCacheRedis) FindMaxGroupMemberVersion(ctx context.Context, groupID }) } +func (g *GroupCacheRedis) BatchFindMaxGroupMemberVersion(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error) { + return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs, + func(groupID string) string { + return g.getGroupMemberMaxVersionKey(groupID) + }, func(versionLog *model.VersionLog) string { + return versionLog.DID + }, func(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error) { + // create two slices with len is groupIDs, just need 0 + versions := make([]uint, len(groupIDs)) + limits := make([]int, len(groupIDs)) + + return g.groupMemberDB.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits) + }) +} + func (g *GroupCacheRedis) FindMaxJoinGroupVersion(ctx context.Context, userID string) (*model.VersionLog, error) { return getCache(ctx, g.rcClient, g.getJoinGroupMaxVersionKey(userID), g.expireTime, func(ctx context.Context) (*model.VersionLog, error) { return g.groupMemberDB.FindJoinIncrVersion(ctx, userID, 0, 0) diff --git a/pkg/common/storage/controller/group.go b/pkg/common/storage/controller/group.go index 3a5f48d4c6..072429ed09 100644 --- a/pkg/common/storage/controller/group.go +++ b/pkg/common/storage/controller/group.go @@ -16,17 +16,19 @@ package controller import ( "context" + "time" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/common" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" - "time" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/protocol/constant" "github.com/openimsdk/tools/db/pagination" "github.com/openimsdk/tools/db/tx" + "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" ) @@ -108,6 +110,7 @@ type GroupDatabase interface { DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error) + BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error @@ -115,6 +118,7 @@ type GroupDatabase interface { //FindSortJoinGroupIDs(ctx context.Context, userID string) ([]string, error) FindMaxGroupMemberVersionCache(ctx context.Context, groupID string) (*model.VersionLog, error) + BatchFindMaxGroupMemberVersionCache(ctx context.Context, groupIDs []string) (map[string]*model.VersionLog, error) FindMaxJoinGroupVersionCache(ctx context.Context, userID string) (*model.VersionLog, error) SearchJoinGroup(ctx context.Context, userID string, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error) @@ -498,6 +502,29 @@ func (g *groupDatabase) FindMemberIncrVersion(ctx context.Context, groupID strin return g.groupMemberDB.FindMemberIncrVersion(ctx, groupID, version, limit) } +func (g *groupDatabase) BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) { + if len(groupIDs) == 0 { + return nil, errs.Wrap(errs.New("groupIDs is nil.")) + } + + // convert []uint64 to []uint + var uintVersions []uint + for _, version := range versions { + uintVersions = append(uintVersions, uint(version)) + } + + versionLogs, err := g.groupMemberDB.BatchFindMemberIncrVersion(ctx, groupIDs, uintVersions, limits) + if err != nil { + return nil, errs.Wrap(err) + } + + groupMemberIncrVersionsMap := datautil.SliceToMap(versionLogs, func(e *model.VersionLog) string { + return e.DID + }) + + return groupMemberIncrVersionsMap, nil +} + func (g *groupDatabase) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) { return g.groupMemberDB.FindJoinIncrVersion(ctx, userID, version, limit) } @@ -506,6 +533,20 @@ func (g *groupDatabase) FindMaxGroupMemberVersionCache(ctx context.Context, grou return g.cache.FindMaxGroupMemberVersion(ctx, groupID) } +func (g *groupDatabase) BatchFindMaxGroupMemberVersionCache(ctx context.Context, groupIDs []string) (map[string]*model.VersionLog, error) { + if len(groupIDs) == 0 { + return nil, errs.Wrap(errs.New("groupIDs is nil in Cache.")) + } + versionLogs, err := g.cache.BatchFindMaxGroupMemberVersion(ctx, groupIDs) + if err != nil { + return nil, errs.Wrap(err) + } + maxGroupMemberVersionsMap := datautil.SliceToMap(versionLogs, func(e *model.VersionLog) string { + return e.DID + }) + return maxGroupMemberVersionsMap, nil +} + func (g *groupDatabase) FindMaxJoinGroupVersionCache(ctx context.Context, userID string) (*model.VersionLog, error) { return g.cache.FindMaxJoinGroupVersion(ctx, userID) } diff --git a/pkg/common/storage/database/group_member.go b/pkg/common/storage/database/group_member.go index 43a7e6095e..0ddf0654c0 100644 --- a/pkg/common/storage/database/group_member.go +++ b/pkg/common/storage/database/group_member.go @@ -16,6 +16,7 @@ package database import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/tools/db/pagination" ) @@ -40,5 +41,6 @@ type GroupMember interface { JoinGroupIncrVersion(ctx context.Context, userID string, groupIDs []string, state int32) error MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error) + BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint, limits []int) ([]*model.VersionLog, error) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) } diff --git a/pkg/common/storage/database/mgo/group_member.go b/pkg/common/storage/database/mgo/group_member.go index 42b3dd72bd..2fdf2003b5 100644 --- a/pkg/common/storage/database/mgo/group_member.go +++ b/pkg/common/storage/database/mgo/group_member.go @@ -16,6 +16,7 @@ package mgo import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/tools/log" @@ -230,6 +231,11 @@ func (g *GroupMemberMgo) FindMemberIncrVersion(ctx context.Context, groupID stri return g.member.FindChangeLog(ctx, groupID, version, limit) } +func (g *GroupMemberMgo) BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint, limits []int) ([]*model.VersionLog, error) { + log.ZDebug(ctx, "Batch find member incr version", "groupIDs", groupIDs, "versions", versions) + return g.member.BatchFindChangeLog(ctx, groupIDs, versions, limits) +} + func (g *GroupMemberMgo) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) { log.ZDebug(ctx, "find join incr version", "userID", userID, "version", version) return g.join.FindChangeLog(ctx, userID, version, limit) diff --git a/pkg/common/storage/database/mgo/version_log.go b/pkg/common/storage/database/mgo/version_log.go index 3b449007bc..2c4bdef4e8 100644 --- a/pkg/common/storage/database/mgo/version_log.go +++ b/pkg/common/storage/database/mgo/version_log.go @@ -3,6 +3,8 @@ package mgo import ( "context" "errors" + "time" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/versionctx" @@ -13,7 +15,6 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "time" ) func NewVersionLog(coll *mongo.Collection) (database.VersionLog, error) { @@ -35,6 +36,7 @@ func (l *VersionLogMgo) initIndex(ctx context.Context) error { }, Options: options.Index().SetUnique(true), }) + return err } @@ -198,6 +200,26 @@ func (l *VersionLogMgo) FindChangeLog(ctx context.Context, dId string, version u } } +func (l *VersionLogMgo) BatchFindChangeLog(ctx context.Context, dIds []string, versions []uint, limits []int) (vLogs []*model.VersionLog, err error) { + for i := 0; i < len(dIds); i++ { + if vLog, err := l.findChangeLog(ctx, dIds[i], versions[i], limits[i]); err == nil { + vLogs = append(vLogs, vLog) + } else if !errors.Is(err, mongo.ErrNoDocuments) { + log.ZError(ctx, "findChangeLog error:", errs.Wrap(err)) + } + log.ZDebug(ctx, "init doc", "dId", dIds[i]) + if res, err := l.initDoc(ctx, dIds[i], nil, 0, time.Now()); err == nil { + log.ZDebug(ctx, "init doc success", "dId", dIds[i]) + vLogs = append(vLogs, res) + } else if mongo.IsDuplicateKeyError(err) { + l.findChangeLog(ctx, dIds[i], versions[i], limits[i]) + } else { + log.ZError(ctx, "init doc error:", errs.Wrap(err)) + } + } + return vLogs, errs.Wrap(err) +} + func (l *VersionLogMgo) findChangeLog(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error) { if version == 0 && limit == 0 { return l.findDoc(ctx, dId) diff --git a/pkg/common/storage/database/version_log.go b/pkg/common/storage/database/version_log.go index 9d7bcc1724..28224a7c79 100644 --- a/pkg/common/storage/database/version_log.go +++ b/pkg/common/storage/database/version_log.go @@ -2,8 +2,9 @@ package database import ( "context" - "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" ) const ( @@ -14,6 +15,7 @@ const ( type VersionLog interface { IncrVersion(ctx context.Context, dId string, eIds []string, state int32) error FindChangeLog(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error) + BatchFindChangeLog(ctx context.Context, dIds []string, versions []uint, limits []int) ([]*model.VersionLog, error) DeleteAfterUnchangedLog(ctx context.Context, deadline time.Time) error Delete(ctx context.Context, dId string) error }