diff --git a/admin/api.go b/admin/api.go index 66ec5dae5..aec490231 100644 --- a/admin/api.go +++ b/admin/api.go @@ -20,64 +20,39 @@ package admin import ( "context" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - connlimit "github.com/polarismesh/polaris/common/conn/limit" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" ) -type ConnReq struct { - Protocol string - Host string - Port int - Amount int -} - -type ConnCountResp struct { - Protocol string - Total int32 - Host map[string]int32 -} - -type ConnStatsResp struct { - Protocol string - ActiveConnTotal int32 - StatsTotal int - StatsSize int - Stats []*connlimit.HostConnStat -} - -type ScopeLevel struct { - Name string - Level string -} - // AdminOperateServer Maintain related operation type AdminOperateServer interface { // GetServerConnections Get connection count - GetServerConnections(ctx context.Context, req *ConnReq) (*ConnCountResp, error) + GetServerConnections(ctx context.Context, req *admin.ConnReq) (*admin.ConnCountResp, error) // GetServerConnStats 获取连接缓存里面的统计信息 - GetServerConnStats(ctx context.Context, req *ConnReq) (*ConnStatsResp, error) + GetServerConnStats(ctx context.Context, req *admin.ConnReq) (*admin.ConnStatsResp, error) // CloseConnections Close connection by ip - CloseConnections(ctx context.Context, reqs []ConnReq) error + CloseConnections(ctx context.Context, reqs []admin.ConnReq) error // FreeOSMemory Free system memory FreeOSMemory(ctx context.Context) error // CleanInstance Clean deleted instance CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // BatchCleanInstances Batch clean deleted instances BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) // GetLastHeartbeat Get last heartbeat GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // GetLogOutputLevel Get log output level - GetLogOutputLevel(ctx context.Context) ([]ScopeLevel, error) + GetLogOutputLevel(ctx context.Context) ([]admin.ScopeLevel, error) // SetLogOutputLevel Set log output level by scope SetLogOutputLevel(ctx context.Context, scope string, level string) error // ListLeaderElections - ListLeaderElections(ctx context.Context) ([]*model.LeaderElection, error) + ListLeaderElections(ctx context.Context) ([]*admin.LeaderElection, error) // ReleaseLeaderElection ReleaseLeaderElection(ctx context.Context, electKey string) error // GetCMDBInfo get cmdb info GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) + // InitMainUser + InitMainUser(ctx context.Context, user apisecurity.User) error } diff --git a/admin/maintain.go b/admin/maintain.go index bc86c6c3a..208eded34 100644 --- a/admin/maintain.go +++ b/admin/maintain.go @@ -24,6 +24,7 @@ import ( "time" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" @@ -31,12 +32,21 @@ import ( connlimit "github.com/polarismesh/polaris/common/conn/limit" commonlog "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" ) -func (s *Server) GetServerConnections(_ context.Context, req *ConnReq) (*ConnCountResp, error) { +func (s *Server) HasMainUser(ctx context.Context, user apisecurity.User) (bool, error) { + return false, nil +} + +func (s *Server) InitMainUser(ctx context.Context, user apisecurity.User) error { + return nil +} + +func (s *Server) GetServerConnections(_ context.Context, req *admin.ConnReq) (*admin.ConnCountResp, error) { if req.Protocol == "" { return nil, errors.New("missing param protocol") } @@ -46,7 +56,7 @@ func (s *Server) GetServerConnections(_ context.Context, req *ConnReq) (*ConnCou return nil, errors.New("not found the protocol") } - var resp = ConnCountResp{ + var resp = admin.ConnCountResp{ Protocol: req.Protocol, Total: lis.GetListenerConnCount(), Host: map[string]int32{}, @@ -63,7 +73,7 @@ func (s *Server) GetServerConnections(_ context.Context, req *ConnReq) (*ConnCou return &resp, nil } -func (s *Server) GetServerConnStats(_ context.Context, req *ConnReq) (*ConnStatsResp, error) { +func (s *Server) GetServerConnStats(_ context.Context, req *admin.ConnReq) (*admin.ConnStatsResp, error) { if req.Protocol == "" { return nil, errors.New("missing param protocol") } @@ -73,7 +83,7 @@ func (s *Server) GetServerConnStats(_ context.Context, req *ConnReq) (*ConnStats return nil, errors.New("not found the protocol") } - var resp ConnStatsResp + var resp admin.ConnStatsResp resp.Protocol = req.Protocol resp.ActiveConnTotal = lis.GetListenerConnCount() @@ -100,7 +110,7 @@ func (s *Server) GetServerConnStats(_ context.Context, req *ConnReq) (*ConnStats return &resp, nil } -func (s *Server) CloseConnections(_ context.Context, reqs []ConnReq) error { +func (s *Server) CloseConnections(_ context.Context, reqs []admin.ConnReq) error { for _, entry := range reqs { listener := connlimit.GetLimitListener(entry.Protocol) if listener == nil { @@ -172,11 +182,11 @@ func (s *Server) GetLastHeartbeat(_ context.Context, req *apiservice.Instance) * return s.healthCheckServer.GetLastHeartbeat(req) } -func (s *Server) GetLogOutputLevel(_ context.Context) ([]ScopeLevel, error) { +func (s *Server) GetLogOutputLevel(_ context.Context) ([]admin.ScopeLevel, error) { scopes := commonlog.Scopes() - out := make([]ScopeLevel, 0, len(scopes)) + out := make([]admin.ScopeLevel, 0, len(scopes)) for k := range scopes { - out = append(out, ScopeLevel{ + out = append(out, admin.ScopeLevel{ Name: k, Level: scopes[k].GetOutputLevel().Name(), }) @@ -189,7 +199,7 @@ func (s *Server) SetLogOutputLevel(_ context.Context, scope string, level string return commonlog.SetLogOutputLevel(scope, level) } -func (s *Server) ListLeaderElections(_ context.Context) ([]*model.LeaderElection, error) { +func (s *Server) ListLeaderElections(_ context.Context) ([]*admin.LeaderElection, error) { return s.storage.ListLeaderElections() } diff --git a/admin/maintain_authability.go b/admin/maintain_authability.go index a8ef51e8f..d90e14a56 100644 --- a/admin/maintain_authability.go +++ b/admin/maintain_authability.go @@ -20,19 +20,29 @@ package admin import ( "context" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) var _ AdminOperateServer = (*serverAuthAbility)(nil) -func (svr *serverAuthAbility) GetServerConnections(ctx context.Context, req *ConnReq) (*ConnCountResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "GetServerConnections") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { +func (s *serverAuthAbility) HasMainUser(ctx context.Context) (bool, error) { + return false, nil +} + +func (s *serverAuthAbility) InitMainUser(ctx context.Context, user apisecurity.User) error { + return nil +} + +func (svr *serverAuthAbility) GetServerConnections(ctx context.Context, req *admin.ConnReq) (*admin.ConnCountResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return nil, err } @@ -42,10 +52,9 @@ func (svr *serverAuthAbility) GetServerConnections(ctx context.Context, req *Con return svr.targetServer.GetServerConnections(ctx, req) } -func (svr *serverAuthAbility) GetServerConnStats(ctx context.Context, req *ConnReq) (*ConnStatsResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "GetServerConnStats") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { +func (svr *serverAuthAbility) GetServerConnStats(ctx context.Context, req *admin.ConnReq) (*admin.ConnStatsResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnStats) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return nil, err } @@ -55,10 +64,9 @@ func (svr *serverAuthAbility) GetServerConnStats(ctx context.Context, req *ConnR return svr.targetServer.GetServerConnStats(ctx, req) } -func (svr *serverAuthAbility) CloseConnections(ctx context.Context, reqs []ConnReq) error { - authCtx := svr.collectMaintainAuthContext(ctx, model.Delete, "CloseConnections") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { +func (svr *serverAuthAbility) CloseConnections(ctx context.Context, reqs []admin.ConnReq) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CloseConnections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return err } @@ -69,9 +77,8 @@ func (svr *serverAuthAbility) CloseConnections(ctx context.Context, reqs []ConnR } func (svr *serverAuthAbility) FreeOSMemory(ctx context.Context) error { - authCtx := svr.collectMaintainAuthContext(ctx, model.Modify, "FreeOSMemory") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.FreeOSMemory) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return err } @@ -82,9 +89,8 @@ func (svr *serverAuthAbility) FreeOSMemory(ctx context.Context) error { } func (svr *serverAuthAbility) CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, model.Delete, "CleanInstance") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CleanInstance) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -95,9 +101,8 @@ func (svr *serverAuthAbility) CleanInstance(ctx context.Context, req *apiservice } func (svr *serverAuthAbility) BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Delete, "BatchCleanInstances") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.BatchCleanInstances) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return 0, err } @@ -105,9 +110,8 @@ func (svr *serverAuthAbility) BatchCleanInstances(ctx context.Context, batchSize } func (svr *serverAuthAbility) GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "GetLastHeartbeat") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeInstanceLastHeartbeat) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -117,10 +121,9 @@ func (svr *serverAuthAbility) GetLastHeartbeat(ctx context.Context, req *apiserv return svr.targetServer.GetLastHeartbeat(ctx, req) } -func (svr *serverAuthAbility) GetLogOutputLevel(ctx context.Context) ([]ScopeLevel, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "GetLogOutputLevel") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { +func (svr *serverAuthAbility) GetLogOutputLevel(ctx context.Context) ([]admin.ScopeLevel, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeGetLogOutputLevel) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return nil, err } @@ -131,19 +134,17 @@ func (svr *serverAuthAbility) GetLogOutputLevel(ctx context.Context) ([]ScopeLev } func (svr *serverAuthAbility) SetLogOutputLevel(ctx context.Context, scope string, level string) error { - authCtx := svr.collectMaintainAuthContext(ctx, model.Modify, "SetLogOutputLevel") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.UpdateLogOutputLevel) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return err } return svr.targetServer.SetLogOutputLevel(ctx, scope, level) } -func (svr *serverAuthAbility) ListLeaderElections(ctx context.Context) ([]*model.LeaderElection, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "ListLeaderElections") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { +func (svr *serverAuthAbility) ListLeaderElections(ctx context.Context) ([]*admin.LeaderElection, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeLeaderElections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return nil, err } @@ -154,9 +155,8 @@ func (svr *serverAuthAbility) ListLeaderElections(ctx context.Context) ([]*model } func (svr *serverAuthAbility) ReleaseLeaderElection(ctx context.Context, electKey string) error { - authCtx := svr.collectMaintainAuthContext(ctx, model.Modify, "ReleaseLeaderElection") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.ReleaseLeaderElection) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return err } @@ -167,9 +167,8 @@ func (svr *serverAuthAbility) ReleaseLeaderElection(ctx context.Context, electKe } func (svr *serverAuthAbility) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { - authCtx := svr.collectMaintainAuthContext(ctx, model.Read, "GetCMDBInfo") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeCMDBInfo) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return nil, err } diff --git a/admin/server_authability.go b/admin/server_authability.go index 87783161c..2ddfbcbde 100644 --- a/admin/server_authability.go +++ b/admin/server_authability.go @@ -24,7 +24,7 @@ import ( apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) // serverAuthAbility 带有鉴权能力的 maintainServer @@ -45,22 +45,22 @@ func newServerAuthAbility(targetServer *Server, return proxy } -func (svr *serverAuthAbility) collectMaintainAuthContext(ctx context.Context, resourceOp model.ResourceOperation, - methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.MaintainModule), - model.WithMethod(methodName), +func (svr *serverAuthAbility) collectMaintainAuthContext(ctx context.Context, resourceOp authcommon.ResourceOperation, + methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.MaintainModule), + authcommon.WithMethod(methodName), ) } func convertToErrCode(err error) apimodel.Code { - if errors.Is(err, model.ErrorTokenNotExist) { + if errors.Is(err, authcommon.ErrorTokenNotExist) { return apimodel.Code_TokenNotExisted } - if errors.Is(err, model.ErrorTokenDisabled) { + if errors.Is(err, authcommon.ErrorTokenDisabled) { return apimodel.Code_TokenDisabled } diff --git a/apiserver/grpcserver/base.go b/apiserver/grpcserver/base.go index d42c6205b..f37e19cfa 100644 --- a/apiserver/grpcserver/base.go +++ b/apiserver/grpcserver/base.go @@ -37,7 +37,7 @@ import ( connlimit "github.com/polarismesh/polaris/common/conn/limit" commonlog "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/metrics" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/secure" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" @@ -58,7 +58,7 @@ type BaseGrpcServer struct { protocol string - bz model.BzModule + bz authcommon.BzModule server *grpc.Server statis plugin.Statis @@ -436,24 +436,24 @@ func (b *BaseGrpcServer) AllowAccess(method string) bool { } type connCounterHook struct { - bz model.BzModule + bz authcommon.BzModule } func (h *connCounterHook) OnAccept(conn net.Conn) { - if h.bz == model.DiscoverModule { + if h.bz == authcommon.DiscoverModule { metrics.AddDiscoveryClientConn() } - if h.bz == model.ConfigModule { + if h.bz == authcommon.ConfigModule { metrics.AddConfigurationClientConn() } metrics.AddSDKClientConn() } func (h *connCounterHook) OnRelease(conn net.Conn) { - if h.bz == model.DiscoverModule { + if h.bz == authcommon.DiscoverModule { metrics.RemoveDiscoveryClientConn() } - if h.bz == model.ConfigModule { + if h.bz == authcommon.ConfigModule { metrics.RemoveConfigurationClientConn() } metrics.RemoveSDKClientConn() diff --git a/apiserver/grpcserver/config/server.go b/apiserver/grpcserver/config/server.go index b26304a55..fb53325e2 100644 --- a/apiserver/grpcserver/config/server.go +++ b/apiserver/grpcserver/config/server.go @@ -28,7 +28,7 @@ import ( "github.com/polarismesh/polaris/apiserver/grpcserver" "github.com/polarismesh/polaris/apiserver/grpcserver/utils" commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/config" ) @@ -58,7 +58,7 @@ func (g *ConfigGRPCServer) Initialize(ctx context.Context, option map[string]int apiConf map[string]apiserver.APIConfig) error { g.openAPI = apiConf return g.BaseGrpcServer.Initialize(ctx, option, - grpcserver.WithModule(model.ConfigModule), + grpcserver.WithModule(authcommon.ConfigModule), grpcserver.WithProtocol(g.GetProtocol()), grpcserver.WithLogger(commonlog.FindScope(commonlog.APIServerLoggerName)), ) diff --git a/apiserver/grpcserver/discover/server.go b/apiserver/grpcserver/discover/server.go index 77d99f6b1..140c4bfd9 100644 --- a/apiserver/grpcserver/discover/server.go +++ b/apiserver/grpcserver/discover/server.go @@ -31,7 +31,7 @@ import ( v1 "github.com/polarismesh/polaris/apiserver/grpcserver/discover/v1" "github.com/polarismesh/polaris/apiserver/grpcserver/utils" commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/service/healthcheck" ) @@ -153,7 +153,7 @@ func (g *GRPCServer) allowAccess(method string) bool { func (g *GRPCServer) buildInitOptions(option map[string]interface{}) []grpcserver.InitOption { initOptions := []grpcserver.InitOption{ - grpcserver.WithModule(model.DiscoverModule), + grpcserver.WithModule(authcommon.DiscoverModule), grpcserver.WithProtocol(g.GetProtocol()), grpcserver.WithLogger(commonlog.FindScope(commonlog.APIServerLoggerName)), grpcserver.WithMessageToCacheObject(discoverCacheConvert), diff --git a/apiserver/grpcserver/option.go b/apiserver/grpcserver/option.go index 16928fb69..9f447b78e 100644 --- a/apiserver/grpcserver/option.go +++ b/apiserver/grpcserver/option.go @@ -19,13 +19,13 @@ package grpcserver import ( commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) type InitOption func(svr *BaseGrpcServer) // WithModule set bz module -func WithModule(bz model.BzModule) InitOption { +func WithModule(bz authcommon.BzModule) InitOption { return func(svr *BaseGrpcServer) { svr.bz = bz } diff --git a/apiserver/httpserver/admin_access.go b/apiserver/httpserver/admin_access.go index de388929d..80589b46a 100644 --- a/apiserver/httpserver/admin_access.go +++ b/apiserver/httpserver/admin_access.go @@ -27,11 +27,10 @@ import ( apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "github.com/polarismesh/polaris/admin" "github.com/polarismesh/polaris/apiserver/httpserver/docs" httpcommon "github.com/polarismesh/polaris/apiserver/httpserver/utils" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" "github.com/polarismesh/polaris/common/utils" ) @@ -262,7 +261,7 @@ func (h *HTTPServer) ListLeaderElections(req *restful.Request, rsp *restful.Resp return } if leaders == nil { - leaders = []*model.LeaderElection{} + leaders = []*admin.LeaderElection{} } _ = rsp.WriteAsJson(leaders) diff --git a/apiserver/httpserver/auth_access.go b/apiserver/httpserver/auth_access.go index 6766ce14d..15c5fc0e2 100644 --- a/apiserver/httpserver/auth_access.go +++ b/apiserver/httpserver/auth_access.go @@ -43,7 +43,7 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichUpdateUserApiDocs(ws.PUT("/user").To(h.UpdateUser))) ws.Route(docs.EnrichUpdateUserPasswordApiDocs(ws.PUT("/user/password").To(h.UpdateUserPassword))) ws.Route(docs.EnrichGetUserTokenApiDocs(ws.GET("/user/token").To(h.GetUserToken))) - ws.Route(docs.EnrichUpdateUserTokenApiDocs(ws.PUT("/user/token/status").To(h.UpdateUserToken))) + ws.Route(docs.EnrichUpdateUserTokenApiDocs(ws.PUT("/user/token/status").To(h.EnableUserToken))) ws.Route(docs.EnrichResetUserTokenApiDocs(ws.PUT("/user/token/refresh").To(h.ResetUserToken))) // ws.Route(docs.EnrichCreateGroupApiDocs(ws.POST("/usergroup").To(h.CreateGroup))) @@ -52,7 +52,7 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichDeleteGroupsApiDocs(ws.POST("/usergroups/delete").To(h.DeleteGroups))) ws.Route(docs.EnrichGetGroupApiDocs(ws.GET("/usergroup/detail").To(h.GetGroup))) ws.Route(docs.EnrichGetGroupTokenApiDocs(ws.GET("/usergroup/token").To(h.GetGroupToken))) - ws.Route(docs.EnrichUpdateGroupTokenApiDocs(ws.PUT("/usergroup/token/status").To(h.UpdateGroupToken))) + ws.Route(docs.EnrichUpdateGroupTokenApiDocs(ws.PUT("/usergroup/token/status").To(h.EnableGroupToken))) ws.Route(docs.EnrichResetGroupTokenApiDocs(ws.PUT("/usergroup/token/refresh").To(h.ResetGroupToken))) ws.Route(docs.EnrichCreateStrategyApiDocs(ws.POST("/auth/strategy").To(h.CreateStrategy))) @@ -213,8 +213,8 @@ func (h *HTTPServer) GetUserToken(req *restful.Request, rsp *restful.Response) { handler.WriteHeaderAndProto(h.userMgn.GetUserToken(handler.ParseHeaderContext(), user)) } -// UpdateUserToken 更改用户的token -func (h *HTTPServer) UpdateUserToken(req *restful.Request, rsp *restful.Response) { +// EnableUserToken 更改用户的token +func (h *HTTPServer) EnableUserToken(req *restful.Request, rsp *restful.Response) { handler := &httpcommon.Handler{ Request: req, Response: rsp, @@ -228,7 +228,7 @@ func (h *HTTPServer) UpdateUserToken(req *restful.Request, rsp *restful.Response return } - handler.WriteHeaderAndProto(h.userMgn.UpdateUserToken(ctx, user)) + handler.WriteHeaderAndProto(h.userMgn.EnableUserToken(ctx, user)) } // ResetUserToken 重置用户 token @@ -358,8 +358,8 @@ func (h *HTTPServer) GetGroupToken(req *restful.Request, rsp *restful.Response) handler.WriteHeaderAndProto(h.userMgn.GetGroupToken(ctx, group)) } -// UpdateGroupToken 更新用户组 token -func (h *HTTPServer) UpdateGroupToken(req *restful.Request, rsp *restful.Response) { +// EnableGroupToken 更新用户组 token +func (h *HTTPServer) EnableGroupToken(req *restful.Request, rsp *restful.Response) { handler := &httpcommon.Handler{ Request: req, Response: rsp, @@ -373,7 +373,7 @@ func (h *HTTPServer) UpdateGroupToken(req *restful.Request, rsp *restful.Respons return } - handler.WriteHeaderAndProto(h.userMgn.UpdateGroupToken(ctx, group)) + handler.WriteHeaderAndProto(h.userMgn.EnableGroupToken(ctx, group)) } // ResetGroupToken 重置用户组 token diff --git a/apiserver/httpserver/docs/admin_apidoc.go b/apiserver/httpserver/docs/admin_apidoc.go index 6c387aeae..928a663c2 100644 --- a/apiserver/httpserver/docs/admin_apidoc.go +++ b/apiserver/httpserver/docs/admin_apidoc.go @@ -23,8 +23,8 @@ import ( "github.com/polarismesh/specification/source/go/api/v1/service_manage" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "github.com/polarismesh/polaris/admin" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" ) var ( @@ -115,7 +115,7 @@ func EnrichListLeaderElectionsApiDocs(r *restful.RouteBuilder) *restful.RouteBui return r. Doc("获取选主的结果"). Metadata(restfulspec.KeyOpenAPITags, maintainApiTags). - Returns(0, "", []model.LeaderElection{}) + Returns(0, "", []admin.LeaderElection{}) } func EnrichReleaseLeaderElectionApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { diff --git a/apiserver/httpserver/utils/handler.go b/apiserver/httpserver/utils/handler.go index 3dc8dba6f..4aa2ad794 100644 --- a/apiserver/httpserver/utils/handler.go +++ b/apiserver/httpserver/utils/handler.go @@ -84,39 +84,7 @@ func (h *Handler) parseArray(createMessage func() proto.Message, jsonDecoder *js return nil, err } } - return h.postParseMessage(requestID) -} - -func (h *Handler) postParseMessage(requestID string) (context.Context, error) { - platformID := h.Request.HeaderParameter("Platform-Id") - platformToken := h.Request.HeaderParameter("Platform-Token") - token := h.Request.HeaderParameter("Polaris-Token") - authToken := h.Request.HeaderParameter(utils.HeaderAuthTokenKey) - ctx := context.Background() - ctx = context.WithValue(ctx, utils.StringContext("request-id"), requestID) - ctx = context.WithValue(ctx, utils.StringContext("platform-id"), platformID) - ctx = context.WithValue(ctx, utils.StringContext("platform-token"), platformToken) - if token != "" { - ctx = context.WithValue(ctx, utils.StringContext("polaris-token"), token) - } - if authToken != "" { - ctx = context.WithValue(ctx, utils.ContextAuthTokenKey, authToken) - } - - var operator string - addrSlice := strings.Split(h.Request.Request.RemoteAddr, ":") - if len(addrSlice) == 2 { - operator = "HTTP:" + addrSlice[0] - if platformID != "" { - operator += "(" + platformID + ")" - } - } - if staffName := h.Request.HeaderParameter("Staffname"); staffName != "" { - operator = staffName - } - ctx = context.WithValue(ctx, utils.StringContext("operator"), operator) - - return ctx, nil + return h.ParseHeaderContext(), nil } // Parse 解析请求 @@ -126,7 +94,7 @@ func (h *Handler) Parse(message proto.Message) (context.Context, error) { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, err } - return h.postParseMessage(requestID) + return h.ParseHeaderContext(), nil } // ParseHeaderContext 将http请求header中携带的用户信息提取出来 @@ -141,6 +109,7 @@ func (h *Handler) ParseHeaderContext() context.Context { ctx = context.WithValue(ctx, utils.StringContext("request-id"), requestID) ctx = context.WithValue(ctx, utils.StringContext("platform-id"), platformID) ctx = context.WithValue(ctx, utils.StringContext("platform-token"), platformToken) + ctx = context.WithValue(ctx, utils.ContextRequestHeaders, h.Request.Request.Header) ctx = context.WithValue(ctx, utils.ContextClientAddress, h.Request.Request.RemoteAddr) if token != "" { ctx = context.WithValue(ctx, utils.StringContext("polaris-token"), token) diff --git a/apiserver/nacosserver/v1/auth.go b/apiserver/nacosserver/v1/auth.go index 2090b55dd..a3cd0830d 100644 --- a/apiserver/nacosserver/v1/auth.go +++ b/apiserver/nacosserver/v1/auth.go @@ -23,7 +23,7 @@ import ( "github.com/emicklei/go-restful/v3" nacoshttp "github.com/polarismesh/polaris/apiserver/nacosserver/v1/http" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) func (n *NacosV1Server) GetAuthServer() (*restful.WebService, error) { @@ -51,9 +51,9 @@ func (n *NacosV1Server) Login(req *restful.Request, rsp *restful.Response) { func (n *NacosV1Server) handleLogin(ctx context.Context, params map[string]string) (map[string]interface{}, error) { username := params["username"] token := params["password"] - authCtx := model.NewAcquireContext( - model.WithFromClient(), - model.WithRequestContext(ctx), + authCtx := authcommon.NewAcquireContext( + authcommon.WithFromClient(), + authcommon.WithRequestContext(ctx), ) if err := n.discoverOpt.UserSvr.CheckCredential(authCtx); err != nil { diff --git a/apiserver/xdsserverv3/rds.go b/apiserver/xdsserverv3/rds.go index fa33c5e05..2174d365c 100644 --- a/apiserver/xdsserverv3/rds.go +++ b/apiserver/xdsserverv3/rds.go @@ -231,8 +231,8 @@ func (rds *RDSBuilder) makeGatewayRoutes(option *resource.BuildOption) ([]*route continue } - for i := range rule.RuleRouting.Rules { - subRule := rule.RuleRouting.Rules[i] + for i := range rule.RuleRouting.RuleRouting.Rules { + subRule := rule.RuleRouting.RuleRouting.Rules[i] // 先判断 dest 的服务是否满足目标 namespace var ( matchNamespace bool diff --git a/auth/api.go b/auth/api.go index e571b6e18..56fe7fdf0 100644 --- a/auth/api.go +++ b/auth/api.go @@ -25,30 +25,44 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/store" ) // AuthChecker 权限管理通用接口定义 type AuthChecker interface { // CheckClientPermission 执行检查客户端动作判断是否有权限,并且对 RequestContext 注入操作者数据 - CheckClientPermission(preCtx *model.AcquireContext) (bool, error) + CheckClientPermission(preCtx *authcommon.AcquireContext) (bool, error) // CheckConsolePermission 执行检查控制台动作判断是否有权限,并且对 RequestContext 注入操作者数据 - CheckConsolePermission(preCtx *model.AcquireContext) (bool, error) + CheckConsolePermission(preCtx *authcommon.AcquireContext) (bool, error) // IsOpenConsoleAuth 返回是否开启了操作鉴权,可以用于前端查询 IsOpenConsoleAuth() bool // IsOpenClientAuth IsOpenClientAuth() bool - // AllowResourceOperate 是否允许资源的操作 - AllowResourceOperate(ctx *model.AcquireContext, opInfo *model.ResourceOpInfo) bool + // ResourcePredicate 是否允许资源的操作 + ResourcePredicate(ctx *authcommon.AcquireContext, opInfo *authcommon.ResourceEntry) bool } // StrategyServer 策略相关操作 type StrategyServer interface { // Initialize 执行初始化动作 - Initialize(options *Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr UserServer) error + Initialize(*Config, store.Store, cachetypes.CacheManager, UserServer) error // Name 策略管理server名称 Name() string + // PolicyOperator . + PolicyOperator + // RoleOperator . + RoleOperator + // PolicyHelper . + PolicyHelper() PolicyHelper + // GetAuthChecker 获取鉴权检查器 + GetAuthChecker() AuthChecker + // AfterResourceOperation 操作完资源的后置处理逻辑 + AfterResourceOperation(afterCtx *authcommon.AcquireContext) error +} + +// PolicyOperator 策略管理 +type PolicyOperator interface { // CreateStrategy 创建策略 CreateStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response // UpdateStrategies 批量更新策略 @@ -63,22 +77,30 @@ type StrategyServer interface { GetStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response // GetPrincipalResources 获取某个 principal 的所有可操作资源列表 GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response - // GetAuthChecker 获取鉴权检查器 - GetAuthChecker() AuthChecker - // AfterResourceOperation 操作完资源的后置处理逻辑 - AfterResourceOperation(afterCtx *model.AcquireContext) error +} + +// RoleOperator 角色管理 +type RoleOperator interface { + // CreateRoles 批量创建角色 + CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse + // UpdateRoles 批量更新角色 + UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse + // DeleteRoles 批量删除角色 + DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse + // GetRoles 查询角色列表 + GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse } // UserServer 用户数据管理 server type UserServer interface { // Initialize 初始化 - Initialize(authOpt *Config, storage store.Store, cacheMgn cachetypes.CacheManager) error + Initialize(*Config, store.Store, StrategyServer, cachetypes.CacheManager) error // Name 用户数据管理server名称 Name() string // Login 登录动作 Login(req *apisecurity.LoginRequest) *apiservice.Response // CheckCredential 检查当前操作用户凭证 - CheckCredential(authCtx *model.AcquireContext) error + CheckCredential(authCtx *authcommon.AcquireContext) error // UserOperator UserOperator // GroupOperator @@ -100,8 +122,8 @@ type UserOperator interface { GetUsers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse // GetUserToken 获取用户的 token GetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response - // UpdateUserToken 禁止用户的token使用 - UpdateUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response + // EnableUserToken 禁止用户的token使用 + EnableUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response // ResetUserToken 重置用户的token ResetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response } @@ -119,8 +141,8 @@ type GroupOperator interface { GetGroup(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response // GetGroupToken 获取用户组的 token GetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response - // UpdateGroupToken 取消用户组的 token 使用 - UpdateGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response + // EnableGroupToken 取消用户组的 token 使用 + EnableGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response // ResetGroupToken 重置用户组的 token ResetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response } @@ -142,6 +164,14 @@ type UserHelper interface { GetGroup(ctx context.Context, req *apisecurity.UserGroup) *apisecurity.UserGroup } +// PolicyHelper . +type PolicyHelper interface { + // CreatePrincipal 创建 principal 的默认 policy 资源 + CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error + // CleanPrincipal 清理 principal 所关联的 policy、role 资源 + CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error +} + // OperatorInfo 根据 token 解析出来的具体额外信息 type OperatorInfo struct { // Origin 原始 token 字符串 @@ -151,7 +181,7 @@ type OperatorInfo struct { // OwnerID 当前用户/用户组对应的 owner OwnerID string // Role 如果当前是 user token 的话,该值才能有信息 - Role model.UserRoleType + Role authcommon.UserRoleType // IsUserToken 当前 token 是否是 user 的 token IsUserToken bool // Disable 标识用户 token 是否被禁用 @@ -176,7 +206,7 @@ func IsEmptyOperator(t OperatorInfo) bool { // IsSubAccount 当前 token 对应的账户类型 func IsSubAccount(t OperatorInfo) bool { - return t.Role == model.SubAccountUserRole + return t.Role == authcommon.SubAccountUserRole } func (t *OperatorInfo) String() string { diff --git a/auth/auth.go b/auth/auth.go index 8cbb006ed..2c04e8365 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -167,7 +167,7 @@ func initialize(_ context.Context, authOpt *Config, storage store.Store, return nil, nil, fmt.Errorf("no such StrategyServer plugin. name(%s)", policyMgrName) } - if err := userMgr.Initialize(authOpt, storage, cacheMgr); err != nil { + if err := userMgr.Initialize(authOpt, storage, policyMgr, cacheMgr); err != nil { log.Printf("UserServer do initialize err: %s", err.Error()) return nil, nil, err } diff --git a/auth/mock/api_mock.go b/auth/mock/api_mock.go index 94b66872f..6e2f0deae 100644 --- a/auth/mock/api_mock.go +++ b/auth/mock/api_mock.go @@ -9,194 +9,190 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - security "github.com/polarismesh/specification/source/go/api/v1/security" - service_manage "github.com/polarismesh/specification/source/go/api/v1/service_manage" - auth "github.com/polarismesh/polaris/auth" - cache "github.com/polarismesh/polaris/cache" - model "github.com/polarismesh/polaris/common/model" + api "github.com/polarismesh/polaris/cache/api" + auth0 "github.com/polarismesh/polaris/common/model/auth" store "github.com/polarismesh/polaris/store" + security "github.com/polarismesh/specification/source/go/api/v1/security" + service_manage "github.com/polarismesh/specification/source/go/api/v1/service_manage" ) -// MockAuthServer is a mock of AuthServer interface. -type MockAuthServer struct { +// MockAuthChecker is a mock of AuthChecker interface. +type MockAuthChecker struct { ctrl *gomock.Controller - recorder *MockAuthServerMockRecorder + recorder *MockAuthCheckerMockRecorder } -// MockAuthServerMockRecorder is the mock recorder for MockAuthServer. -type MockAuthServerMockRecorder struct { - mock *MockAuthServer +// MockAuthCheckerMockRecorder is the mock recorder for MockAuthChecker. +type MockAuthCheckerMockRecorder struct { + mock *MockAuthChecker } -// NewMockAuthServer creates a new mock instance. -func NewMockAuthServer(ctrl *gomock.Controller) *MockAuthServer { - mock := &MockAuthServer{ctrl: ctrl} - mock.recorder = &MockAuthServerMockRecorder{mock} +// NewMockAuthChecker creates a new mock instance. +func NewMockAuthChecker(ctrl *gomock.Controller) *MockAuthChecker { + mock := &MockAuthChecker{ctrl: ctrl} + mock.recorder = &MockAuthCheckerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAuthServer) EXPECT() *MockAuthServerMockRecorder { +func (m *MockAuthChecker) EXPECT() *MockAuthCheckerMockRecorder { return m.recorder } -// AfterResourceOperation mocks base method. -func (m *MockAuthServer) AfterResourceOperation(afterCtx *model.AcquireContext) error { +// CheckClientPermission mocks base method. +func (m *MockAuthChecker) CheckClientPermission(preCtx *auth0.AcquireContext) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AfterResourceOperation", afterCtx) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "CheckClientPermission", preCtx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// AfterResourceOperation indicates an expected call of AfterResourceOperation. -func (mr *MockAuthServerMockRecorder) AfterResourceOperation(afterCtx interface{}) *gomock.Call { +// CheckClientPermission indicates an expected call of CheckClientPermission. +func (mr *MockAuthCheckerMockRecorder) CheckClientPermission(preCtx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterResourceOperation", reflect.TypeOf((*MockAuthServer)(nil).AfterResourceOperation), afterCtx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckClientPermission", reflect.TypeOf((*MockAuthChecker)(nil).CheckClientPermission), preCtx) } -// CreateGroup mocks base method. -func (m *MockAuthServer) CreateGroup(ctx context.Context, group *security.UserGroup) *service_manage.Response { +// CheckConsolePermission mocks base method. +func (m *MockAuthChecker) CheckConsolePermission(preCtx *auth0.AcquireContext) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGroup", ctx, group) - ret0, _ := ret[0].(*service_manage.Response) - return ret0 + ret := m.ctrl.Call(m, "CheckConsolePermission", preCtx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// CreateGroup indicates an expected call of CreateGroup. -func (mr *MockAuthServerMockRecorder) CreateGroup(ctx, group interface{}) *gomock.Call { +// CheckConsolePermission indicates an expected call of CheckConsolePermission. +func (mr *MockAuthCheckerMockRecorder) CheckConsolePermission(preCtx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroup", reflect.TypeOf((*MockAuthServer)(nil).CreateGroup), ctx, group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConsolePermission", reflect.TypeOf((*MockAuthChecker)(nil).CheckConsolePermission), preCtx) } -// CreateStrategy mocks base method. -func (m *MockAuthServer) CreateStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { +// IsOpenClientAuth mocks base method. +func (m *MockAuthChecker) IsOpenClientAuth() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateStrategy", ctx, strategy) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "IsOpenClientAuth") + ret0, _ := ret[0].(bool) return ret0 } -// CreateStrategy indicates an expected call of CreateStrategy. -func (mr *MockAuthServerMockRecorder) CreateStrategy(ctx, strategy interface{}) *gomock.Call { +// IsOpenClientAuth indicates an expected call of IsOpenClientAuth. +func (mr *MockAuthCheckerMockRecorder) IsOpenClientAuth() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStrategy", reflect.TypeOf((*MockAuthServer)(nil).CreateStrategy), ctx, strategy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOpenClientAuth", reflect.TypeOf((*MockAuthChecker)(nil).IsOpenClientAuth)) } -// CreateUsers mocks base method. -func (m *MockAuthServer) CreateUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { +// IsOpenConsoleAuth mocks base method. +func (m *MockAuthChecker) IsOpenConsoleAuth() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateUsers", ctx, users) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "IsOpenConsoleAuth") + ret0, _ := ret[0].(bool) return ret0 } -// CreateUsers indicates an expected call of CreateUsers. -func (mr *MockAuthServerMockRecorder) CreateUsers(ctx, users interface{}) *gomock.Call { +// IsOpenConsoleAuth indicates an expected call of IsOpenConsoleAuth. +func (mr *MockAuthCheckerMockRecorder) IsOpenConsoleAuth() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUsers", reflect.TypeOf((*MockAuthServer)(nil).CreateUsers), ctx, users) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOpenConsoleAuth", reflect.TypeOf((*MockAuthChecker)(nil).IsOpenConsoleAuth)) } -// DeleteGroups mocks base method. -func (m *MockAuthServer) DeleteGroups(ctx context.Context, group []*security.UserGroup) *service_manage.BatchWriteResponse { +// ResourcePredicate mocks base method. +func (m *MockAuthChecker) ResourcePredicate(ctx *auth0.AcquireContext, opInfo *auth0.ResourceEntry) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGroups", ctx, group) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "ResourcePredicate", ctx, opInfo) + ret0, _ := ret[0].(bool) return ret0 } -// DeleteGroups indicates an expected call of DeleteGroups. -func (mr *MockAuthServerMockRecorder) DeleteGroups(ctx, group interface{}) *gomock.Call { +// ResourcePredicate indicates an expected call of ResourcePredicate. +func (mr *MockAuthCheckerMockRecorder) ResourcePredicate(ctx, opInfo interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroups", reflect.TypeOf((*MockAuthServer)(nil).DeleteGroups), ctx, group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourcePredicate", reflect.TypeOf((*MockAuthChecker)(nil).ResourcePredicate), ctx, opInfo) } -// DeleteStrategies mocks base method. -func (m *MockAuthServer) DeleteStrategies(ctx context.Context, reqs []*security.AuthStrategy) *service_manage.BatchWriteResponse { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteStrategies", ctx, reqs) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) - return ret0 +// MockStrategyServer is a mock of StrategyServer interface. +type MockStrategyServer struct { + ctrl *gomock.Controller + recorder *MockStrategyServerMockRecorder } -// DeleteStrategies indicates an expected call of DeleteStrategies. -func (mr *MockAuthServerMockRecorder) DeleteStrategies(ctx, reqs interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStrategies", reflect.TypeOf((*MockAuthServer)(nil).DeleteStrategies), ctx, reqs) +// MockStrategyServerMockRecorder is the mock recorder for MockStrategyServer. +type MockStrategyServerMockRecorder struct { + mock *MockStrategyServer } -// DeleteUsers mocks base method. -func (m *MockAuthServer) DeleteUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUsers", ctx, users) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) - return ret0 +// NewMockStrategyServer creates a new mock instance. +func NewMockStrategyServer(ctrl *gomock.Controller) *MockStrategyServer { + mock := &MockStrategyServer{ctrl: ctrl} + mock.recorder = &MockStrategyServerMockRecorder{mock} + return mock } -// DeleteUsers indicates an expected call of DeleteUsers. -func (mr *MockAuthServerMockRecorder) DeleteUsers(ctx, users interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUsers", reflect.TypeOf((*MockAuthServer)(nil).DeleteUsers), ctx, users) +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStrategyServer) EXPECT() *MockStrategyServerMockRecorder { + return m.recorder } -// GetAuthChecker mocks base method. -func (m *MockAuthServer) GetAuthChecker() auth.AuthChecker { +// AfterResourceOperation mocks base method. +func (m *MockStrategyServer) AfterResourceOperation(afterCtx *auth0.AcquireContext) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAuthChecker") - ret0, _ := ret[0].(auth.AuthChecker) + ret := m.ctrl.Call(m, "AfterResourceOperation", afterCtx) + ret0, _ := ret[0].(error) return ret0 } -// GetAuthChecker indicates an expected call of GetAuthChecker. -func (mr *MockAuthServerMockRecorder) GetAuthChecker() *gomock.Call { +// AfterResourceOperation indicates an expected call of AfterResourceOperation. +func (mr *MockStrategyServerMockRecorder) AfterResourceOperation(afterCtx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthChecker", reflect.TypeOf((*MockAuthServer)(nil).GetAuthChecker)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterResourceOperation", reflect.TypeOf((*MockStrategyServer)(nil).AfterResourceOperation), afterCtx) } -// GetGroup mocks base method. -func (m *MockAuthServer) GetGroup(ctx context.Context, req *security.UserGroup) *service_manage.Response { +// CreateStrategy mocks base method. +func (m *MockStrategyServer) CreateStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroup", ctx, req) + ret := m.ctrl.Call(m, "CreateStrategy", ctx, strategy) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// GetGroup indicates an expected call of GetGroup. -func (mr *MockAuthServerMockRecorder) GetGroup(ctx, req interface{}) *gomock.Call { +// CreateStrategy indicates an expected call of CreateStrategy. +func (mr *MockStrategyServerMockRecorder) CreateStrategy(ctx, strategy interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroup", reflect.TypeOf((*MockAuthServer)(nil).GetGroup), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStrategy", reflect.TypeOf((*MockStrategyServer)(nil).CreateStrategy), ctx, strategy) } -// GetGroupToken mocks base method. -func (m *MockAuthServer) GetGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { +// DeleteStrategies mocks base method. +func (m *MockStrategyServer) DeleteStrategies(ctx context.Context, reqs []*security.AuthStrategy) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupToken", ctx, group) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "DeleteStrategies", ctx, reqs) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) return ret0 } -// GetGroupToken indicates an expected call of GetGroupToken. -func (mr *MockAuthServerMockRecorder) GetGroupToken(ctx, group interface{}) *gomock.Call { +// DeleteStrategies indicates an expected call of DeleteStrategies. +func (mr *MockStrategyServerMockRecorder) DeleteStrategies(ctx, reqs interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupToken", reflect.TypeOf((*MockAuthServer)(nil).GetGroupToken), ctx, group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStrategies", reflect.TypeOf((*MockStrategyServer)(nil).DeleteStrategies), ctx, reqs) } -// GetGroups mocks base method. -func (m *MockAuthServer) GetGroups(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { +// GetAuthChecker mocks base method. +func (m *MockStrategyServer) GetAuthChecker() auth.AuthChecker { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroups", ctx, query) - ret0, _ := ret[0].(*service_manage.BatchQueryResponse) + ret := m.ctrl.Call(m, "GetAuthChecker") + ret0, _ := ret[0].(auth.AuthChecker) return ret0 } -// GetGroups indicates an expected call of GetGroups. -func (mr *MockAuthServerMockRecorder) GetGroups(ctx, query interface{}) *gomock.Call { +// GetAuthChecker indicates an expected call of GetAuthChecker. +func (mr *MockStrategyServerMockRecorder) GetAuthChecker() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockAuthServer)(nil).GetGroups), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthChecker", reflect.TypeOf((*MockStrategyServer)(nil).GetAuthChecker)) } // GetPrincipalResources mocks base method. -func (m *MockAuthServer) GetPrincipalResources(ctx context.Context, query map[string]string) *service_manage.Response { +func (m *MockStrategyServer) GetPrincipalResources(ctx context.Context, query map[string]string) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPrincipalResources", ctx, query) ret0, _ := ret[0].(*service_manage.Response) @@ -204,13 +200,13 @@ func (m *MockAuthServer) GetPrincipalResources(ctx context.Context, query map[st } // GetPrincipalResources indicates an expected call of GetPrincipalResources. -func (mr *MockAuthServerMockRecorder) GetPrincipalResources(ctx, query interface{}) *gomock.Call { +func (mr *MockStrategyServerMockRecorder) GetPrincipalResources(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalResources", reflect.TypeOf((*MockAuthServer)(nil).GetPrincipalResources), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalResources", reflect.TypeOf((*MockStrategyServer)(nil).GetPrincipalResources), ctx, query) } // GetStrategies mocks base method. -func (m *MockAuthServer) GetStrategies(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { +func (m *MockStrategyServer) GetStrategies(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStrategies", ctx, query) ret0, _ := ret[0].(*service_manage.BatchQueryResponse) @@ -218,13 +214,13 @@ func (m *MockAuthServer) GetStrategies(ctx context.Context, query map[string]str } // GetStrategies indicates an expected call of GetStrategies. -func (mr *MockAuthServerMockRecorder) GetStrategies(ctx, query interface{}) *gomock.Call { +func (mr *MockStrategyServerMockRecorder) GetStrategies(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategies", reflect.TypeOf((*MockAuthServer)(nil).GetStrategies), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategies", reflect.TypeOf((*MockStrategyServer)(nil).GetStrategies), ctx, query) } // GetStrategy mocks base method. -func (m *MockAuthServer) GetStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { +func (m *MockStrategyServer) GetStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStrategy", ctx, strategy) ret0, _ := ret[0].(*service_manage.Response) @@ -232,327 +228,395 @@ func (m *MockAuthServer) GetStrategy(ctx context.Context, strategy *security.Aut } // GetStrategy indicates an expected call of GetStrategy. -func (mr *MockAuthServerMockRecorder) GetStrategy(ctx, strategy interface{}) *gomock.Call { +func (mr *MockStrategyServerMockRecorder) GetStrategy(ctx, strategy interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockAuthServer)(nil).GetStrategy), ctx, strategy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockStrategyServer)(nil).GetStrategy), ctx, strategy) } -// GetUserToken mocks base method. -func (m *MockAuthServer) GetUserToken(ctx context.Context, user *security.User) *service_manage.Response { +// Initialize mocks base method. +func (m *MockStrategyServer) Initialize(arg0 *auth.Config, arg1 store.Store, arg2 api.CacheManager, arg3 auth.UserServer) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserToken", ctx, user) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "Initialize", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) return ret0 } -// GetUserToken indicates an expected call of GetUserToken. -func (mr *MockAuthServerMockRecorder) GetUserToken(ctx, user interface{}) *gomock.Call { +// Initialize indicates an expected call of Initialize. +func (mr *MockStrategyServerMockRecorder) Initialize(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserToken", reflect.TypeOf((*MockAuthServer)(nil).GetUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockStrategyServer)(nil).Initialize), arg0, arg1, arg2, arg3) } -// GetUsers mocks base method. -func (m *MockAuthServer) GetUsers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { +// Name mocks base method. +func (m *MockStrategyServer) Name() string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUsers", ctx, query) - ret0, _ := ret[0].(*service_manage.BatchQueryResponse) + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) return ret0 } -// GetUsers indicates an expected call of GetUsers. -func (mr *MockAuthServerMockRecorder) GetUsers(ctx, query interface{}) *gomock.Call { +// Name indicates an expected call of Name. +func (mr *MockStrategyServerMockRecorder) Name() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsers", reflect.TypeOf((*MockAuthServer)(nil).GetUsers), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockStrategyServer)(nil).Name)) } -// Initialize mocks base method. -func (m *MockAuthServer) Initialize(authOpt *auth.Config, storage store.Store, cacheMgn *cache.CacheManager) error { +// UpdateStrategies mocks base method. +func (m *MockStrategyServer) UpdateStrategies(ctx context.Context, reqs []*security.ModifyAuthStrategy) *service_manage.BatchWriteResponse { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStrategies", ctx, reqs) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + return ret0 +} + +// UpdateStrategies indicates an expected call of UpdateStrategies. +func (mr *MockStrategyServerMockRecorder) UpdateStrategies(ctx, reqs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStrategies", reflect.TypeOf((*MockStrategyServer)(nil).UpdateStrategies), ctx, reqs) +} + +// MockUserServer is a mock of UserServer interface. +type MockUserServer struct { + ctrl *gomock.Controller + recorder *MockUserServerMockRecorder +} + +// MockUserServerMockRecorder is the mock recorder for MockUserServer. +type MockUserServerMockRecorder struct { + mock *MockUserServer +} + +// NewMockUserServer creates a new mock instance. +func NewMockUserServer(ctrl *gomock.Controller) *MockUserServer { + mock := &MockUserServer{ctrl: ctrl} + mock.recorder = &MockUserServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserServer) EXPECT() *MockUserServerMockRecorder { + return m.recorder +} + +// CheckCredential mocks base method. +func (m *MockUserServer) CheckCredential(authCtx *auth0.AcquireContext) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Initialize", authOpt, storage, cacheMgn) + ret := m.ctrl.Call(m, "CheckCredential", authCtx) ret0, _ := ret[0].(error) return ret0 } -// Initialize indicates an expected call of Initialize. -func (mr *MockAuthServerMockRecorder) Initialize(authOpt, storage, cacheMgn interface{}) *gomock.Call { +// CheckCredential indicates an expected call of CheckCredential. +func (mr *MockUserServerMockRecorder) CheckCredential(authCtx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockAuthServer)(nil).Initialize), authOpt, storage, cacheMgn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckCredential", reflect.TypeOf((*MockUserServer)(nil).CheckCredential), authCtx) } -// Login mocks base method. -func (m *MockAuthServer) Login(req *security.LoginRequest) *service_manage.Response { +// CreateGroup mocks base method. +func (m *MockUserServer) CreateGroup(ctx context.Context, group *security.UserGroup) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Login", req) + ret := m.ctrl.Call(m, "CreateGroup", ctx, group) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// Login indicates an expected call of Login. -func (mr *MockAuthServerMockRecorder) Login(req interface{}) *gomock.Call { +// CreateGroup indicates an expected call of CreateGroup. +func (mr *MockUserServerMockRecorder) CreateGroup(ctx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Login", reflect.TypeOf((*MockAuthServer)(nil).Login), req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroup", reflect.TypeOf((*MockUserServer)(nil).CreateGroup), ctx, group) } -// Name mocks base method. -func (m *MockAuthServer) Name() string { +// CreateUsers mocks base method. +func (m *MockUserServer) CreateUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Name") - ret0, _ := ret[0].(string) + ret := m.ctrl.Call(m, "CreateUsers", ctx, users) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) return ret0 } -// Name indicates an expected call of Name. -func (mr *MockAuthServerMockRecorder) Name() *gomock.Call { +// CreateUsers indicates an expected call of CreateUsers. +func (mr *MockUserServerMockRecorder) CreateUsers(ctx, users interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockAuthServer)(nil).Name)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUsers", reflect.TypeOf((*MockUserServer)(nil).CreateUsers), ctx, users) } -// ResetGroupToken mocks base method. -func (m *MockAuthServer) ResetGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { +// DeleteGroups mocks base method. +func (m *MockUserServer) DeleteGroups(ctx context.Context, group []*security.UserGroup) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetGroupToken", ctx, group) + ret := m.ctrl.Call(m, "DeleteGroups", ctx, group) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + return ret0 +} + +// DeleteGroups indicates an expected call of DeleteGroups. +func (mr *MockUserServerMockRecorder) DeleteGroups(ctx, group interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroups", reflect.TypeOf((*MockUserServer)(nil).DeleteGroups), ctx, group) +} + +// DeleteUsers mocks base method. +func (m *MockUserServer) DeleteUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUsers", ctx, users) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + return ret0 +} + +// DeleteUsers indicates an expected call of DeleteUsers. +func (mr *MockUserServerMockRecorder) DeleteUsers(ctx, users interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUsers", reflect.TypeOf((*MockUserServer)(nil).DeleteUsers), ctx, users) +} + +// EnableGroupToken mocks base method. +func (m *MockUserServer) EnableGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableGroupToken", ctx, group) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// ResetGroupToken indicates an expected call of ResetGroupToken. -func (mr *MockAuthServerMockRecorder) ResetGroupToken(ctx, group interface{}) *gomock.Call { +// EnableGroupToken indicates an expected call of EnableGroupToken. +func (mr *MockUserServerMockRecorder) EnableGroupToken(ctx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetGroupToken", reflect.TypeOf((*MockAuthServer)(nil).ResetGroupToken), ctx, group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableGroupToken", reflect.TypeOf((*MockUserServer)(nil).EnableGroupToken), ctx, group) } -// ResetUserToken mocks base method. -func (m *MockAuthServer) ResetUserToken(ctx context.Context, user *security.User) *service_manage.Response { +// EnableUserToken mocks base method. +func (m *MockUserServer) EnableUserToken(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResetUserToken", ctx, user) + ret := m.ctrl.Call(m, "EnableUserToken", ctx, user) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// ResetUserToken indicates an expected call of ResetUserToken. -func (mr *MockAuthServerMockRecorder) ResetUserToken(ctx, user interface{}) *gomock.Call { +// EnableUserToken indicates an expected call of EnableUserToken. +func (mr *MockUserServerMockRecorder) EnableUserToken(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetUserToken", reflect.TypeOf((*MockAuthServer)(nil).ResetUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableUserToken", reflect.TypeOf((*MockUserServer)(nil).EnableUserToken), ctx, user) } -// UpdateGroupToken mocks base method. -func (m *MockAuthServer) UpdateGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { +// GetGroup mocks base method. +func (m *MockUserServer) GetGroup(ctx context.Context, req *security.UserGroup) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGroupToken", ctx, group) + ret := m.ctrl.Call(m, "GetGroup", ctx, req) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// UpdateGroupToken indicates an expected call of UpdateGroupToken. -func (mr *MockAuthServerMockRecorder) UpdateGroupToken(ctx, group interface{}) *gomock.Call { +// GetGroup indicates an expected call of GetGroup. +func (mr *MockUserServerMockRecorder) GetGroup(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroupToken", reflect.TypeOf((*MockAuthServer)(nil).UpdateGroupToken), ctx, group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroup", reflect.TypeOf((*MockUserServer)(nil).GetGroup), ctx, req) } -// UpdateGroups mocks base method. -func (m *MockAuthServer) UpdateGroups(ctx context.Context, groups []*security.ModifyUserGroup) *service_manage.BatchWriteResponse { +// GetGroupToken mocks base method. +func (m *MockUserServer) GetGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGroups", ctx, groups) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "GetGroupToken", ctx, group) + ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// UpdateGroups indicates an expected call of UpdateGroups. -func (mr *MockAuthServerMockRecorder) UpdateGroups(ctx, groups interface{}) *gomock.Call { +// GetGroupToken indicates an expected call of GetGroupToken. +func (mr *MockUserServerMockRecorder) GetGroupToken(ctx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockAuthServer)(nil).UpdateGroups), ctx, groups) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupToken", reflect.TypeOf((*MockUserServer)(nil).GetGroupToken), ctx, group) } -// UpdateStrategies mocks base method. -func (m *MockAuthServer) UpdateStrategies(ctx context.Context, reqs []*security.ModifyAuthStrategy) *service_manage.BatchWriteResponse { +// GetGroups mocks base method. +func (m *MockUserServer) GetGroups(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateStrategies", ctx, reqs) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "GetGroups", ctx, query) + ret0, _ := ret[0].(*service_manage.BatchQueryResponse) return ret0 } -// UpdateStrategies indicates an expected call of UpdateStrategies. -func (mr *MockAuthServerMockRecorder) UpdateStrategies(ctx, reqs interface{}) *gomock.Call { +// GetGroups indicates an expected call of GetGroups. +func (mr *MockUserServerMockRecorder) GetGroups(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStrategies", reflect.TypeOf((*MockAuthServer)(nil).UpdateStrategies), ctx, reqs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockUserServer)(nil).GetGroups), ctx, query) } -// UpdateUser mocks base method. -func (m *MockAuthServer) UpdateUser(ctx context.Context, user *security.User) *service_manage.Response { +// GetUserHelper mocks base method. +func (m *MockUserServer) GetUserHelper() auth.UserHelper { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUser", ctx, user) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "GetUserHelper") + ret0, _ := ret[0].(auth.UserHelper) return ret0 } -// UpdateUser indicates an expected call of UpdateUser. -func (mr *MockAuthServerMockRecorder) UpdateUser(ctx, user interface{}) *gomock.Call { +// GetUserHelper indicates an expected call of GetUserHelper. +func (mr *MockUserServerMockRecorder) GetUserHelper() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockAuthServer)(nil).UpdateUser), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserHelper", reflect.TypeOf((*MockUserServer)(nil).GetUserHelper)) } -// UpdateUserPassword mocks base method. -func (m *MockAuthServer) UpdateUserPassword(ctx context.Context, req *security.ModifyUserPassword) *service_manage.Response { +// GetUserToken mocks base method. +func (m *MockUserServer) GetUserToken(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUserPassword", ctx, req) + ret := m.ctrl.Call(m, "GetUserToken", ctx, user) ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// UpdateUserPassword indicates an expected call of UpdateUserPassword. -func (mr *MockAuthServerMockRecorder) UpdateUserPassword(ctx, req interface{}) *gomock.Call { +// GetUserToken indicates an expected call of GetUserToken. +func (mr *MockUserServerMockRecorder) GetUserToken(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserPassword", reflect.TypeOf((*MockAuthServer)(nil).UpdateUserPassword), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserToken", reflect.TypeOf((*MockUserServer)(nil).GetUserToken), ctx, user) } -// UpdateUserToken mocks base method. -func (m *MockAuthServer) UpdateUserToken(ctx context.Context, user *security.User) *service_manage.Response { +// GetUsers mocks base method. +func (m *MockUserServer) GetUsers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUserToken", ctx, user) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "GetUsers", ctx, query) + ret0, _ := ret[0].(*service_manage.BatchQueryResponse) return ret0 } -// UpdateUserToken indicates an expected call of UpdateUserToken. -func (mr *MockAuthServerMockRecorder) UpdateUserToken(ctx, user interface{}) *gomock.Call { +// GetUsers indicates an expected call of GetUsers. +func (mr *MockUserServerMockRecorder) GetUsers(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserToken", reflect.TypeOf((*MockAuthServer)(nil).UpdateUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsers", reflect.TypeOf((*MockUserServer)(nil).GetUsers), ctx, query) } -// MockAuthChecker is a mock of AuthChecker interface. -type MockAuthChecker struct { - ctrl *gomock.Controller - recorder *MockAuthCheckerMockRecorder +// Initialize mocks base method. +func (m *MockUserServer) Initialize(arg0 *auth.Config, arg1 store.Store, arg2 auth.StrategyServer, arg3 api.CacheManager) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Initialize", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 } -// MockAuthCheckerMockRecorder is the mock recorder for MockAuthChecker. -type MockAuthCheckerMockRecorder struct { - mock *MockAuthChecker +// Initialize indicates an expected call of Initialize. +func (mr *MockUserServerMockRecorder) Initialize(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockUserServer)(nil).Initialize), arg0, arg1, arg2, arg3) } -// NewMockAuthChecker creates a new mock instance. -func NewMockAuthChecker(ctrl *gomock.Controller) *MockAuthChecker { - mock := &MockAuthChecker{ctrl: ctrl} - mock.recorder = &MockAuthCheckerMockRecorder{mock} - return mock +// Login mocks base method. +func (m *MockUserServer) Login(req *security.LoginRequest) *service_manage.Response { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Login", req) + ret0, _ := ret[0].(*service_manage.Response) + return ret0 } -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAuthChecker) EXPECT() *MockAuthCheckerMockRecorder { - return m.recorder +// Login indicates an expected call of Login. +func (mr *MockUserServerMockRecorder) Login(req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Login", reflect.TypeOf((*MockUserServer)(nil).Login), req) } -// CheckClientPermission mocks base method. -func (m *MockAuthChecker) CheckClientPermission(preCtx *model.AcquireContext) (bool, error) { +// Name mocks base method. +func (m *MockUserServer) Name() string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckClientPermission", preCtx) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 } -// CheckClientPermission indicates an expected call of CheckClientPermission. -func (mr *MockAuthCheckerMockRecorder) CheckClientPermission(preCtx interface{}) *gomock.Call { +// Name indicates an expected call of Name. +func (mr *MockUserServerMockRecorder) Name() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckClientPermission", reflect.TypeOf((*MockAuthChecker)(nil).CheckClientPermission), preCtx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockUserServer)(nil).Name)) } -// CheckConsolePermission mocks base method. -func (m *MockAuthChecker) CheckConsolePermission(preCtx *model.AcquireContext) (bool, error) { +// ResetGroupToken mocks base method. +func (m *MockUserServer) ResetGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckConsolePermission", preCtx) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "ResetGroupToken", ctx, group) + ret0, _ := ret[0].(*service_manage.Response) + return ret0 } -// CheckConsolePermission indicates an expected call of CheckConsolePermission. -func (mr *MockAuthCheckerMockRecorder) CheckConsolePermission(preCtx interface{}) *gomock.Call { +// ResetGroupToken indicates an expected call of ResetGroupToken. +func (mr *MockUserServerMockRecorder) ResetGroupToken(ctx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConsolePermission", reflect.TypeOf((*MockAuthChecker)(nil).CheckConsolePermission), preCtx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetGroupToken", reflect.TypeOf((*MockUserServer)(nil).ResetGroupToken), ctx, group) } -// Initialize mocks base method. -func (m *MockAuthChecker) Initialize(options *auth.Config, storage store.Store, cacheMgn *cache.CacheManager) error { +// ResetUserToken mocks base method. +func (m *MockUserServer) ResetUserToken(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Initialize", options, storage, cacheMgn) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "ResetUserToken", ctx, user) + ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// Initialize indicates an expected call of Initialize. -func (mr *MockAuthCheckerMockRecorder) Initialize(options, storage, cacheMgn interface{}) *gomock.Call { +// ResetUserToken indicates an expected call of ResetUserToken. +func (mr *MockUserServerMockRecorder) ResetUserToken(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockAuthChecker)(nil).Initialize), options, storage, cacheMgn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetUserToken", reflect.TypeOf((*MockUserServer)(nil).ResetUserToken), ctx, user) } -// IsOpenClientAuth mocks base method. -func (m *MockAuthChecker) IsOpenClientAuth() bool { +// UpdateGroups mocks base method. +func (m *MockUserServer) UpdateGroups(ctx context.Context, groups []*security.ModifyUserGroup) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsOpenClientAuth") - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "UpdateGroups", ctx, groups) + ret0, _ := ret[0].(*service_manage.BatchWriteResponse) return ret0 } -// IsOpenClientAuth indicates an expected call of IsOpenClientAuth. -func (mr *MockAuthCheckerMockRecorder) IsOpenClientAuth() *gomock.Call { +// UpdateGroups indicates an expected call of UpdateGroups. +func (mr *MockUserServerMockRecorder) UpdateGroups(ctx, groups interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOpenClientAuth", reflect.TypeOf((*MockAuthChecker)(nil).IsOpenClientAuth)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockUserServer)(nil).UpdateGroups), ctx, groups) } -// IsOpenConsoleAuth mocks base method. -func (m *MockAuthChecker) IsOpenConsoleAuth() bool { +// UpdateUser mocks base method. +func (m *MockUserServer) UpdateUser(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsOpenConsoleAuth") - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "UpdateUser", ctx, user) + ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// IsOpenConsoleAuth indicates an expected call of IsOpenConsoleAuth. -func (mr *MockAuthCheckerMockRecorder) IsOpenConsoleAuth() *gomock.Call { +// UpdateUser indicates an expected call of UpdateUser. +func (mr *MockUserServerMockRecorder) UpdateUser(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOpenConsoleAuth", reflect.TypeOf((*MockAuthChecker)(nil).IsOpenConsoleAuth)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockUserServer)(nil).UpdateUser), ctx, user) } -// VerifyCredential mocks base method. -func (m *MockAuthChecker) VerifyCredential(preCtx *model.AcquireContext) error { +// UpdateUserPassword mocks base method. +func (m *MockUserServer) UpdateUserPassword(ctx context.Context, req *security.ModifyUserPassword) *service_manage.Response { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "VerifyCredential", preCtx) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "UpdateUserPassword", ctx, req) + ret0, _ := ret[0].(*service_manage.Response) return ret0 } -// VerifyCredential indicates an expected call of VerifyCredential. -func (mr *MockAuthCheckerMockRecorder) VerifyCredential(preCtx interface{}) *gomock.Call { +// UpdateUserPassword indicates an expected call of UpdateUserPassword. +func (mr *MockUserServerMockRecorder) UpdateUserPassword(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCredential", reflect.TypeOf((*MockAuthChecker)(nil).VerifyCredential), preCtx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserPassword", reflect.TypeOf((*MockUserServer)(nil).UpdateUserPassword), ctx, req) } -// MockUserServer is a mock of UserServer interface. -type MockUserServer struct { +// MockUserOperator is a mock of UserOperator interface. +type MockUserOperator struct { ctrl *gomock.Controller - recorder *MockUserServerMockRecorder + recorder *MockUserOperatorMockRecorder } -// MockUserServerMockRecorder is the mock recorder for MockUserServer. -type MockUserServerMockRecorder struct { - mock *MockUserServer +// MockUserOperatorMockRecorder is the mock recorder for MockUserOperator. +type MockUserOperatorMockRecorder struct { + mock *MockUserOperator } -// NewMockUserServer creates a new mock instance. -func NewMockUserServer(ctrl *gomock.Controller) *MockUserServer { - mock := &MockUserServer{ctrl: ctrl} - mock.recorder = &MockUserServerMockRecorder{mock} +// NewMockUserOperator creates a new mock instance. +func NewMockUserOperator(ctrl *gomock.Controller) *MockUserOperator { + mock := &MockUserOperator{ctrl: ctrl} + mock.recorder = &MockUserOperatorMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockUserServer) EXPECT() *MockUserServerMockRecorder { +func (m *MockUserOperator) EXPECT() *MockUserOperatorMockRecorder { return m.recorder } // CreateUsers mocks base method. -func (m *MockUserServer) CreateUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { +func (m *MockUserOperator) CreateUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateUsers", ctx, users) ret0, _ := ret[0].(*service_manage.BatchWriteResponse) @@ -560,13 +624,13 @@ func (m *MockUserServer) CreateUsers(ctx context.Context, users []*security.User } // CreateUsers indicates an expected call of CreateUsers. -func (mr *MockUserServerMockRecorder) CreateUsers(ctx, users interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) CreateUsers(ctx, users interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUsers", reflect.TypeOf((*MockUserServer)(nil).CreateUsers), ctx, users) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUsers", reflect.TypeOf((*MockUserOperator)(nil).CreateUsers), ctx, users) } // DeleteUsers mocks base method. -func (m *MockUserServer) DeleteUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { +func (m *MockUserOperator) DeleteUsers(ctx context.Context, users []*security.User) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteUsers", ctx, users) ret0, _ := ret[0].(*service_manage.BatchWriteResponse) @@ -574,13 +638,27 @@ func (m *MockUserServer) DeleteUsers(ctx context.Context, users []*security.User } // DeleteUsers indicates an expected call of DeleteUsers. -func (mr *MockUserServerMockRecorder) DeleteUsers(ctx, users interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) DeleteUsers(ctx, users interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUsers", reflect.TypeOf((*MockUserServer)(nil).DeleteUsers), ctx, users) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUsers", reflect.TypeOf((*MockUserOperator)(nil).DeleteUsers), ctx, users) +} + +// EnableUserToken mocks base method. +func (m *MockUserOperator) EnableUserToken(ctx context.Context, user *security.User) *service_manage.Response { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableUserToken", ctx, user) + ret0, _ := ret[0].(*service_manage.Response) + return ret0 +} + +// EnableUserToken indicates an expected call of EnableUserToken. +func (mr *MockUserOperatorMockRecorder) EnableUserToken(ctx, user interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableUserToken", reflect.TypeOf((*MockUserOperator)(nil).EnableUserToken), ctx, user) } // GetUserToken mocks base method. -func (m *MockUserServer) GetUserToken(ctx context.Context, user *security.User) *service_manage.Response { +func (m *MockUserOperator) GetUserToken(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserToken", ctx, user) ret0, _ := ret[0].(*service_manage.Response) @@ -588,13 +666,13 @@ func (m *MockUserServer) GetUserToken(ctx context.Context, user *security.User) } // GetUserToken indicates an expected call of GetUserToken. -func (mr *MockUserServerMockRecorder) GetUserToken(ctx, user interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) GetUserToken(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserToken", reflect.TypeOf((*MockUserServer)(nil).GetUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserToken", reflect.TypeOf((*MockUserOperator)(nil).GetUserToken), ctx, user) } // GetUsers mocks base method. -func (m *MockUserServer) GetUsers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { +func (m *MockUserOperator) GetUsers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUsers", ctx, query) ret0, _ := ret[0].(*service_manage.BatchQueryResponse) @@ -602,13 +680,13 @@ func (m *MockUserServer) GetUsers(ctx context.Context, query map[string]string) } // GetUsers indicates an expected call of GetUsers. -func (mr *MockUserServerMockRecorder) GetUsers(ctx, query interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) GetUsers(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsers", reflect.TypeOf((*MockUserServer)(nil).GetUsers), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsers", reflect.TypeOf((*MockUserOperator)(nil).GetUsers), ctx, query) } // ResetUserToken mocks base method. -func (m *MockUserServer) ResetUserToken(ctx context.Context, user *security.User) *service_manage.Response { +func (m *MockUserOperator) ResetUserToken(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ResetUserToken", ctx, user) ret0, _ := ret[0].(*service_manage.Response) @@ -616,13 +694,13 @@ func (m *MockUserServer) ResetUserToken(ctx context.Context, user *security.User } // ResetUserToken indicates an expected call of ResetUserToken. -func (mr *MockUserServerMockRecorder) ResetUserToken(ctx, user interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) ResetUserToken(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetUserToken", reflect.TypeOf((*MockUserServer)(nil).ResetUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetUserToken", reflect.TypeOf((*MockUserOperator)(nil).ResetUserToken), ctx, user) } // UpdateUser mocks base method. -func (m *MockUserServer) UpdateUser(ctx context.Context, user *security.User) *service_manage.Response { +func (m *MockUserOperator) UpdateUser(ctx context.Context, user *security.User) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateUser", ctx, user) ret0, _ := ret[0].(*service_manage.Response) @@ -630,13 +708,13 @@ func (m *MockUserServer) UpdateUser(ctx context.Context, user *security.User) *s } // UpdateUser indicates an expected call of UpdateUser. -func (mr *MockUserServerMockRecorder) UpdateUser(ctx, user interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) UpdateUser(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockUserServer)(nil).UpdateUser), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockUserOperator)(nil).UpdateUser), ctx, user) } // UpdateUserPassword mocks base method. -func (m *MockUserServer) UpdateUserPassword(ctx context.Context, req *security.ModifyUserPassword) *service_manage.Response { +func (m *MockUserOperator) UpdateUserPassword(ctx context.Context, req *security.ModifyUserPassword) *service_manage.Response { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateUserPassword", ctx, req) ret0, _ := ret[0].(*service_manage.Response) @@ -644,23 +722,9 @@ func (m *MockUserServer) UpdateUserPassword(ctx context.Context, req *security.M } // UpdateUserPassword indicates an expected call of UpdateUserPassword. -func (mr *MockUserServerMockRecorder) UpdateUserPassword(ctx, req interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserPassword", reflect.TypeOf((*MockUserServer)(nil).UpdateUserPassword), ctx, req) -} - -// UpdateUserToken mocks base method. -func (m *MockUserServer) UpdateUserToken(ctx context.Context, user *security.User) *service_manage.Response { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUserToken", ctx, user) - ret0, _ := ret[0].(*service_manage.Response) - return ret0 -} - -// UpdateUserToken indicates an expected call of UpdateUserToken. -func (mr *MockUserServerMockRecorder) UpdateUserToken(ctx, user interface{}) *gomock.Call { +func (mr *MockUserOperatorMockRecorder) UpdateUserPassword(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserToken", reflect.TypeOf((*MockUserServer)(nil).UpdateUserToken), ctx, user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserPassword", reflect.TypeOf((*MockUserOperator)(nil).UpdateUserPassword), ctx, req) } // MockGroupOperator is a mock of GroupOperator interface. @@ -714,6 +778,20 @@ func (mr *MockGroupOperatorMockRecorder) DeleteGroups(ctx, group interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroups", reflect.TypeOf((*MockGroupOperator)(nil).DeleteGroups), ctx, group) } +// EnableGroupToken mocks base method. +func (m *MockGroupOperator) EnableGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableGroupToken", ctx, group) + ret0, _ := ret[0].(*service_manage.Response) + return ret0 +} + +// EnableGroupToken indicates an expected call of EnableGroupToken. +func (mr *MockGroupOperatorMockRecorder) EnableGroupToken(ctx, group interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableGroupToken", reflect.TypeOf((*MockGroupOperator)(nil).EnableGroupToken), ctx, group) +} + // GetGroup mocks base method. func (m *MockGroupOperator) GetGroup(ctx context.Context, req *security.UserGroup) *service_manage.Response { m.ctrl.T.Helper() @@ -770,20 +848,6 @@ func (mr *MockGroupOperatorMockRecorder) ResetGroupToken(ctx, group interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetGroupToken", reflect.TypeOf((*MockGroupOperator)(nil).ResetGroupToken), ctx, group) } -// UpdateGroupToken mocks base method. -func (m *MockGroupOperator) UpdateGroupToken(ctx context.Context, group *security.UserGroup) *service_manage.Response { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGroupToken", ctx, group) - ret0, _ := ret[0].(*service_manage.Response) - return ret0 -} - -// UpdateGroupToken indicates an expected call of UpdateGroupToken. -func (mr *MockGroupOperatorMockRecorder) UpdateGroupToken(ctx, group interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroupToken", reflect.TypeOf((*MockGroupOperator)(nil).UpdateGroupToken), ctx, group) -} - // UpdateGroups mocks base method. func (m *MockGroupOperator) UpdateGroups(ctx context.Context, groups []*security.ModifyUserGroup) *service_manage.BatchWriteResponse { m.ctrl.T.Helper() @@ -798,109 +862,123 @@ func (mr *MockGroupOperatorMockRecorder) UpdateGroups(ctx, groups interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockGroupOperator)(nil).UpdateGroups), ctx, groups) } -// MockStrategyServer is a mock of StrategyServer interface. -type MockStrategyServer struct { +// MockUserHelper is a mock of UserHelper interface. +type MockUserHelper struct { ctrl *gomock.Controller - recorder *MockStrategyServerMockRecorder + recorder *MockUserHelperMockRecorder } -// MockStrategyServerMockRecorder is the mock recorder for MockStrategyServer. -type MockStrategyServerMockRecorder struct { - mock *MockStrategyServer +// MockUserHelperMockRecorder is the mock recorder for MockUserHelper. +type MockUserHelperMockRecorder struct { + mock *MockUserHelper } -// NewMockStrategyServer creates a new mock instance. -func NewMockStrategyServer(ctrl *gomock.Controller) *MockStrategyServer { - mock := &MockStrategyServer{ctrl: ctrl} - mock.recorder = &MockStrategyServerMockRecorder{mock} +// NewMockUserHelper creates a new mock instance. +func NewMockUserHelper(ctrl *gomock.Controller) *MockUserHelper { + mock := &MockUserHelper{ctrl: ctrl} + mock.recorder = &MockUserHelperMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStrategyServer) EXPECT() *MockStrategyServerMockRecorder { +func (m *MockUserHelper) EXPECT() *MockUserHelperMockRecorder { return m.recorder } -// CreateStrategy mocks base method. -func (m *MockStrategyServer) CreateStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { +// CheckGroupsExist mocks base method. +func (m *MockUserHelper) CheckGroupsExist(ctx context.Context, groups []*security.UserGroup) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateStrategy", ctx, strategy) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "CheckGroupsExist", ctx, groups) + ret0, _ := ret[0].(error) return ret0 } -// CreateStrategy indicates an expected call of CreateStrategy. -func (mr *MockStrategyServerMockRecorder) CreateStrategy(ctx, strategy interface{}) *gomock.Call { +// CheckGroupsExist indicates an expected call of CheckGroupsExist. +func (mr *MockUserHelperMockRecorder) CheckGroupsExist(ctx, groups interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStrategy", reflect.TypeOf((*MockStrategyServer)(nil).CreateStrategy), ctx, strategy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckGroupsExist", reflect.TypeOf((*MockUserHelper)(nil).CheckGroupsExist), ctx, groups) } -// DeleteStrategies mocks base method. -func (m *MockStrategyServer) DeleteStrategies(ctx context.Context, reqs []*security.AuthStrategy) *service_manage.BatchWriteResponse { +// CheckUserInGroup mocks base method. +func (m *MockUserHelper) CheckUserInGroup(ctx context.Context, group *security.UserGroup, user *security.User) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteStrategies", ctx, reqs) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "CheckUserInGroup", ctx, group, user) + ret0, _ := ret[0].(bool) return ret0 } -// DeleteStrategies indicates an expected call of DeleteStrategies. -func (mr *MockStrategyServerMockRecorder) DeleteStrategies(ctx, reqs interface{}) *gomock.Call { +// CheckUserInGroup indicates an expected call of CheckUserInGroup. +func (mr *MockUserHelperMockRecorder) CheckUserInGroup(ctx, group, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStrategies", reflect.TypeOf((*MockStrategyServer)(nil).DeleteStrategies), ctx, reqs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserInGroup", reflect.TypeOf((*MockUserHelper)(nil).CheckUserInGroup), ctx, group, user) } -// GetPrincipalResources mocks base method. -func (m *MockStrategyServer) GetPrincipalResources(ctx context.Context, query map[string]string) *service_manage.Response { +// CheckUsersExist mocks base method. +func (m *MockUserHelper) CheckUsersExist(ctx context.Context, users []*security.User) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPrincipalResources", ctx, query) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "CheckUsersExist", ctx, users) + ret0, _ := ret[0].(error) return ret0 } -// GetPrincipalResources indicates an expected call of GetPrincipalResources. -func (mr *MockStrategyServerMockRecorder) GetPrincipalResources(ctx, query interface{}) *gomock.Call { +// CheckUsersExist indicates an expected call of CheckUsersExist. +func (mr *MockUserHelperMockRecorder) CheckUsersExist(ctx, users interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalResources", reflect.TypeOf((*MockStrategyServer)(nil).GetPrincipalResources), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUsersExist", reflect.TypeOf((*MockUserHelper)(nil).CheckUsersExist), ctx, users) } -// GetStrategies mocks base method. -func (m *MockStrategyServer) GetStrategies(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { +// GetGroup mocks base method. +func (m *MockUserHelper) GetGroup(ctx context.Context, req *security.UserGroup) *security.UserGroup { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStrategies", ctx, query) - ret0, _ := ret[0].(*service_manage.BatchQueryResponse) + ret := m.ctrl.Call(m, "GetGroup", ctx, req) + ret0, _ := ret[0].(*security.UserGroup) return ret0 } -// GetStrategies indicates an expected call of GetStrategies. -func (mr *MockStrategyServerMockRecorder) GetStrategies(ctx, query interface{}) *gomock.Call { +// GetGroup indicates an expected call of GetGroup. +func (mr *MockUserHelperMockRecorder) GetGroup(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategies", reflect.TypeOf((*MockStrategyServer)(nil).GetStrategies), ctx, query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroup", reflect.TypeOf((*MockUserHelper)(nil).GetGroup), ctx, req) } -// GetStrategy mocks base method. -func (m *MockStrategyServer) GetStrategy(ctx context.Context, strategy *security.AuthStrategy) *service_manage.Response { +// GetUser mocks base method. +func (m *MockUserHelper) GetUser(ctx context.Context, user *security.User) *security.User { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStrategy", ctx, strategy) - ret0, _ := ret[0].(*service_manage.Response) + ret := m.ctrl.Call(m, "GetUser", ctx, user) + ret0, _ := ret[0].(*security.User) return ret0 } -// GetStrategy indicates an expected call of GetStrategy. -func (mr *MockStrategyServerMockRecorder) GetStrategy(ctx, strategy interface{}) *gomock.Call { +// GetUser indicates an expected call of GetUser. +func (mr *MockUserHelperMockRecorder) GetUser(ctx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockStrategyServer)(nil).GetStrategy), ctx, strategy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockUserHelper)(nil).GetUser), ctx, user) } -// UpdateStrategies mocks base method. -func (m *MockStrategyServer) UpdateStrategies(ctx context.Context, reqs []*security.ModifyAuthStrategy) *service_manage.BatchWriteResponse { +// GetUserByID mocks base method. +func (m *MockUserHelper) GetUserByID(ctx context.Context, id string) *security.User { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateStrategies", ctx, reqs) - ret0, _ := ret[0].(*service_manage.BatchWriteResponse) + ret := m.ctrl.Call(m, "GetUserByID", ctx, id) + ret0, _ := ret[0].(*security.User) return ret0 } -// UpdateStrategies indicates an expected call of UpdateStrategies. -func (mr *MockStrategyServerMockRecorder) UpdateStrategies(ctx, reqs interface{}) *gomock.Call { +// GetUserByID indicates an expected call of GetUserByID. +func (mr *MockUserHelperMockRecorder) GetUserByID(ctx, id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStrategies", reflect.TypeOf((*MockStrategyServer)(nil).UpdateStrategies), ctx, reqs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByID", reflect.TypeOf((*MockUserHelper)(nil).GetUserByID), ctx, id) +} + +// GetUserOwnGroup mocks base method. +func (m *MockUserHelper) GetUserOwnGroup(ctx context.Context, user *security.User) []*security.UserGroup { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserOwnGroup", ctx, user) + ret0, _ := ret[0].([]*security.UserGroup) + return ret0 +} + +// GetUserOwnGroup indicates an expected call of GetUserOwnGroup. +func (mr *MockUserHelperMockRecorder) GetUserOwnGroup(ctx, user interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserOwnGroup", reflect.TypeOf((*MockUserHelper)(nil).GetUserOwnGroup), ctx, user) } diff --git a/auth/policy/access.go b/auth/policy/access.go deleted file mode 100644 index 752bfaac9..000000000 --- a/auth/policy/access.go +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy - -import ( - "context" - - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" -) - -// CreateStrategy 创建鉴权策略 -func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { - return svr.handleCreateStrategy(ctx, req) -} - -// UpdateStrategies 批量修改鉴权 -func (svr *Server) UpdateStrategies( - ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { - return svr.handleUpdateStrategies(ctx, reqs) -} - -// DeleteStrategies 批量删除鉴权策略 -func (svr *Server) DeleteStrategies( - ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { - return svr.handleDeleteStrategies(ctx, reqs) -} - -// GetStrategies 批量查询鉴权策略 -func (svr *Server) GetStrategies(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - return svr.handleGetStrategies(ctx, query) -} - -// GetStrategy 查询单个鉴权策略 -func (svr *Server) GetStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { - return svr.handleGetStrategy(ctx, req) -} - -// GetPrincipalResources 查询鉴权策略所属资源 -func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { - return svr.handleGetPrincipalResources(ctx, query) -} diff --git a/common/model/admin.go b/auth/policy/action_map.go similarity index 75% rename from common/model/admin.go rename to auth/policy/action_map.go index df44779a7..16835d8e3 100644 --- a/common/model/admin.go +++ b/auth/policy/action_map.go @@ -15,17 +15,18 @@ * specific language governing permissions and limitations under the License. */ -package model +package policy -import "time" +var actionsMap map[string]string -// LeaderElection leader election info -type LeaderElection struct { - ElectKey string - Host string - Ctime int64 - CreateTime time.Time - Mtime int64 - ModifyTime time.Time - Valid bool +func InitActionMapping() { + // TODO +} + +func GetRealAction(s string) string { + val, ok := actionsMap[s] + if !ok { + return s + } + return val } diff --git a/auth/policy/auth_checker.go b/auth/policy/auth_checker.go index 828da26c5..0a25a5958 100644 --- a/auth/policy/auth_checker.go +++ b/auth/policy/auth_checker.go @@ -18,6 +18,8 @@ package policy import ( + "strings" + "github.com/pkg/errors" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "go.uber.org/zap" @@ -25,7 +27,7 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -41,9 +43,10 @@ var ( // DefaultAuthChecker 北极星自带的默认鉴权中心 type DefaultAuthChecker struct { - conf *AuthConfig - cacheMgr cachetypes.CacheManager - userSvr auth.UserServer + conf *AuthConfig + cacheMgr cachetypes.CacheManager + userSvr auth.UserServer + policyMgr *Server } // Initialize 执行初始化动作 @@ -55,6 +58,18 @@ func (d *DefaultAuthChecker) Initialize(conf *AuthConfig, s store.Store, return nil } +func (d *DefaultAuthChecker) SetCacheMgr(mgr cachetypes.CacheManager) { + d.cacheMgr = mgr +} + +func (d *DefaultAuthChecker) GetConfig() *AuthConfig { + return d.conf +} + +func (d *DefaultAuthChecker) SetConfig(conf *AuthConfig) { + d.conf = conf +} + // Cache 获取缓存统一管理 func (d *DefaultAuthChecker) Cache() cachetypes.CacheManager { return d.cacheMgr @@ -76,34 +91,21 @@ func (d *DefaultAuthChecker) IsOpenAuth() bool { } // AllowResourceOperate 是否允许资源的操作 -func (d *DefaultAuthChecker) AllowResourceOperate(ctx *model.AcquireContext, opInfo *model.ResourceOpInfo) bool { - // 如果鉴权能力没有开启,那就默认都可以进行编辑 +func (d *DefaultAuthChecker) ResourcePredicate(ctx *authcommon.AcquireContext, res *authcommon.ResourceEntry) bool { + // 如果鉴权能力没有开启,那就默认都可以进行操作 if !d.IsOpenAuth() { return true } - attachVal, ok := ctx.GetAttachment(model.TokenDetailInfoKey) + + p, ok := ctx.GetAttachment(authcommon.PrincipalKey) if !ok { - // TODO need log return false } - tokenInfo, ok := attachVal.(auth.OperatorInfo) - - principal := model.Principal{ - PrincipalID: tokenInfo.OperatorID, - PrincipalRole: func() model.PrincipalType { - if tokenInfo.IsUserToken { - return model.PrincipalUser - } - return model.PrincipalGroup - }(), - } - - editable := d.cacheMgr.AuthStrategy().IsResourceEditable(principal, opInfo.ResourceType, opInfo.ResourceID) - return editable + return d.cacheMgr.AuthStrategy().Hint(p.(authcommon.Principal), res) != apisecurity.AuthAction_DENY } // CheckClientPermission 执行检查客户端动作判断是否有权限,并且对 RequestContext 注入操作者数据 -func (d *DefaultAuthChecker) CheckClientPermission(preCtx *model.AcquireContext) (bool, error) { +func (d *DefaultAuthChecker) CheckClientPermission(preCtx *authcommon.AcquireContext) (bool, error) { preCtx.SetFromClient() if !d.IsOpenClientAuth() { return true, nil @@ -115,7 +117,7 @@ func (d *DefaultAuthChecker) CheckClientPermission(preCtx *model.AcquireContext) } // CheckConsolePermission 执行检查控制台动作判断是否有权限,并且对 RequestContext 注入操作者数据 -func (d *DefaultAuthChecker) CheckConsolePermission(preCtx *model.AcquireContext) (bool, error) { +func (d *DefaultAuthChecker) CheckConsolePermission(preCtx *authcommon.AcquireContext) (bool, error) { preCtx.SetFromConsole() if !d.IsOpenConsoleAuth() { return true, nil @@ -123,39 +125,9 @@ func (d *DefaultAuthChecker) CheckConsolePermission(preCtx *model.AcquireContext if d.IsOpenConsoleAuth() && !d.conf.ConsoleStrict { preCtx.SetAllowAnonymous(true) } - if preCtx.GetModule() == model.MaintainModule { - return d.checkMaintainPermission(preCtx) - } return d.CheckPermission(preCtx) } -// CheckMaintainPermission 执行检查运维动作判断是否有权限 -func (d *DefaultAuthChecker) checkMaintainPermission(preCtx *model.AcquireContext) (bool, error) { - if preCtx.GetOperation() == model.Read { - return true, nil - } - - attachVal, ok := preCtx.GetAttachment(model.TokenDetailInfoKey) - if !ok { - return false, model.ErrorTokenNotExist - } - tokenInfo, ok := attachVal.(auth.OperatorInfo) - if !ok { - return false, model.ErrorTokenNotExist - } - - if tokenInfo.Disable { - return false, model.ErrorTokenDisabled - } - if !tokenInfo.IsUserToken { - return false, errors.New("only user role can access maintain API") - } - if tokenInfo.Role != model.OwnerUserRole { - return false, errors.New("only owner account can access maintain API") - } - return true, nil -} - // CheckPermission 执行检查动作判断是否有权限 // // step 1. 判断是否开启了鉴权 @@ -165,95 +137,190 @@ func (d *DefaultAuthChecker) checkMaintainPermission(preCtx *model.AcquireContex // b. 写操作,快速失败 // step 3. 拉取token对应的操作者相关信息,注入到请求上下文中 // step 4. 进行权限检查 -func (d *DefaultAuthChecker) CheckPermission(authCtx *model.AcquireContext) (bool, error) { +func (d *DefaultAuthChecker) CheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { if err := d.userSvr.CheckCredential(authCtx); err != nil { return false, err } - - attachVal, ok := authCtx.GetAttachment(model.TokenDetailInfoKey) - if !ok { - return false, model.ErrorTokenNotExist - } - operatorInfo, ok := attachVal.(auth.OperatorInfo) - if !ok { - return false, model.ErrorTokenNotExist - } - // 这里需要检查当 token 被禁止的情况,如果 token 被禁止,无论是否可以操作目标资源,都无法进行写操作 - if operatorInfo.Disable { - return false, model.ErrorTokenDisabled + if log.DebugEnabled() { + log.Debug("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), + zap.String("method", string(authCtx.GetMethod())), zap.Any("resources", authCtx.GetAccessResources())) } - log.Debug("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), - zap.String("method", authCtx.GetMethod()), zap.Any("resources", authCtx.GetAccessResources())) - if pass, _ := d.doCheckPermission(authCtx); pass { - return ok, nil + return true, nil } - // 强制同步一次db中strategy数据到cache - if err := d.cacheMgr.AuthStrategy().ForceSync(); err != nil { - log.Error("[Auth][Checker] force sync strategy to cache failed", - utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + // 触发缓存的同步,避免鉴权策略和角色信息不一致导致的权限检查失败 + if err := d.resyncData(authCtx); err != nil { return false, err } return d.doCheckPermission(authCtx) } +func (d *DefaultAuthChecker) resyncData(authCtx *authcommon.AcquireContext) error { + if err := d.cacheMgr.AuthStrategy().Update(); err != nil { + log.Error("[Auth][Checker] force sync policy rule to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + return err + } + if err := d.cacheMgr.Role().Update(); err != nil { + log.Error("[Auth][Checker] force sync role to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + return err + } + return nil +} + // doCheckPermission 执行权限检查 -func (d *DefaultAuthChecker) doCheckPermission(authCtx *model.AcquireContext) (bool, error) { +func (d *DefaultAuthChecker) doCheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { + p, _ := authCtx.GetAttachments()[authcommon.PrincipalKey].(authcommon.Principal) + if d.IsCredible(authCtx) { + return true, nil + } - var checkNamespace, checkSvc, checkCfgGroup bool + allowPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("allow", p) + denyPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("deny", p) - reqRes := authCtx.GetAccessResources() - nsResEntries := reqRes[apisecurity.ResourceType_Namespaces] - svcResEntries := reqRes[apisecurity.ResourceType_Services] - cfgResEntries := reqRes[apisecurity.ResourceType_ConfigGroups] - - principleID, _ := authCtx.GetAttachments()[model.OperatorIDKey].(string) - principleType, _ := authCtx.GetAttachments()[model.OperatorPrincipalType].(model.PrincipalType) - p := model.Principal{ - PrincipalID: principleID, - PrincipalRole: principleType, - } - checkNamespace = d.checkAction(p, apisecurity.ResourceType_Namespaces, nsResEntries, authCtx) - checkSvc = d.checkAction(p, apisecurity.ResourceType_Services, svcResEntries, authCtx) - checkCfgGroup = d.checkAction(p, apisecurity.ResourceType_ConfigGroups, cfgResEntries, authCtx) + resources := authCtx.GetAccessResources() - checkAllResEntries := checkNamespace && checkSvc && checkCfgGroup + // 先执行 deny 策略 + for i := range denyPolicies { + item := denyPolicies[i] + if d.MatchPolicy(authCtx, item, p, resources) { + return false, ErrorNotPermission + } + } - var err error - if !checkAllResEntries { - err = ErrorNotPermission + // 处理 allow 策略,只要有一个放开,就可以认为通过 + for i := range allowPolicies { + item := allowPolicies[i] + if d.MatchPolicy(authCtx, item, p, resources) { + return true, nil + } } - return checkAllResEntries, err + return false, ErrorNotPermission } -// checkAction 检查操作是否和策略匹配 -func (d *DefaultAuthChecker) checkAction(principal model.Principal, - resType apisecurity.ResourceType, resources []model.ResourceEntry, ctx *model.AcquireContext) bool { - // TODO 后续可针对读写操作进行鉴权, 并且可以针对具体的方法调用进行鉴权控制 - - switch ctx.GetOperation() { - case model.Read: - return true - default: - for _, entry := range resources { - if !d.cacheMgr.AuthStrategy().IsResourceEditable(principal, resType, entry.ID) { - return false - } +// IsCredible 检查是否是可信的请求 +func (d *DefaultAuthChecker) IsCredible(authCtx *authcommon.AcquireContext) bool { + reqHeaders, ok := authCtx.GetRequestContext().Value(utils.ContextRequestHeaders).(map[string][]string) + if !ok || len(d.conf.CredibleHeaders) == 0 { + return false + } + matched := true + for k, v := range d.conf.CredibleHeaders { + val, exist := reqHeaders[strings.ToLower(k)] + if !exist { + matched = false + break + } + if len(val) == 0 { + matched = false } + matched = v == val[0] + if !matched { + break + } + } + return matched +} + +// MatchPolicy 检查策略是否匹配 +func (d *DefaultAuthChecker) MatchPolicy(authCtx *authcommon.AcquireContext, policy *authcommon.StrategyDetail, + principal authcommon.Principal, resources map[apisecurity.ResourceType][]authcommon.ResourceEntry) bool { + if !d.MatchCalleeFunctions(authCtx, principal, policy) { + return false + } + if !d.MatchResourceOperateable(authCtx, principal, policy) { + return false + } + if !d.MatchResourceConditions(authCtx, principal, policy) { + return false } return true } -func (d *DefaultAuthChecker) SetCacheMgr(mgr cachetypes.CacheManager) { - d.cacheMgr = mgr +// MatchCalleeFunctions 检查操作方法是否和策略匹配 +func (d *DefaultAuthChecker) MatchCalleeFunctions(authCtx *authcommon.AcquireContext, + principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { + functions := policy.CalleeMethods + for i := range functions { + if functions[i] == string(authCtx.GetMethod()) { + return true + } + if utils.IsWildMatch(string(authCtx.GetMethod()), functions[i]) { + return true + } + } + return false } -func (d *DefaultAuthChecker) GetConfig() *AuthConfig { - return d.conf +// checkAction 检查操作资源是否和策略匹配 +func (d *DefaultAuthChecker) MatchResourceOperateable(authCtx *authcommon.AcquireContext, + principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { + matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { + for i := range resources { + actionResult := d.cacheMgr.AuthStrategy().Hint(principal, &resources[i]) + if actionResult.String() == policy.Action { + return true + } + } + return false + } + + reqRes := authCtx.GetAccessResources() + isMatch := false + for k, v := range reqRes { + if isMatch = matchCheck(k, v); isMatch { + break + } + } + return isMatch } -func (d *DefaultAuthChecker) SetConfig(conf *AuthConfig) { - d.conf = conf +// MatchResourceConditions 检查操作资源所拥有的标签是否和策略匹配 +func (d *DefaultAuthChecker) MatchResourceConditions(authCtx *authcommon.AcquireContext, + principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { + matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { + conditions := policy.Conditions + + for i := range resources { + allMatch := true + for j := range conditions { + condition := conditions[j] + resVal, ok := resources[i].Metadata[condition.Key] + if !ok { + allMatch = false + break + } + compareFunc, ok := conditionCompareDict[condition.CompareFunc] + if !ok { + allMatch = false + break + } + if allMatch = compareFunc(resVal, condition.Value); !allMatch { + break + } + } + if allMatch { + return true + } + } + return false + } + + reqRes := authCtx.GetAccessResources() + isMatch := false + for k, v := range reqRes { + if isMatch = matchCheck(k, v); isMatch { + break + } + } + return isMatch } + +var ( + conditionCompareDict = map[string]func(string, string) bool{ + "for_any_value:string_equal": func(s1, s2 string) bool { + return s1 == s2 + }, + } +) diff --git a/auth/policy/auth_checker_test.go b/auth/policy/auth_checker_test.go index bddb12440..4668fc661 100644 --- a/auth/policy/auth_checker_test.go +++ b/auth/policy/auth_checker_test.go @@ -29,7 +29,8 @@ import ( "github.com/polarismesh/polaris/auth/policy" defaultuser "github.com/polarismesh/polaris/auth/user" "github.com/polarismesh/polaris/cache" - "github.com/polarismesh/polaris/common/model" + cachetypes "github.com/polarismesh/polaris/cache/api" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -54,7 +55,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) @@ -79,7 +80,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) "salt": "polarismesh@2021", }, }, - }, storage, cacheMgr) + }, storage, nil, cacheMgr) _, svr, err := newPolicyServer() if err != nil { @@ -100,12 +101,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-主账户资源访问检查", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -122,12 +123,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户资源访问检查(无操作权限)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -144,12 +145,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户资源访问检查(有操作权限)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -166,12 +167,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户资源访问检查(资源无绑定策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -188,12 +189,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户访问用户组资源检查(属于用户组成员)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[(len(users)-1)+2].ID, @@ -210,12 +211,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户访问用户组资源检查(不属于用户组成员)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[(len(users)-1)+4].ID, @@ -232,13 +233,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-用户组访问组内成员资源检查", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groups[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(groups[1].Token), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(groups[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -255,12 +256,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-token非法-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "users[1].Token") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -277,12 +278,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) t.Run("权限检查非严格模式-token为空-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -315,7 +316,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) @@ -341,7 +342,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { "salt": "polarismesh@2021", }, }, - }, storage, cacheMgr) + }, storage, nil, cacheMgr) _, svr, err := newPolicyServer() if err != nil { @@ -362,13 +363,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-主账户操作资源", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken(users[0].Token), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -385,13 +386,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-子账户操作资源(无操作权限)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -407,13 +408,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-子账户操作资源(有操作权限)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -429,13 +430,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源有策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -451,13 +452,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源有策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken(""), - model.WithModule(model.DiscoverModule), - model.WithOperation(model.Create), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(""), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithOperation(authcommon.Create), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -473,13 +474,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源没有策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -506,13 +507,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源没有策略)", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // model.WithToken(""), - model.WithOperation(model.Create), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -554,7 +555,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) @@ -580,7 +581,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) "salt": "polarismesh@2021", }, }, - }, storage, cacheMgr) + }, storage, nil, cacheMgr) _, svr, err := newPolicyServer() if err != nil { @@ -601,13 +602,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-主账户正常读操作", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[0].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -623,13 +624,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -645,13 +646,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -667,13 +668,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -689,13 +690,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(""), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -711,12 +712,12 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -732,13 +733,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -754,13 +755,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -792,7 +793,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) @@ -818,7 +819,7 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { "salt": "polarismesh@2021", }, }, - }, storage, cacheMgr) + }, storage, nil, cacheMgr) _, svr, err := newPolicyServer() if err != nil { @@ -848,13 +849,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-主账户正常读操作", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[0].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -870,13 +871,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[1].ID, @@ -892,13 +893,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -914,13 +915,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(users[1].Token), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -936,13 +937,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(""), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -958,13 +959,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken(""), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -980,13 +981,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[0].ID, @@ -1002,13 +1003,13 @@ func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // model.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - model.WithOperation(model.Read), - model.WithModule(model.DiscoverModule), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { ID: services[freeIndex].ID, @@ -1105,3 +1106,31 @@ func Test_DefaultAuthChecker_Initialize(t *testing.T) { }) } + +func TestDefaultAuthChecker_isCredible(t *testing.T) { + type fields struct { + conf *policy.AuthConfig + cacheMgr cachetypes.CacheManager + userSvr auth.UserServer + } + type args struct { + authCtx *authcommon.AcquireContext + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &policy.DefaultAuthChecker{} + d.SetConfig(tt.fields.conf) + if got := d.IsCredible(tt.args.authCtx); got != tt.want { + t.Errorf("DefaultAuthChecker.isCredible() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/auth/policy/common_test.go b/auth/policy/common_test.go index c64e4718c..72b0e02ef 100644 --- a/auth/policy/common_test.go +++ b/auth/policy/common_test.go @@ -32,6 +32,7 @@ import ( "github.com/polarismesh/polaris/cache" "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" storemock "github.com/polarismesh/polaris/store/mock" ) @@ -116,9 +117,9 @@ func createMockService(namespaces []*model.Namespace) []*model.Service { return services } -func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, services []*model.Service) ([]*model.StrategyDetail, []*model.StrategyDetail) { - strategies := make([]*model.StrategyDetail, 0, len(users)+len(groups)) - defaultStrategies := make([]*model.StrategyDetail, 0, len(users)+len(groups)) +func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { + strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) owner := "" for i := 0; i < len(users); i++ { @@ -133,20 +134,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se user := users[i] service := services[i] id := utils.NewUUID() - strategies = append(strategies, &model.StrategyDetail{ + strategies = append(strategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: user.ID, - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: false, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -164,20 +165,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se ModifyTime: time.Time{}, }) - defaultStrategies = append(defaultStrategies, &model.StrategyDetail{ + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: user.ID, - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: true, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -200,20 +201,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se group := groups[i] service := services[len(users)+i] id := utils.NewUUID() - strategies = append(strategies, &model.StrategyDetail{ + strategies = append(strategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: group.ID, - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }, }, Default: false, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -231,20 +232,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se ModifyTime: time.Time{}, }) - defaultStrategies = append(defaultStrategies, &model.StrategyDetail{ + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: group.ID, - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }, }, Default: true, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -278,8 +279,8 @@ func convertServiceSliceToMap(services []*model.Service) map[string]*model.Servi } // createMockUser 默认 users[0] 为 owner 用户 -func createMockUser(total int, prefix ...string) []*model.User { - users := make([]*model.User, 0, total) +func createMockUser(total int, prefix ...string) []*authcommon.User { + users := make([]*authcommon.User, 0, total) ownerId := utils.NewUUID() @@ -295,7 +296,7 @@ func createMockUser(total int, prefix ...string) []*model.User { } pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") - users = append(users, &model.User{ + users = append(users, &authcommon.User{ ID: id, Name: fmt.Sprintf(nameTemp, i), Password: string(pwd), @@ -308,11 +309,11 @@ func createMockUser(total int, prefix ...string) []*model.User { Source: "Polaris", Mobile: "", Email: "", - Type: func() model.UserRoleType { + Type: func() authcommon.UserRoleType { if id == ownerId { - return model.OwnerUserRole + return authcommon.OwnerUserRole } - return model.SubAccountUserRole + return authcommon.SubAccountUserRole }(), Token: token, TokenEnable: true, @@ -343,16 +344,16 @@ func createApiMockUser(total int, prefix ...string) []*apisecurity.User { return users } -func createMockUserGroup(users []*model.User) []*model.UserGroupDetail { - groups := make([]*model.UserGroupDetail, 0, len(users)) +func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { + groups := make([]*authcommon.UserGroupDetail, 0, len(users)) for i := range users { user := users[i] id := utils.NewUUID() token, _ := defaultuser.CreateToken("", id, "polarismesh@2021") - groups = append(groups, &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + groups = append(groups, &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: id, Name: fmt.Sprintf("test-group-%d", i), Owner: users[0].ID, @@ -374,9 +375,9 @@ func createMockUserGroup(users []*model.User) []*model.UserGroupDetail { // createMockApiUserGroup func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { - musers := make([]*model.User, 0, len(users)) + musers := make([]*authcommon.User, 0, len(users)) for i := range users { - musers = append(musers, &model.User{ + musers = append(musers, &authcommon.User{ ID: users[i].GetId().GetValue(), }) } diff --git a/auth/policy/default.go b/auth/policy/default.go index 337feff77..ccff9599b 100644 --- a/auth/policy/default.go +++ b/auth/policy/default.go @@ -23,6 +23,7 @@ import ( "github.com/polarismesh/polaris/auth" policy_auth "github.com/polarismesh/polaris/auth/policy/inteceptor/auth" + "github.com/polarismesh/polaris/auth/policy/inteceptor/paramcheck" ) type ServerProxyFactory func(svr *Server, pre auth.StrategyServer) (auth.StrategyServer, error) @@ -53,6 +54,9 @@ func loadInteceptors() { RegisterServerProxy("auth", func(svr *Server, pre auth.StrategyServer) (auth.StrategyServer, error) { return policy_auth.NewServer(pre), nil }) + RegisterServerProxy("paramcheck", func(svr *Server, pre auth.StrategyServer) (auth.StrategyServer, error) { + return paramcheck.NewServer(pre), nil + }) } func BuildServer() (*Server, auth.StrategyServer, error) { @@ -61,7 +65,7 @@ func BuildServer() (*Server, auth.StrategyServer, error) { var nextSvr auth.StrategyServer nextSvr = svr // 需要返回包装代理的 DiscoverServer - order := []string{"auth"} + order := GetChainOrder() for i := range order { factory, exist := serverProxyFactories[order[i]] if !exist { @@ -76,3 +80,10 @@ func BuildServer() (*Server, auth.StrategyServer, error) { } return svr, nextSvr, nil } + +func GetChainOrder() []string { + return []string{ + "auth", + "paramcheck", + } +} diff --git a/auth/policy/helper.go b/auth/policy/helper.go new file mode 100644 index 000000000..e77f14de1 --- /dev/null +++ b/auth/policy/helper.go @@ -0,0 +1,60 @@ +package policy + +import ( + "context" + + "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" +) + +type DefaultPolicyHelper struct { + options *AuthConfig + storage store.Store + cacheMgr cachetypes.CacheManager + checker auth.AuthChecker +} + +// CreatePrincipal 创建 principal 的默认 policy 资源 +func (h *DefaultPolicyHelper) CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { + if !h.options.OpenPrincipalDefaultPolicy { + return nil + } + + if err := h.storage.AddStrategy(tx, defaultPrincipalPolicy(p)); err != nil { + return err + } + return nil +} + +func defaultPrincipalPolicy(p authcommon.Principal) *authcommon.StrategyDetail { + // Create the user's default weight policy + return &authcommon.StrategyDetail{ + ID: utils.NewUUID(), + Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalUser, p.Name), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Default: true, + Owner: p.Owner, + Revision: utils.NewUUID(), + Resources: []authcommon.StrategyResource{}, + Valid: true, + Comment: "Default Strategy", + } +} + +// CleanPrincipal 清理 principal 所关联的 policy、role 资源 +func (h *DefaultPolicyHelper) CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { + if h.options.OpenPrincipalDefaultPolicy { + if err := h.storage.CleanPrincipalPolicies(tx, p); err != nil { + return err + } + } + + if err := h.storage.CleanPrincipalRoles(tx, &p); err != nil { + return err + } + return nil +} diff --git a/auth/policy/inteceptor/auth/server.go b/auth/policy/inteceptor/auth/server.go index be4812d06..128046b18 100644 --- a/auth/policy/inteceptor/auth/server.go +++ b/auth/policy/inteceptor/auth/server.go @@ -20,28 +20,22 @@ package auth import ( "context" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/golang/protobuf/ptypes/wrappers" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "go.uber.org/zap" "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/store" ) -var ( - // MustOwner 必须超级账户 or 主账户 - MustOwner = true - // NotOwner 任意账户 - NotOwner = false - // WriteOp 写操作 - WriteOp = true - // ReadOp 读操作 - ReadOp = false +type ( + PolicyInfoGetter interface { + GetId() *wrappers.StringValue + GetName() *wrappers.StringValue + } ) func NewServer(nextSvr auth.StrategyServer) auth.StrategyServer { @@ -55,6 +49,11 @@ type Server struct { userSvr auth.UserServer } +// PolicyHelper implements auth.StrategyServer. +func (svr *Server) PolicyHelper() auth.PolicyHelper { + return svr.nextSvr.PolicyHelper() +} + // Initialize 执行初始化动作 func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr auth.UserServer) error { svr.userSvr = userSvr @@ -68,19 +67,39 @@ func (svr *Server) Name() string { // CreateStrategy 创建策略 func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - return rsp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.CreateAuthPolicy), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp } return svr.nextSvr.CreateStrategy(ctx, strategy) } // UpdateStrategies 批量更新策略 func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) + resources := make([]authcommon.ResourceEntry, 0, len(reqs)) + for i := range reqs { + item := reqs[i] + resources = append(resources, authcommon.ResourceEntry{ + ID: item.GetId().GetValue(), + }) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.UpdateAuthPolicies), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } return svr.nextSvr.UpdateStrategies(ctx, reqs) @@ -88,10 +107,23 @@ func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.Mod // DeleteStrategies 删除策略 func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) + resources := make([]authcommon.ResourceEntry, 0, len(reqs)) + for i := range reqs { + item := reqs[i] + resources = append(resources, authcommon.ResourceEntry{ + ID: item.GetId().GetValue(), + }) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Delete), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DeleteAuthPolicies), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } return svr.nextSvr.DeleteStrategies(ctx, reqs) @@ -101,27 +133,66 @@ func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.Aut // support 1. 支持按照 principal-id + principal-role 进行查询 // support 2. 支持普通的鉴权策略查询 func (svr *Server) GetStrategies(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code(rsp.GetCode().Value), rsp.Info.Value) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeAuthPolicies), + ) + + if err := svr.userSvr.CheckCredential(authCtx); err != nil { + return api.NewAuthBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } + + checker := svr.GetAuthChecker() + cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { + return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: sd.ID, + }) + }) + return svr.nextSvr.GetStrategies(ctx, query) } // GetStrategy 获取策略详细 func (svr *Server) GetStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeAuthPolicyDetail), + ) + + checker := svr.GetAuthChecker() + + if _, err := checker.CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } + + cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { + return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: sd.ID, + }) + }) + return svr.nextSvr.GetStrategy(ctx, strategy) } // GetPrincipalResources 获取某个 principal 的所有可操作资源列表 func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribePrincipalResources), + ) + + checker := svr.GetAuthChecker() + + if _, err := checker.CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.GetPrincipalResources(ctx, query) } @@ -132,57 +203,26 @@ func (svr *Server) GetAuthChecker() auth.AuthChecker { } // AfterResourceOperation 操作完资源的后置处理逻辑 -func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error { +func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) error { return svr.nextSvr.AfterResourceOperation(afterCtx) } -// verifyAuth 用于 user、group 以及 strategy 模块的鉴权工作检查 -func (svr *Server) verifyAuth(ctx context.Context, isWrite bool, - needOwner bool) (context.Context, *apiservice.Response) { - reqId := utils.ParseRequestID(ctx) - authToken := utils.ParseAuthToken(ctx) - - if authToken == "" { - log.Error("[Auth][Server] auth token is empty", utils.ZapRequestID(reqId)) - return nil, api.NewAuthResponse(apimodel.Code_EmptyAutToken) - } - - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.AuthModule), - ) - - // case 1. 如果 error 不是 token 被禁止的 error,直接返回 - // case 2. 如果 error 是 token 被禁止,按下面情况判断 - // i. 如果当前只是一个数据的读取操作,则放通 - // ii. 如果当前是一个数据的写操作,则只能允许处于正常的 token 进行操作 - if err := svr.userSvr.CheckCredential(authCtx); err != nil { - log.Error("[Auth][Server] verify auth token", utils.ZapRequestID(reqId), zap.Error(err)) - return nil, api.NewAuthResponse(apimodel.Code_AuthTokenForbidden) - } - - attachVal, exist := authCtx.GetAttachment(model.TokenDetailInfoKey) - if !exist { - log.Error("[Auth][Server] token detail info not exist", utils.ZapRequestID(reqId)) - return nil, api.NewAuthResponse(apimodel.Code_TokenNotExisted) - } - - operateInfo := attachVal.(auth.OperatorInfo) - if isWrite && operateInfo.Disable { - log.Error("[Auth][Server] token is disabled", utils.ZapRequestID(reqId), - zap.String("operation", authCtx.GetMethod())) - return nil, api.NewAuthResponse(apimodel.Code_TokenDisabled) - } +// CreateRoles 批量创建角色 +func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} - if !operateInfo.IsUserToken { - log.Error("[Auth][Server] only user role can access this API", utils.ZapRequestID(reqId)) - return nil, api.NewAuthResponse(apimodel.Code_OperationRoleForbidden) - } +// UpdateRoles 批量更新角色 +func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} - if needOwner && auth.IsSubAccount(operateInfo) { - log.Error("[Auth][Server] only admin/owner account can access this API", utils.ZapRequestID(reqId)) - return nil, api.NewAuthResponse(apimodel.Code_OperationRoleForbidden) - } +// DeleteRoles 批量删除角色 +func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} - return authCtx.GetRequestContext(), nil +// GetRoles 查询角色列表 +func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + return nil } diff --git a/auth/policy/inteceptor/paramcheck/server.go b/auth/policy/inteceptor/paramcheck/server.go new file mode 100644 index 000000000..3ae25a934 --- /dev/null +++ b/auth/policy/inteceptor/paramcheck/server.go @@ -0,0 +1,160 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + "strconv" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" + + "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" +) + +var ( + // StrategyFilterAttributes strategy filter attributes + StrategyFilterAttributes = map[string]bool{ + "id": true, + "name": true, + "owner": true, + "offset": true, + "limit": true, + "principal_id": true, + "principal_type": true, + "res_id": true, + "res_type": true, + "default": true, + "show_detail": true, + } +) + +func NewServer(nextSvr auth.StrategyServer) auth.StrategyServer { + return &Server{ + nextSvr: nextSvr, + } +} + +type Server struct { + nextSvr auth.StrategyServer + userSvr auth.UserServer +} + +// PolicyHelper implements auth.StrategyServer. +func (svr *Server) PolicyHelper() auth.PolicyHelper { + return svr.nextSvr.PolicyHelper() +} + +// Initialize 执行初始化动作 +func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr auth.UserServer) error { + svr.userSvr = userSvr + return svr.nextSvr.Initialize(options, storage, cacheMgr, userSvr) +} + +// Name 策略管理server名称 +func (svr *Server) Name() string { + return svr.nextSvr.Name() +} + +// CreateStrategy 创建策略 +func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { + return svr.nextSvr.CreateStrategy(ctx, strategy) +} + +// UpdateStrategies 批量更新策略 +func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { + return svr.nextSvr.UpdateStrategies(ctx, reqs) +} + +// DeleteStrategies 删除策略 +func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { + return svr.nextSvr.DeleteStrategies(ctx, reqs) +} + +// GetStrategies 获取资源列表 +// support 1. 支持按照 principal-id + principal-role 进行查询 +// support 2. 支持普通的鉴权策略查询 +func (svr *Server) GetStrategies(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + log.Debug("[Auth][Strategy] origin get strategies query params", utils.RequestID(ctx), zap.Any("query", query)) + + searchFilters := make(map[string]string, len(query)) + for key, value := range query { + if _, ok := StrategyFilterAttributes[key]; !ok { + log.Errorf("[Auth][Strategy] get strategies attribute(%s) it not allowed", key) + return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") + } + searchFilters[key] = value + } + + offset, limit, err := utils.ParseOffsetAndLimit(searchFilters) + + if err != nil { + return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) + } + searchFilters["offset"] = strconv.FormatUint(uint64(offset), 10) + searchFilters["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetStrategies(ctx, query) +} + +// GetStrategy 获取策略详细 +func (svr *Server) GetStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { + return svr.nextSvr.GetStrategy(ctx, strategy) +} + +// GetPrincipalResources 获取某个 principal 的所有可操作资源列表 +func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { + return svr.nextSvr.GetPrincipalResources(ctx, query) +} + +// GetAuthChecker 获取鉴权检查器 +func (svr *Server) GetAuthChecker() auth.AuthChecker { + return svr.nextSvr.GetAuthChecker() +} + +// AfterResourceOperation 操作完资源的后置处理逻辑 +func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) error { + return svr.nextSvr.AfterResourceOperation(afterCtx) +} + +// CreateRoles 批量创建角色 +func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return svr.nextSvr.CreateRoles(ctx, reqs) +} + +// UpdateRoles 批量更新角色 +func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return svr.nextSvr.UpdateRoles(ctx, reqs) +} + +// DeleteRoles 批量删除角色 +func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return svr.nextSvr.DeleteRoles(ctx, reqs) +} + +// GetRoles 查询角色列表 +func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + return svr.nextSvr.GetRoles(ctx, query) +} diff --git a/auth/policy/role.go b/auth/policy/role.go new file mode 100644 index 000000000..2417883ba --- /dev/null +++ b/auth/policy/role.go @@ -0,0 +1,28 @@ +package policy + +import ( + "context" + + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" +) + +// CreateRoles 批量创建角色 +func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} + +// UpdateRoles 批量更新角色 +func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} + +// DeleteRoles 批量删除角色 +func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { + return nil +} + +// GetRoles 查询角色列表 +func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + return nil +} diff --git a/auth/policy/server.go b/auth/policy/server.go index 804be3123..5603f51c0 100644 --- a/auth/policy/server.go +++ b/auth/policy/server.go @@ -31,6 +31,7 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/store" @@ -49,6 +50,10 @@ type AuthConfig struct { ConsoleStrict bool `json:"consoleStrict"` // ClientStrict 是否启用鉴权的严格模式,即对于没有任何鉴权策略的资源,也必须带上正确的token才能操作, 默认关闭 ClientStrict bool `json:"clientStrict"` + // CredibleHeaders 可信请求 Header + CredibleHeaders map[string]string + // OpenPrincipalDefaultPolicy 是否开启 principal 默认策略 + OpenPrincipalDefaultPolicy bool `json:"openPrincipalDefaultPolicy"` } // DefaultAuthConfig 返回一个默认的鉴权配置 @@ -68,12 +73,21 @@ func DefaultAuthConfig() *AuthConfig { type Server struct { options *AuthConfig storage store.Store - history plugin.History cacheMgr cachetypes.CacheManager - checker *DefaultAuthChecker + checker auth.AuthChecker userSvr auth.UserServer } +// PolicyHelper implements auth.StrategyServer. +func (svr *Server) PolicyHelper() auth.PolicyHelper { + return &DefaultPolicyHelper{ + options: svr.options, + storage: svr.storage, + cacheMgr: svr.cacheMgr, + checker: svr.checker, + } +} + // initialize func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr auth.UserServer) error { svr.cacheMgr = cacheMgr @@ -86,14 +100,12 @@ func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMg _ = cacheMgr.OpenResourceCache(cachetypes.ConfigEntry{ Name: cachetypes.StrategyRuleName, }) - // 获取History插件,注意:插件的配置在bootstrap已经设置好 - svr.history = plugin.GetHistory() - if svr.history == nil { - log.Warnf("Not Found History Log Plugin") - } - svr.checker = &DefaultAuthChecker{} - svr.checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr) + checker := &DefaultAuthChecker{ + policyMgr: svr, + } + checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr) + svr.checker = checker return nil } @@ -152,23 +164,17 @@ func (svr *Server) GetAuthChecker() auth.AuthChecker { // RecordHistory Server对外提供history插件的简单封装 func (svr *Server) RecordHistory(entry *model.RecordEntry) { - // 如果插件没有初始化,那么不记录history - if svr.history == nil { - return - } - // 如果数据为空,则不需要打印了 - if entry == nil { - return - } + plugin.GetHistory().Record(entry) +} - // 调用插件记录history - svr.history.Record(entry) +func (svr *Server) isOpenAuth() bool { + return svr.checker.IsOpenClientAuth() || svr.checker.IsOpenConsoleAuth() } // AfterResourceOperation 对于资源的添加删除操作,需要执行后置逻辑 // 所有子用户或者用户分组,都默认获得对所创建的资源的写权限 -func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error { - if !svr.checker.IsOpenAuth() || afterCtx.GetOperation() == model.Read { +func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) error { + if !svr.isOpenAuth() || afterCtx.GetOperation() == authcommon.Read { return nil } @@ -181,7 +187,7 @@ func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error return nil } - attachVal, ok := afterCtx.GetAttachment(model.TokenDetailInfoKey) + attachVal, ok := afterCtx.GetAttachment(authcommon.TokenDetailInfoKey) if !ok { return nil } @@ -195,13 +201,13 @@ func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error return nil } - addUserIds := afterCtx.GetAttachments()[model.LinkUsersKey].([]string) - addGroupIds := afterCtx.GetAttachments()[model.LinkGroupsKey].([]string) - removeUserIds := afterCtx.GetAttachments()[model.RemoveLinkUsersKey].([]string) - removeGroupIds := afterCtx.GetAttachments()[model.RemoveLinkGroupsKey].([]string) + addUserIds := afterCtx.GetAttachments()[authcommon.LinkUsersKey].([]string) + addGroupIds := afterCtx.GetAttachments()[authcommon.LinkGroupsKey].([]string) + removeUserIds := afterCtx.GetAttachments()[authcommon.RemoveLinkUsersKey].([]string) + removeGroupIds := afterCtx.GetAttachments()[authcommon.RemoveLinkGroupsKey].([]string) // 只有在创建一个资源的时候,才需要把当前的创建者一并加到里面去 - if afterCtx.GetOperation() == model.Create { + if afterCtx.GetOperation() == authcommon.Create { if tokenInfo.IsUserToken { addUserIds = append(addUserIds, tokenInfo.OperatorID) } else { @@ -210,28 +216,28 @@ func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error } log.Info("[Auth][Server] add resource to principal default strategy", - zap.Any("resource", afterCtx.GetAttachments()[model.ResourceAttachmentKey]), + zap.Any("resource", afterCtx.GetAttachments()[authcommon.ResourceAttachmentKey]), zap.Any("add_user", addUserIds), zap.Any("add_group", addGroupIds), zap.Any("remove_user", removeUserIds), zap.Any("remove_group", removeGroupIds), ) // 添加某些用户、用户组与资源的默认授权关系 - if err := svr.handleUserStrategy(addUserIds, afterCtx, false); err != nil { + if err := svr.handleChangeUserPolicy(addUserIds, afterCtx, false); err != nil { log.Error("[Auth][Server] add user link resource", zap.Error(err)) return err } - if err := svr.handleGroupStrategy(addGroupIds, afterCtx, false); err != nil { + if err := svr.handleChangeUserGroupPolicy(addGroupIds, afterCtx, false); err != nil { log.Error("[Auth][Server] add group link resource", zap.Error(err)) return err } // 清理某些用户、用户组与资源的默认授权关系 - if err := svr.handleUserStrategy(removeUserIds, afterCtx, true); err != nil { + if err := svr.handleChangeUserPolicy(removeUserIds, afterCtx, true); err != nil { log.Error("[Auth][Server] remove user link resource", zap.Error(err)) return err } - if err := svr.handleGroupStrategy(removeGroupIds, afterCtx, true); err != nil { + if err := svr.handleChangeUserGroupPolicy(removeGroupIds, afterCtx, true); err != nil { log.Error("[Auth][Server] remove group link resource", zap.Error(err)) return err } @@ -240,7 +246,7 @@ func (svr *Server) AfterResourceOperation(afterCtx *model.AcquireContext) error } // handleUserStrategy -func (svr *Server) handleUserStrategy(userIds []string, afterCtx *model.AcquireContext, isRemove bool) error { +func (svr *Server) handleChangeUserPolicy(userIds []string, afterCtx *authcommon.AcquireContext, isRemove bool) error { for index := range utils.StringSliceDeDuplication(userIds) { userId := userIds[index] user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ @@ -254,7 +260,7 @@ func (svr *Server) handleUserStrategy(userIds []string, afterCtx *model.AcquireC if ownerId == "" { ownerId = user.GetId().GetValue() } - if err := svr.handlerModifyDefaultStrategy(userId, ownerId, model.PrincipalUser, + if err := svr.changePrincipalPolicies(userId, ownerId, authcommon.PrincipalUser, afterCtx, isRemove); err != nil { return err } @@ -263,7 +269,7 @@ func (svr *Server) handleUserStrategy(userIds []string, afterCtx *model.AcquireC } // handleGroupStrategy -func (svr *Server) handleGroupStrategy(groupIds []string, afterCtx *model.AcquireContext, isRemove bool) error { +func (svr *Server) handleChangeUserGroupPolicy(groupIds []string, afterCtx *authcommon.AcquireContext, isRemove bool) error { for index := range utils.StringSliceDeDuplication(groupIds) { groupId := groupIds[index] group := svr.userSvr.GetUserHelper().GetGroup(context.TODO(), &apisecurity.UserGroup{ @@ -273,7 +279,7 @@ func (svr *Server) handleGroupStrategy(groupIds []string, afterCtx *model.Acquir return errors.New("not found target group") } ownerId := group.GetOwner().GetValue() - if err := svr.handlerModifyDefaultStrategy(groupId, ownerId, model.PrincipalGroup, + if err := svr.changePrincipalPolicies(groupId, ownerId, authcommon.PrincipalGroup, afterCtx, isRemove); err != nil { return err } @@ -282,10 +288,10 @@ func (svr *Server) handleGroupStrategy(groupIds []string, afterCtx *model.Acquir return nil } -// handlerModifyDefaultStrategy 处理默认策略的修改 +// changePrincipalPolicies 处理默认策略的修改 // case 1. 如果默认策略是全部放通 -func (svr *Server) handlerModifyDefaultStrategy(id, ownerId string, uType model.PrincipalType, - afterCtx *model.AcquireContext, cleanRealtion bool) error { +func (svr *Server) changePrincipalPolicies(id, ownerId string, uType authcommon.PrincipalType, + afterCtx *authcommon.AcquireContext, cleanRealtion bool) error { // Get the default policy rules strategy, err := svr.storage.GetDefaultStrategyDetailByPrincipal(id, uType) if err != nil { @@ -298,26 +304,26 @@ func (svr *Server) handlerModifyDefaultStrategy(id, ownerId string, uType model. } var ( - strategyResource = make([]model.StrategyResource, 0) + strategyResource = make([]authcommon.StrategyResource, 0) strategyId = strategy.ID ) - attachVal, ok := afterCtx.GetAttachment(model.ResourceAttachmentKey) + attachVal, ok := afterCtx.GetAttachment(authcommon.ResourceAttachmentKey) if !ok { return nil } - resources, ok := attachVal.(map[apisecurity.ResourceType][]model.ResourceEntry) + resources, ok := attachVal.(map[apisecurity.ResourceType][]authcommon.ResourceEntry) if !ok { return nil } // 资源删除时,清理该资源与所有策略的关联关系 - if afterCtx.GetOperation() == model.Delete { + if afterCtx.GetOperation() == authcommon.Delete { strategyId = "" } for rType, rIds := range resources { for i := range rIds { id := rIds[i] - strategyResource = append(strategyResource, model.StrategyResource{ + strategyResource = append(strategyResource, authcommon.StrategyResource{ StrategyID: strategyId, ResType: int32(rType), ResID: id.ID, @@ -333,11 +339,11 @@ func (svr *Server) handlerModifyDefaultStrategy(id, ownerId string, uType model. HappenTime: time.Now(), } - if afterCtx.GetOperation() == model.Delete || cleanRealtion { + if afterCtx.GetOperation() == authcommon.Delete || cleanRealtion { if err = svr.storage.RemoveStrategyResources(strategyResource); err != nil { log.Error("[Auth][Server] remove default strategy resource", zap.String("owner", ownerId), zap.String("id", id), - zap.String("type", model.PrincipalNames[uType]), zap.Error(err)) + zap.String("type", authcommon.PrincipalNames[uType]), zap.Error(err)) return err } entry.OperationType = model.ODelete @@ -348,7 +354,7 @@ func (svr *Server) handlerModifyDefaultStrategy(id, ownerId string, uType model. if err = svr.storage.LooseAddStrategyResources(strategyResource); err != nil { log.Error("[Auth][Server] update default strategy resource", zap.String("owner", ownerId), zap.String("id", id), zap.String("id", id), - zap.String("type", model.PrincipalNames[uType]), zap.Error(err)) + zap.String("type", authcommon.PrincipalNames[uType]), zap.Error(err)) return err } entry.OperationType = model.OUpdate diff --git a/auth/policy/server_test.go b/auth/policy/server_test.go new file mode 100644 index 000000000..325113678 --- /dev/null +++ b/auth/policy/server_test.go @@ -0,0 +1,302 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/auth" + authmock "github.com/polarismesh/polaris/auth/mock" + "github.com/polarismesh/polaris/auth/policy" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +func Test_AfterResourceOperation(t *testing.T) { + svr := &policy.Server{} + + t.Run("not_need_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext()) + assert.NoError(t, err) + }) + + t.Run("read_op", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Read), + )) + assert.NoError(t, err) + }) + + t.Run("from_client_not_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + )) + assert.NoError(t, err) + }) + + t.Run("from_console_not_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromConsole(), + )) + assert.NoError(t, err) + }) + + t.Run("not_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + )) + assert.NoError(t, err) + }) + + t.Run("invalid_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + ctx.SetAttachment(authcommon.TokenDetailInfoKey, map[string]string{}) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("empty_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + t.Run("origin_empty", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: "", + }) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("is_anonymous", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: "123", + Anonymous: true, + }) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + }) + + t.Run("change_principal_policy", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + + ownerId := "mock_auth_owner" + curUserId := "123" + + t.Run("user", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: curUserId, + OperatorID: curUserId, + OwnerID: ownerId, + Role: authcommon.OwnerUserRole, + IsUserToken: true, + }) + + initMockAcquireContext(ctx) + + t.Run("not_found_user", func(t *testing.T) { + userSvr := authmock.NewMockUserServer(ctrl) + mockHelper := authmock.NewMockUserHelper(ctrl) + + userSvr.EXPECT().GetUserHelper().Return(mockHelper) + mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(nil) + + svr.MockUserServer(userSvr) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "not found target user", err.Error()) + }) + + t.Run("found_user", func(t *testing.T) { + userSvr := authmock.NewMockUserServer(ctrl) + mockHelper := authmock.NewMockUserHelper(ctrl) + + userSvr.EXPECT().GetUserHelper().Return(mockHelper).AnyTimes() + mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(&security.User{ + Id: wrapperspb.String(curUserId), + Owner: wrapperspb.String(ownerId), + }).AnyTimes() + + svr.MockUserServer(userSvr) + + t.Run("store_has_err", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock_err")) + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "mock_err", err.Error()) + }) + + t.Run("not_found_default_policy", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "not found default strategy rule", err.Error()) + }) + + t.Run("not_op_resource", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("invalid_op_resource", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) + svr.MockStore(mockStore) + + ctx.SetAttachment(authcommon.ResourceAttachmentKey, map[string]interface{}{}) + + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("delete_resource", func(t *testing.T) { + delCtx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Delete), + authcommon.WithFromClient(), + ) + delCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: curUserId, + OperatorID: curUserId, + OwnerID: ownerId, + Role: authcommon.OwnerUserRole, + IsUserToken: true, + }) + + initMockAcquireContext(delCtx) + + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + mockStore.EXPECT().RemoveStrategyResources(gomock.Any()).DoAndReturn(func(args interface{}) error { + resources := args.([]authcommon.StrategyResource) + for i := range resources { + assert.True(t, resources[i].StrategyID == "", utils.MustJson(resources[i])) + } + return nil + }).Times(1) + + svr.MockStore(mockStore) + err := svr.AfterResourceOperation(delCtx) + assert.NoError(t, err) + }) + }) + }) + }) +} + +func initMockAcquireContext(ctx *authcommon.AcquireContext) { + ctx.SetAttachment(authcommon.LinkUsersKey, []string{}) + ctx.SetAttachment(authcommon.LinkGroupsKey, []string{}) + ctx.SetAttachment(authcommon.RemoveLinkUsersKey, []string{}) + ctx.SetAttachment(authcommon.RemoveLinkGroupsKey, []string{}) +} diff --git a/auth/policy/strategy.go b/auth/policy/strategy.go index 3640b6862..7e48ed12b 100644 --- a/auth/policy/strategy.go +++ b/auth/policy/strategy.go @@ -32,6 +32,7 @@ import ( "go.uber.org/zap" "google.golang.org/protobuf/types/known/wrapperspb" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -42,31 +43,12 @@ import ( type ( // StrategyDetail2Api strategy detail to *apisecurity.AuthStrategy func - StrategyDetail2Api func(user *model.StrategyDetail) *apisecurity.AuthStrategy + StrategyDetail2Api func(user *authcommon.StrategyDetail) *apisecurity.AuthStrategy ) -var ( - // StrategyFilterAttributes strategy filter attributes - StrategyFilterAttributes = map[string]bool{ - "id": true, - "name": true, - "owner": true, - "offset": true, - "limit": true, - "principal_id": true, - "principal_type": true, - "res_id": true, - "res_type": true, - "default": true, - "show_detail": true, - } -) - -// handleCreateStrategy 创建鉴权策略 -func (svr *Server) handleCreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) +// CreateStrategy 创建鉴权策略 +func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { req.Owner = utils.NewStringValue(utils.ParseOwnerID(ctx)) - if checkErrResp := svr.checkCreateStrategy(req); checkErrResp != nil { return checkErrResp } @@ -74,21 +56,34 @@ func (svr *Server) handleCreateStrategy(ctx context.Context, req *apisecurity.Au req.Resources = svr.normalizeResource(req.Resources) data := svr.createAuthStrategyModel(req) - if err := svr.storage.AddStrategy(data); err != nil { - log.Error("[Auth][Strategy] create strategy into store", utils.ZapRequestID(requestID), + + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][Strategy] start tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + defer func() { + _ = tx.Rollback() + }() + + if err := svr.storage.AddStrategy(tx, data); err != nil { + log.Error("[Auth][Strategy] create strategy into store", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } + if err := tx.Commit(); err != nil { + log.Error("[Auth][Strategy] create strategy commit tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } - log.Info("[Auth][Strategy] create strategy", utils.ZapRequestID(requestID), - zap.String("name", req.Name.GetValue())) + log.Info("[Auth][Strategy] create strategy", utils.RequestID(ctx), zap.String("name", req.Name.GetValue())) svr.RecordHistory(authStrategyRecordEntry(ctx, req, data, model.OCreate)) return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, req) } -// handleUpdateStrategies 批量修改鉴权 -func (svr *Server) handleUpdateStrategies( +// UpdateStrategies 批量修改鉴权 +func (svr *Server) UpdateStrategies( ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) @@ -140,8 +135,8 @@ func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAu return api.NewModifyAuthStrategyResponse(apimodel.Code_ExecuteSuccess, req) } -// handleDeleteStrategies 批量删除鉴权策略 -func (svr *Server) handleDeleteStrategies( +// DeleteStrategies 批量删除鉴权策略 +func (svr *Server) DeleteStrategies( ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) for index := range reqs { @@ -191,7 +186,7 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, req) } -// GetStrategies 查询鉴权策略列表 +// GetStrategies 批量查询鉴权策略 // Case 1. 如果是以资源视角来查询鉴权策略,那么就会忽略自动根据账户类型进行数据查看的限制 // // eg. 比如当前子账户A想要查看资源R的相关的策略,那么不在会自动注入 principal_id 以及 principal_type 的查询条件 @@ -203,35 +198,18 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra // a. 如果当前是超级管理账户,则按照传入的 query 进行查询即可 // b. 如果当前是主账户,则自动注入 owner 字段,即只能查看策略的 owner 是自己的策略 // c. 如果当前是子账户,则自动注入 principal_id 以及 principal_type 字段,即稚嫩查询与自己有关的策略 -func (svr *Server) handleGetStrategies(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - log.Debug("[Auth][Strategy] origin get strategies query params", utils.ZapRequestID(requestID), - utils.ZapPlatformID(platformID), zap.Any("query", query)) - - showDetail := query["show_detail"] - - searchFilters := make(map[string]string, len(query)) - for key, value := range query { - if _, ok := StrategyFilterAttributes[key]; !ok { - log.Errorf("[Auth][Strategy] get strategies attribute(%s) it not allowed", key) - return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") - } - searchFilters[key] = value - } - - searchFilters = ParseStrategySearchArgs(ctx, searchFilters) - offset, limit, err := utils.ParseOffsetAndLimit(searchFilters) - - if err != nil { - return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) - } - - total, strategies, err := svr.storage.GetStrategies(searchFilters, offset, limit) +func (svr *Server) GetStrategies(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { + filters = ParseStrategySearchArgs(ctx, filters) + offset, limit, _ := utils.ParseOffsetAndLimit(filters) + total, strategies, err := svr.cacheMgr.AuthStrategy().Query(ctx, cachetypes.PolicySearchArgs{ + Filters: filters, + Offset: offset, + Limit: limit, + }) if err != nil { - log.Error("[Auth][Strategy] get strategies from store", zap.Any("query", searchFilters), - zap.Error(err)) + log.Error("[Auth][Strategy] get strategies from store", zap.Any("query", filters), + utils.RequestID(ctx), zap.Error(err)) return api.NewAuthBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -239,8 +217,8 @@ func (svr *Server) handleGetStrategies(ctx context.Context, query map[string]str resp.Amount = utils.NewUInt32Value(total) resp.Size = utils.NewUInt32Value(uint32(len(strategies))) - if strings.Compare(showDetail, "true") == 0 { - log.Info("[Auth][Strategy] fill strategy detail", utils.ZapRequestID(requestID)) + if strings.Compare(filters["show_detail"], "true") == 0 { + log.Info("[Auth][Strategy] fill strategy detail", utils.RequestID(ctx)) resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategyFull2Api) } else { resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategy2Api) @@ -281,7 +259,7 @@ func ParseStrategySearchArgs(ctx context.Context, searchFilters map[string]strin } } - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { + if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { // 如果当前账户不是 admin 角色,既不是走资源视角查看,也不是指定principal查看,那么只能查询当前操作用户被关联到的鉴权策略, if _, ok := searchFilters["res_id"]; !ok { // 设置 owner 参数,只能查看对应 owner 下的策略 @@ -290,7 +268,7 @@ func ParseStrategySearchArgs(ctx context.Context, searchFilters map[string]strin // 如果当前不是 owner 角色,那么只能查询与自己有关的策略 if !utils.ParseIsOwner(ctx) { searchFilters["principal_id"] = utils.ParseUserID(ctx) - searchFilters["principal_type"] = strconv.Itoa(int(model.PrincipalUser)) + searchFilters["principal_type"] = strconv.Itoa(int(authcommon.PrincipalUser)) } } } @@ -299,11 +277,11 @@ func ParseStrategySearchArgs(ctx context.Context, searchFilters map[string]strin return searchFilters } -// handleGetStrategy 根据策略ID获取详细的鉴权策略 +// GetStrategy 根据策略ID获取详细的鉴权策略 // Case 1 如果当前操作者是该策略 principal 中的一员,则可以查看 // Case 2 如果当前操作者是该策略的 owner,则可以查看 // Case 3 如果当前操作者是admin角色,直接查看 -func (svr *Server) handleGetStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { +func (svr *Server) GetStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { userId := utils.ParseUserID(ctx) isOwner := utils.ParseIsOwner(ctx) @@ -324,7 +302,7 @@ func (svr *Server) handleGetStrategy(ctx context.Context, req *apisecurity.AuthS var canView bool if isOwner { // 是否是本鉴权策略的 owner 账户, 或者是否是超级管理员, 是的话则快速跳过下面的检查 - canView = (ret.Owner == userId) || authcommon.ParseUserRole(ctx) == model.AdminUserRole + canView = (ret.Owner == userId) || authcommon.ParseUserRole(ctx) == authcommon.AdminUserRole } // 判断是否在该策略所属的成员列表中,如果自己在某个用户组,而该用户组又在这个策略的成员中,则也是可以查看的 @@ -334,11 +312,11 @@ func (svr *Server) handleGetStrategy(ctx context.Context, req *apisecurity.AuthS } for index := range ret.Principals { principal := ret.Principals[index] - if principal.PrincipalRole == model.PrincipalUser && principal.PrincipalID == userId { + if principal.PrincipalType == authcommon.PrincipalUser && principal.PrincipalID == userId { canView = true break } - if principal.PrincipalRole == model.PrincipalGroup { + if principal.PrincipalType == authcommon.PrincipalGroup { group := &apisecurity.UserGroup{ Id: wrapperspb.String(principal.PrincipalID), } @@ -361,8 +339,8 @@ func (svr *Server) handleGetStrategy(ctx context.Context, req *apisecurity.AuthS return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, svr.authStrategyFull2Api(ret)) } -// handleGetPrincipalResources 获取某个principal可以获取到的所有资源ID数据信息 -func (svr *Server) handleGetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { +// GetPrincipalResources 获取某个principal可以获取到的所有资源ID数据信息 +func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { requestID := utils.ParseRequestID(ctx) if len(query) == 0 { return api.NewAuthResponse(apimodel.Code_EmptyRequest) @@ -381,23 +359,23 @@ func (svr *Server) handleGetPrincipalResources(ctx context.Context, query map[st } principalRole, _ := strconv.ParseInt(principalType, 10, 64) - if err := model.CheckPrincipalType(int(principalRole)); err != nil { + if err := authcommon.CheckPrincipalType(int(principalRole)); err != nil { return api.NewAuthResponse(apimodel.Code_InvalidPrincipalType) } var ( - resources = make([]model.StrategyResource, 0, 20) + resources = make([]authcommon.StrategyResource, 0, 20) err error ) // 找这个用户所关联的用户组 - if model.PrincipalType(principalRole) == model.PrincipalUser { + if authcommon.PrincipalType(principalRole) == authcommon.PrincipalUser { groups := svr.userSvr.GetUserHelper().GetUserOwnGroup(ctx, &apisecurity.User{ Id: wrapperspb.String(principalId), }) for i := range groups { item := groups[i] - res, err := svr.storage.GetStrategyResources(item.GetId().GetValue(), model.PrincipalGroup) + res, err := svr.storage.GetStrategyResources(item.GetId().GetValue(), authcommon.PrincipalGroup) if err != nil { log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) @@ -407,7 +385,7 @@ func (svr *Server) handleGetPrincipalResources(ctx context.Context, query map[st } } - pResources, err := svr.storage.GetStrategyResources(principalId, model.PrincipalType(principalRole)) + pResources, err := svr.storage.GetStrategyResources(principalId, authcommon.PrincipalType(principalRole)) if err != nil { log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) @@ -423,7 +401,7 @@ func (svr *Server) handleGetPrincipalResources(ctx context.Context, query map[st }, } - svr.fillResourceInfo(tmp, &model.StrategyDetail{ + svr.fillResourceInfo(tmp, &authcommon.StrategyDetail{ Resources: resourceDeduplication(resources), }) @@ -431,7 +409,7 @@ func (svr *Server) handleGetPrincipalResources(ctx context.Context, query map[st } // enhancedAuthStrategy2Api -func enhancedAuthStrategy2Api(s []*model.StrategyDetail, fn StrategyDetail2Api) []*apisecurity.AuthStrategy { +func enhancedAuthStrategy2Api(s []*authcommon.StrategyDetail, fn StrategyDetail2Api) []*apisecurity.AuthStrategy { out := make([]*apisecurity.AuthStrategy, 0, len(s)) for k := range s { out = append(out, fn(s[k])) @@ -440,7 +418,7 @@ func enhancedAuthStrategy2Api(s []*model.StrategyDetail, fn StrategyDetail2Api) } // authStrategy2Api -func (svr *Server) authStrategy2Api(s *model.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategy2Api(s *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if s == nil { return nil } @@ -461,7 +439,7 @@ func (svr *Server) authStrategy2Api(s *model.StrategyDetail) *apisecurity.AuthSt } // authStrategyFull2Api -func (svr *Server) authStrategyFull2Api(data *model.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategyFull2Api(data *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if data == nil { return nil } @@ -470,7 +448,7 @@ func (svr *Server) authStrategyFull2Api(data *model.StrategyDetail) *apisecurity groups := make([]*wrappers.StringValue, 0, len(data.Principals)) for index := range data.Principals { principal := data.Principals[index] - if principal.PrincipalRole == model.PrincipalUser { + if principal.PrincipalType == authcommon.PrincipalUser { users = append(users, utils.NewStringValue(principal.PrincipalID)) } else { groups = append(groups, utils.NewStringValue(principal.PrincipalID)) @@ -495,8 +473,8 @@ func (svr *Server) authStrategyFull2Api(data *model.StrategyDetail) *apisecurity } // createAuthStrategyModel 创建鉴权策略的存储模型 -func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) *model.StrategyDetail { - ret := &model.StrategyDetail{ +func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) *authcommon.StrategyDetail { + ret := &authcommon.StrategyDetail{ ID: utils.NewUUID(), Name: strategy.Name.GetValue(), Action: apisecurity.AuthAction_READ_WRITE.String(), @@ -510,7 +488,7 @@ func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) * } // 收集涉及的资源信息 - resEntry := make([]model.StrategyResource, 0, 20) + resEntry := make([]authcommon.StrategyResource, 0, 20) resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Namespaces, strategy.GetResources().GetNamespaces(), false)...) resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Services, @@ -519,10 +497,10 @@ func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) * strategy.GetResources().GetConfigGroups(), false)...) // 收集涉及的 principal 信息 - principals := make([]model.Principal, 0, 20) - principals = append(principals, collectPrincipalEntry(ret.ID, model.PrincipalUser, + principals := make([]authcommon.Principal, 0, 20) + principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalUser, strategy.GetPrincipals().GetUsers())...) - principals = append(principals, collectPrincipalEntry(ret.ID, model.PrincipalGroup, + principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalGroup, strategy.GetPrincipals().GetGroups())...) ret.Resources = resEntry @@ -533,9 +511,9 @@ func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) * // updateAuthStrategyAttribute 更新计算鉴权策略的属性 func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *apisecurity.ModifyAuthStrategy, - saved *model.StrategyDetail) (*model.ModifyStrategyDetail, bool) { + saved *authcommon.StrategyDetail) (*authcommon.ModifyStrategyDetail, bool) { var needUpdate bool - ret := &model.ModifyStrategyDetail{ + ret := &authcommon.ModifyStrategyDetail{ ID: strategy.Id.GetValue(), Name: saved.Name, Action: saved.Action, @@ -568,9 +546,9 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap // computeResourceChange 计算资源的变化情况,判断是否涉及变更 func (svr *Server) computeResourceChange( - modify *model.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { + modify *authcommon.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { var needUpdate bool - addResEntry := make([]model.StrategyResource, 0) + addResEntry := make([]authcommon.StrategyResource, 0) addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, strategy.GetAddResources().GetNamespaces(), false)...) addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, @@ -583,7 +561,7 @@ func (svr *Server) computeResourceChange( modify.AddResources = addResEntry } - removeResEntry := make([]model.StrategyResource, 0) + removeResEntry := make([]authcommon.StrategyResource, 0) removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, strategy.GetRemoveResources().GetNamespaces(), true)...) removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, @@ -600,12 +578,12 @@ func (svr *Server) computeResourceChange( } // computePrincipalChange 计算 principal 的变化情况,判断是否涉及变更 -func computePrincipalChange(modify *model.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { +func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { var needUpdate bool - addPrincipals := make([]model.Principal, 0) - addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, model.PrincipalUser, + addPrincipals := make([]authcommon.Principal, 0) + addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalUser, strategy.GetAddPrincipals().GetUsers())...) - addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, model.PrincipalGroup, + addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetAddPrincipals().GetGroups())...) if len(addPrincipals) != 0 { @@ -613,10 +591,10 @@ func computePrincipalChange(modify *model.ModifyStrategyDetail, strategy *apisec modify.AddPrincipals = addPrincipals } - removePrincipals := make([]model.Principal, 0) - removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, model.PrincipalUser, + removePrincipals := make([]authcommon.Principal, 0) + removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalUser, strategy.GetRemovePrincipals().GetUsers())...) - removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, model.PrincipalGroup, + removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetRemovePrincipals().GetGroups())...) if len(removePrincipals) != 0 { @@ -627,10 +605,10 @@ func computePrincipalChange(modify *model.ModifyStrategyDetail, strategy *apisec return needUpdate } -// collectResEntry 将资源ID转换为对应的 []model.StrategyResource 数组 +// collectResEntry 将资源ID转换为对应的 []authcommon.StrategyResource 数组 func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceType, - res []*apisecurity.StrategyResourceEntry, delete bool) []model.StrategyResource { - resEntries := make([]model.StrategyResource, 0, len(res)+1) + res []*apisecurity.StrategyResourceEntry, delete bool) []authcommon.StrategyResource { + resEntries := make([]authcommon.StrategyResource, 0, len(res)+1) if len(res) == 0 { return resEntries } @@ -640,7 +618,7 @@ func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceTy if !delete { // 归一化处理 if res[index].GetId().GetValue() == "*" || res[index].GetName().GetValue() == "*" { - return []model.StrategyResource{ + return []authcommon.StrategyResource{ { StrategyID: ruleId, ResType: int32(resType), @@ -650,7 +628,7 @@ func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceTy } } - entry := model.StrategyResource{ + entry := authcommon.StrategyResource{ StrategyID: ruleId, ResType: int32(resType), ResID: res[index].GetId().GetValue(), @@ -662,18 +640,18 @@ func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceTy return resEntries } -// collectPrincipalEntry 将 Principal 转换为对应的 []model.Principal 数组 -func collectPrincipalEntry(ruleID string, uType model.PrincipalType, res []*apisecurity.Principal) []model.Principal { - principals := make([]model.Principal, 0, len(res)+1) +// collectPrincipalEntry 将 Principal 转换为对应的 []authcommon.Principal 数组 +func collectPrincipalEntry(ruleID string, uType authcommon.PrincipalType, res []*apisecurity.Principal) []authcommon.Principal { + principals := make([]authcommon.Principal, 0, len(res)+1) if len(res) == 0 { return principals } for index := range res { - principals = append(principals, model.Principal{ + principals = append(principals, authcommon.Principal{ StrategyID: ruleID, PrincipalID: res[index].GetId().GetValue(), - PrincipalRole: uType, + PrincipalType: uType, }) } @@ -705,9 +683,9 @@ func (svr *Server) checkCreateStrategy(req *apisecurity.AuthStrategy) *apiservic // Case 1. 修改的是默认鉴权策略的话,只能修改资源,不能添加用户 or 用户组 // Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy, - saved *model.StrategyDetail) *apiservice.Response { + saved *authcommon.StrategyDetail) *apiservice.Response { userId := utils.ParseUserID(ctx) - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { + if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { if !utils.ParseIsOwner(ctx) || userId != saved.Owner { log.Error("[Auth][Strategy] modify strategy denied, current user not owner", utils.ZapRequestID(utils.ParseRequestID(ctx)), @@ -727,7 +705,7 @@ func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.Mod } // 主账户的默认策略禁止编辑 - if len(saved.Principals) == 1 && saved.Principals[0].PrincipalRole == model.PrincipalUser { + if len(saved.Principals) == 1 && saved.Principals[0].PrincipalType == authcommon.PrincipalUser { if saved.Principals[0].PrincipalID == utils.ParseOwnerID(ctx) { return api.NewAuthResponse(apimodel.Code_NotAllowModifyOwnerDefaultStrategy) } @@ -753,7 +731,7 @@ func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.Mod } // authStrategyRecordEntry 转换为鉴权策略的记录结构体 -func authStrategyRecordEntry(ctx context.Context, req *apisecurity.AuthStrategy, md *model.StrategyDetail, +func authStrategyRecordEntry(ctx context.Context, req *apisecurity.AuthStrategy, md *authcommon.StrategyDetail, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -773,7 +751,7 @@ func authStrategyRecordEntry(ctx context.Context, req *apisecurity.AuthStrategy, // authModifyStrategyRecordEntry func authModifyStrategyRecordEntry( - ctx context.Context, req *apisecurity.ModifyAuthStrategy, md *model.ModifyStrategyDetail, + ctx context.Context, req *apisecurity.ModifyAuthStrategy, md *authcommon.ModifyStrategyDetail, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -883,7 +861,6 @@ func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) * break } } - services := resources.GetServices() for index := range services { val := services[index] @@ -894,17 +871,16 @@ func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) * break } } - return resources } // fillPrincipalInfo 填充 principal 摘要信息 -func (svr *Server) fillPrincipalInfo(resp *apisecurity.AuthStrategy, data *model.StrategyDetail) { +func (svr *Server) fillPrincipalInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { users := make([]*apisecurity.Principal, 0, len(data.Principals)) groups := make([]*apisecurity.Principal, 0, len(data.Principals)) for index := range data.Principals { principal := data.Principals[index] - if principal.PrincipalRole == model.PrincipalUser { + if principal.PrincipalType == authcommon.PrincipalUser { user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ Id: wrapperspb.String(principal.PrincipalID), }) @@ -936,7 +912,7 @@ func (svr *Server) fillPrincipalInfo(resp *apisecurity.AuthStrategy, data *model } // fillResourceInfo 填充资源摘要信息 -func (svr *Server) fillResourceInfo(resp *apisecurity.AuthStrategy, data *model.StrategyDetail) { +func (svr *Server) fillResourceInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { namespaces := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) services := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) configGroups := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) @@ -1062,9 +1038,9 @@ func (f *resourceFilter) GetFilter(t apisecurity.ResourceType) (map[string]struc } // filter different types of Strategy resources -func resourceDeduplication(resources []model.StrategyResource) []model.StrategyResource { +func resourceDeduplication(resources []authcommon.StrategyResource) []authcommon.StrategyResource { rLen := len(resources) - ret := make([]model.StrategyResource, 0, rLen) + ret := make([]authcommon.StrategyResource, 0, rLen) rf := resourceFilter{ ns: make(map[string]struct{}, rLen), svc: make(map[string]struct{}, rLen), diff --git a/auth/policy/strategy_test.go b/auth/policy/strategy_test.go index 4ce9a9934..5626fd7d6 100644 --- a/auth/policy/strategy_test.go +++ b/auth/policy/strategy_test.go @@ -37,23 +37,24 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" storemock "github.com/polarismesh/polaris/store/mock" ) type StrategyTest struct { - admin *model.User - ownerOne *model.User - ownerTwo *model.User + admin *authcommon.User + ownerOne *authcommon.User + ownerTwo *authcommon.User namespaces []*model.Namespace services []*model.Service - strategies []*model.StrategyDetail - allStrategies []*model.StrategyDetail - defaultStrategies []*model.StrategyDetail + strategies []*authcommon.StrategyDetail + allStrategies []*authcommon.StrategyDetail + defaultStrategies []*authcommon.StrategyDetail - users []*model.User - groups []*model.UserGroupDetail + users []*authcommon.User + groups []*authcommon.UserGroupDetail storage *storemock.MockStore cacheMgn *cache.CacheManager @@ -80,7 +81,7 @@ func newStrategyTest(t *testing.T) *StrategyTest { serviceMap := convertServiceSliceToMap(services) defaultStrategies, strategies := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - allStrategies := make([]*model.StrategyDetail, 0, len(defaultStrategies)+len(strategies)) + allStrategies := make([]*authcommon.StrategyDetail, 0, len(defaultStrategies)+len(strategies)) allStrategies = append(allStrategies, defaultStrategies...) allStrategies = append(allStrategies, strategies...) @@ -90,7 +91,7 @@ func newStrategyTest(t *testing.T) *StrategyTest { storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allStrategies, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(allStrategies, nil) storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) storage.EXPECT().GetStrategyResources(gomock.Eq(users[1].ID), gomock.Any()).AnyTimes().Return(strategies[1].Resources, nil) @@ -141,7 +142,7 @@ func newStrategyTest(t *testing.T) *StrategyTest { "salt": "polarismesh@2021", }, }, - }, storage, cacheMgn) + }, storage, nil, cacheMgn) _, svr, err := newPolicyServer() if err != nil { @@ -217,7 +218,7 @@ func Test_CreateStrategy(t *testing.T) { _ = strategyTest.cacheMgn.TestUpdate() t.Run("正常创建鉴权策略", func(t *testing.T) { - strategyTest.storage.EXPECT().AddStrategy(gomock.Any()).Return(nil) + strategyTest.storage.EXPECT().AddStrategy(gomock.Any(), gomock.Any()).Return(nil) valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) strategyId := utils.NewUUID() @@ -776,7 +777,7 @@ func Test_parseStrategySearchArgs(t *testing.T) { ctx: func() context.Context { ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, model.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) return ctx }(), @@ -795,7 +796,7 @@ func Test_parseStrategySearchArgs(t *testing.T) { ctx: func() context.Context { ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, model.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) return ctx }(), @@ -814,7 +815,7 @@ func Test_parseStrategySearchArgs(t *testing.T) { ctx: func() context.Context { ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, model.SubAccountUserRole) + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.SubAccountUserRole) ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, false) return ctx }(), @@ -834,7 +835,7 @@ func Test_parseStrategySearchArgs(t *testing.T) { ctx: func() context.Context { ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, model.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) return ctx }(), @@ -853,7 +854,7 @@ func Test_parseStrategySearchArgs(t *testing.T) { ctx: func() context.Context { ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, model.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) return ctx }(), diff --git a/auth/policy/test_export.go b/auth/policy/test_export.go new file mode 100644 index 000000000..9c693f585 --- /dev/null +++ b/auth/policy/test_export.go @@ -0,0 +1,38 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy + +import ( + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/store" +) + +// MockAuthChecker mock auth.AuthChecker for unit test +func (svr *Server) MockAuthChecker(checker auth.AuthChecker) { + svr.checker = checker +} + +// MockStore mock store.Store for unit test +func (svr *Server) MockStore(storage store.Store) { + svr.storage = storage +} + +// MockUserServer mock auth.UserServer for unit test +func (svr *Server) MockUserServer(userSvr auth.UserServer) { + svr.userSvr = userSvr +} diff --git a/auth/policy/utils_test.go b/auth/policy/utils_test.go new file mode 100644 index 000000000..7ec9e822b --- /dev/null +++ b/auth/policy/utils_test.go @@ -0,0 +1,84 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "strings" + "testing" + + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/polarismesh/polaris/auth/policy" + "github.com/polarismesh/polaris/common/utils" +) + +func TestCheckName(t *testing.T) { + tests := []struct { + name string + input *wrappers.StringValue + expected error + }{ + { + name: "测试空名称", + input: nil, + expected: errors.New(utils.NilErrString), + }, + { + name: "测试空名称", + input: utils.NewStringValue(""), + expected: errors.New(utils.EmptyErrString), + }, + { + name: "测试非法用户名", + input: utils.NewStringValue("polariadmin"), + expected: errors.New("illegal username"), + }, + { + name: "测试名称长度超过限制", + input: utils.NewStringValue(strings.Repeat("a", utils.MaxNameLength+1)), + expected: errors.New("name too long"), + }, + { + name: "测试包含无效字符的名称", + input: utils.NewStringValue("invalid*name"), + expected: errors.New("name contains invalid character"), + }, + { + name: "测试有效的名称", + input: utils.NewStringValue("valid_name"), + expected: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual := policy.CheckName(tc.input) + if tc.expected == nil && actual != nil { + t.Fatal(tc.name) + } + if tc.expected != nil && actual == nil { + t.Fatal(nil) + } + if tc.expected != nil && actual != nil { + assert.Equal(t, tc.expected.Error(), actual.Error()) + } + }) + } +} diff --git a/auth/user/common_test.go b/auth/user/common_test.go index e50624246..a5479c62a 100644 --- a/auth/user/common_test.go +++ b/auth/user/common_test.go @@ -32,6 +32,7 @@ import ( "github.com/polarismesh/polaris/cache" "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" storemock "github.com/polarismesh/polaris/store/mock" ) @@ -120,8 +121,8 @@ func createMockService(namespaces []*model.Namespace) []*model.Service { } // createMockUser 默认 users[0] 为 owner 用户 -func createMockUser(total int, prefix ...string) []*model.User { - users := make([]*model.User, 0, total) +func createMockUser(total int, prefix ...string) []*authcommon.User { + users := make([]*authcommon.User, 0, total) ownerId := utils.NewUUID() @@ -137,7 +138,7 @@ func createMockUser(total int, prefix ...string) []*model.User { } pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") - users = append(users, &model.User{ + users = append(users, &authcommon.User{ ID: id, Name: fmt.Sprintf(nameTemp, i), Password: string(pwd), @@ -150,11 +151,11 @@ func createMockUser(total int, prefix ...string) []*model.User { Source: "Polaris", Mobile: "", Email: "", - Type: func() model.UserRoleType { + Type: func() authcommon.UserRoleType { if id == ownerId { - return model.OwnerUserRole + return authcommon.OwnerUserRole } - return model.SubAccountUserRole + return authcommon.SubAccountUserRole }(), Token: token, TokenEnable: true, @@ -185,8 +186,8 @@ func createApiMockUser(total int, prefix ...string) []*apisecurity.User { return users } -func createMockUserGroup(users []*model.User) []*model.UserGroupDetail { - groups := make([]*model.UserGroupDetail, 0, len(users)) +func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { + groups := make([]*authcommon.UserGroupDetail, 0, len(users)) for i := range users { user := users[i] @@ -194,8 +195,8 @@ func createMockUserGroup(users []*model.User) []*model.UserGroupDetail { token, _ := defaultuser.CreateToken("", id, _defaultSalt) - groups = append(groups, &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + groups = append(groups, &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: id, Name: fmt.Sprintf("test-group-%d", i), Owner: users[0].ID, @@ -217,9 +218,9 @@ func createMockUserGroup(users []*model.User) []*model.UserGroupDetail { // createMockApiUserGroup func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { - musers := make([]*model.User, 0, len(users)) + musers := make([]*authcommon.User, 0, len(users)) for i := range users { - musers = append(musers, &model.User{ + musers = append(musers, &authcommon.User{ ID: users[i].GetId().GetValue(), }) } @@ -244,9 +245,9 @@ func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup return ret } -func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, services []*model.Service) ([]*model.StrategyDetail, []*model.StrategyDetail) { - strategies := make([]*model.StrategyDetail, 0, len(users)+len(groups)) - defaultStrategies := make([]*model.StrategyDetail, 0, len(users)+len(groups)) +func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { + strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) owner := "" for i := 0; i < len(users); i++ { @@ -261,20 +262,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se user := users[i] service := services[i] id := utils.NewUUID() - strategies = append(strategies, &model.StrategyDetail{ + strategies = append(strategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: user.ID, - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: false, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -292,20 +293,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se ModifyTime: time.Time{}, }) - defaultStrategies = append(defaultStrategies, &model.StrategyDetail{ + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: user.ID, - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: true, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -328,20 +329,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se group := groups[i] service := services[len(users)+i] id := utils.NewUUID() - strategies = append(strategies, &model.StrategyDetail{ + strategies = append(strategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: group.ID, - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }, }, Default: false, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -359,20 +360,20 @@ func createMockStrategy(users []*model.User, groups []*model.UserGroupDetail, se ModifyTime: time.Time{}, }) - defaultStrategies = append(defaultStrategies, &model.StrategyDetail{ + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ ID: id, Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: "", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { PrincipalID: group.ID, - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }, }, Default: true, Owner: owner, - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: id, ResType: int32(apisecurity.ResourceType_Namespaces), diff --git a/auth/user/default.go b/auth/user/default.go index a92a536b6..001cacf97 100644 --- a/auth/user/default.go +++ b/auth/user/default.go @@ -23,6 +23,7 @@ import ( "github.com/polarismesh/polaris/auth" user_auth "github.com/polarismesh/polaris/auth/user/inteceptor/auth" + "github.com/polarismesh/polaris/auth/user/inteceptor/paramcheck" ) type ServerProxyFactory func(svr *Server, pre auth.UserServer) (auth.UserServer, error) @@ -53,6 +54,9 @@ func loadInteceptors() { RegisterServerProxy("auth", func(svr *Server, pre auth.UserServer) (auth.UserServer, error) { return user_auth.NewServer(pre), nil }) + RegisterServerProxy("paramcheck", func(svr *Server, pre auth.UserServer) (auth.UserServer, error) { + return paramcheck.NewServer(pre), nil + }) } func BuildServer() (*Server, auth.UserServer, error) { @@ -61,7 +65,7 @@ func BuildServer() (*Server, auth.UserServer, error) { var nextSvr auth.UserServer nextSvr = svr // 需要返回包装代理的 DiscoverServer - order := []string{"auth"} + order := GetChainOrder() for i := range order { factory, exist := serverProxyFactories[order[i]] if !exist { @@ -76,3 +80,10 @@ func BuildServer() (*Server, auth.UserServer, error) { } return svr, nextSvr, nil } + +func GetChainOrder() []string { + return []string{ + "auth", + "paramcheck", + } +} diff --git a/auth/user/group.go b/auth/user/group.go index d40a6ae7a..15dbbe487 100644 --- a/auth/user/group.go +++ b/auth/user/group.go @@ -28,8 +28,10 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -37,31 +39,12 @@ import ( type ( // UserGroup2Api is the user group to api - UserGroup2Api func(user *model.UserGroup) *apisecurity.UserGroup -) - -var ( - // UserLinkGroupAttributes is the user link group attributes - UserLinkGroupAttributes = map[string]struct{}{ - "id": {}, - "user_id": {}, - "user_name": {}, - "group_id": {}, - "name": {}, - "offset": {}, - "limit": {}, - "owner": {}, - } + UserGroup2Api func(user *authcommon.UserGroup) *apisecurity.UserGroup ) // CreateGroup create a group func (svr *Server) CreateGroup(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { - var ( - requestID = utils.ParseRequestID(ctx) - platformID = utils.ParsePlatformID(ctx) - ownerID = utils.ParseOwnerID(ctx) - ) - + ownerID := utils.ParseOwnerID(ctx) req.Owner = utils.NewStringValue(ownerID) if rsp := svr.preCheckGroupRelation(req.GetRelation()); rsp != nil { return rsp @@ -70,8 +53,7 @@ func (svr *Server) CreateGroup(ctx context.Context, req *apisecurity.UserGroup) // 根据 owner + groupname 确定唯一的用户组信息 group, err := svr.storage.GetGroupByName(req.GetName().GetValue(), ownerID) if err != nil { - log.Error("get group when create", utils.ZapRequestID(requestID), - utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("get group when create", utils.RequestID(ctx), zap.Error(err)) return api.NewGroupResponse(commonstore.StoreCode2APICode(err), req) } @@ -81,22 +63,41 @@ func (svr *Server) CreateGroup(ctx context.Context, req *apisecurity.UserGroup) data, err := svr.createGroupModel(req) if err != nil { - log.Error("create group model", utils.ZapRequestID(requestID), - utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("create group model", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponseWithMsg(apimodel.Code_ExecuteException, err.Error()) } - if err := svr.storage.AddGroup(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][User] create user_group begion storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } + defer func() { + _ = tx.Rollback() + }() + + if err := svr.storage.AddGroup(tx, data); err != nil { + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } + if err := svr.policySvr.PolicyHelper().CreatePrincipal(ctx, tx, authcommon.Principal{ + PrincipalID: data.ID, + PrincipalType: authcommon.PrincipalGroup, + Owner: data.Owner, + Name: data.Name, + }); err != nil { + log.Error("[Auth][User] add user_group default policy rule", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + if err := tx.Commit(); err != nil { + log.Error("[Auth][User] create user_group commit storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } - log.Info("create group", zap.String("name", req.Name.GetValue()), utils.ZapRequestID(requestID), - utils.ZapPlatformID(platformID)) + log.Info("create group", zap.String("name", req.Name.GetValue()), utils.RequestID(ctx)) svr.RecordHistory(userGroupRecordEntry(ctx, req, data.UserGroup, model.OCreate)) req.Id = utils.NewStringValue(data.ID) - return api.NewGroupResponse(apimodel.Code_ExecuteSuccess, req) } @@ -105,11 +106,9 @@ func (svr *Server) UpdateGroups( ctx context.Context, groups []*apisecurity.ModifyUserGroup) *apiservice.BatchWriteResponse { resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) for index := range groups { - req := groups[index] - ret := svr.UpdateGroup(ctx, req) + ret := svr.UpdateGroup(ctx, groups[index]) api.Collect(resp, ret) } - return resp } @@ -163,11 +162,27 @@ func (svr *Server) DeleteGroup(ctx context.Context, req *apisecurity.UserGroup) if group == nil { return api.NewGroupResponse(apimodel.Code_ExecuteSuccess, req) } + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][User] delete user_group begion storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } + defer func() { + _ = tx.Rollback() + }() - if err := svr.storage.DeleteGroup(group); err != nil { + if err := svr.storage.DeleteGroup(tx, group); err != nil { log.Error("delete group from store", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } + if err := svr.policySvr.PolicyHelper().CleanPrincipal(ctx, tx, authcommon.Principal{ + PrincipalID: group.ID, + PrincipalType: authcommon.PrincipalGroup, + Owner: group.Owner, + }); err != nil { + log.Error("[Auth][User] delete user_group from policy server", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } log.Info("delete group", utils.RequestID(ctx), zap.String("name", req.Name.GetValue())) svr.RecordHistory(userGroupRecordEntry(ctx, req, group.UserGroup, model.ODelete)) @@ -176,27 +191,16 @@ func (svr *Server) DeleteGroup(ctx context.Context, req *apisecurity.UserGroup) } // GetGroups 查看用户组 -func (svr *Server) GetGroups(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - log.Info("[Auth][Group] origin get groups query params", - utils.RequestID(ctx), zap.Any("query", query)) - - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) - } - - searchFilters := make(map[string]string, len(query)) - for key, value := range query { - if _, ok := UserLinkGroupAttributes[key]; !ok { - log.Errorf("[Auth][Group] get groups attribute(%s) it not allowed", key) - return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") - } - searchFilters[key] = value - } - - total, groups, err := svr.storage.GetGroups(searchFilters, offset, limit) +func (svr *Server) GetGroups(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { + offset, limit, _ := utils.ParseOffsetAndLimit(filters) + total, groups, err := svr.cacheMgr.User().QueryUserGroups(ctx, cachetypes.UserGroupSearchArgs{ + Filters: filters, + Offset: offset, + Limit: limit, + }) if err != nil { - log.Errorf("[Auth][Group] get groups req(%+v) store err: %s", query, err.Error()) + log.Error("[Auth][Group] list user_group from store", utils.RequestID(ctx), + zap.Any("filters", filters), zap.Error(err)) return api.NewAuthBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -222,12 +226,10 @@ func (svr *Server) GetGroup(ctx context.Context, req *apisecurity.UserGroup) *ap if req.GetId().GetValue() == "" { return api.NewAuthResponse(apimodel.Code_InvalidUserGroupID) } - group, errResp := svr.getGroupFromDB(req.Id.Value) if errResp != nil { return errResp } - return api.NewGroupResponse(apimodel.Code_ExecuteSuccess, svr.userGroupDetail2Api(group)) } @@ -237,32 +239,27 @@ func (svr *Server) GetGroupToken(ctx context.Context, req *apisecurity.UserGroup return api.NewAuthResponse(apimodel.Code_InvalidUserGroupID) } - groupCache, errResp := svr.getGroupFromCache(req) - if errResp != nil { - return errResp + group := svr.cacheMgr.User().GetGroup(req.Id.GetValue()) + if group == nil { + return api.NewGroupResponse(apimodel.Code_NotFoundUserGroup, req) } - req.AuthToken = utils.NewStringValue(groupCache.Token) - req.TokenEnable = utils.NewBoolValue(groupCache.TokenEnable) + req.AuthToken = utils.NewStringValue(group.Token) + req.TokenEnable = utils.NewBoolValue(group.TokenEnable) return api.NewGroupResponse(apimodel.Code_ExecuteSuccess, req) } -// UpdateGroupToken 调整用户组 token 的使用状态 (禁用|开启) -func (svr *Server) UpdateGroupToken(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { - var ( - requestID = utils.ParseRequestID(ctx) - platformID = utils.ParsePlatformID(ctx) - group, errResp = svr.getGroupFromDB(req.Id.GetValue()) - ) - +// EnableGroupToken 调整用户组 token 的使用状态 (禁用|开启) +func (svr *Server) EnableGroupToken(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { + group, errResp := svr.getGroupFromDB(req.Id.GetValue()) if errResp != nil { return errResp } group.TokenEnable = req.TokenEnable.GetValue() - modifyReq := &model.ModifyUserGroup{ + modifyReq := &authcommon.ModifyUserGroup{ ID: group.ID, Owner: group.Owner, Token: group.Token, @@ -271,12 +268,12 @@ func (svr *Server) UpdateGroupToken(ctx context.Context, req *apisecurity.UserGr } if err := svr.storage.UpdateGroup(modifyReq); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } log.Info("update group token", zap.String("id", req.Id.GetValue()), - zap.Bool("enable", group.TokenEnable), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + zap.Bool("enable", group.TokenEnable), utils.RequestID(ctx)) svr.RecordHistory(userGroupRecordEntry(ctx, req, group.UserGroup, model.OUpdateToken)) return api.NewGroupResponse(apimodel.Code_ExecuteSuccess, req) @@ -306,7 +303,7 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro } group.Token = newToken - modifyReq := &model.ModifyUserGroup{ + modifyReq := &authcommon.ModifyUserGroup{ ID: group.ID, Owner: group.Owner, Token: group.Token, @@ -329,7 +326,7 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro } // getGroupFromDB 获取用户组 -func (svr *Server) getGroupFromDB(id string) (*model.UserGroupDetail, *apiservice.Response) { +func (svr *Server) getGroupFromDB(id string) (*authcommon.UserGroupDetail, *apiservice.Response) { group, err := svr.storage.GetGroup(id) if err != nil { log.Error("get group from store", zap.Error(err)) @@ -338,17 +335,6 @@ func (svr *Server) getGroupFromDB(id string) (*model.UserGroupDetail, *apiservic if group == nil { return nil, api.NewAuthResponse(apimodel.Code_NotFoundUserGroup) } - - return group, nil -} - -// getGroupFromCache 从缓存中获取用户组信息数据 -func (svr *Server) getGroupFromCache(req *apisecurity.UserGroup) (*model.UserGroupDetail, *apiservice.Response) { - group := svr.cacheMgr.User().GetGroup(req.Id.GetValue()) - if group == nil { - return nil, api.NewGroupResponse(apimodel.Code_NotFoundUserGroup, req) - } - return group, nil } @@ -386,11 +372,11 @@ func (svr *Server) checkUpdateGroup(ctx context.Context, req *apisecurity.Modify } // UpdateGroupAttribute 更新计算用户组更新时的结构体数据,并判断是否需要执行更新操作 -func UpdateGroupAttribute(ctx context.Context, old *model.UserGroup, newUser *apisecurity.ModifyUserGroup) ( - *model.ModifyUserGroup, bool) { +func UpdateGroupAttribute(ctx context.Context, old *authcommon.UserGroup, newUser *apisecurity.ModifyUserGroup) ( + *authcommon.ModifyUserGroup, bool) { var ( needUpdate bool - ret = &model.ModifyUserGroup{ + ret = &authcommon.ModifyUserGroup{ ID: old.ID, Token: old.Token, TokenEnable: old.TokenEnable, @@ -429,24 +415,24 @@ func UpdateGroupAttribute(ctx context.Context, old *model.UserGroup, newUser *ap } // enhancedGroups2Api 数组专为 []*apisecurity.UserGroup -func enhancedGroups2Api(groups []*model.UserGroup, handler UserGroup2Api) []*apisecurity.UserGroup { +func enhancedGroups2Api(groups []*authcommon.UserGroupDetail, handler UserGroup2Api) []*apisecurity.UserGroup { out := make([]*apisecurity.UserGroup, 0, len(groups)) for k := range groups { - out = append(out, handler(groups[k])) + out = append(out, handler(groups[k].UserGroup)) } return out } // createGroupModel 创建用户组的存储模型 -func (svr *Server) createGroupModel(req *apisecurity.UserGroup) (group *model.UserGroupDetail, err error) { +func (svr *Server) createGroupModel(req *apisecurity.UserGroup) (group *authcommon.UserGroupDetail, err error) { ids := make(map[string]struct{}, len(req.GetRelation().GetUsers())) for index := range req.GetRelation().GetUsers() { ids[req.GetRelation().GetUsers()[index].GetId().GetValue()] = struct{}{} } - group = &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + group = &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: utils.NewUUID(), Name: req.GetName().GetValue(), Owner: req.GetOwner().GetValue(), @@ -466,7 +452,7 @@ func (svr *Server) createGroupModel(req *apisecurity.UserGroup) (group *model.Us } // model.UserGroup 转为 api.UserGroup -func userGroup2Api(group *model.UserGroup) *apisecurity.UserGroup { +func userGroup2Api(group *authcommon.UserGroup) *apisecurity.UserGroup { if group == nil { return nil } @@ -486,7 +472,7 @@ func userGroup2Api(group *model.UserGroup) *apisecurity.UserGroup { } // model.UserGroupDetail 转为 api.UserGroup,并且主动填充 user 的信息数据 -func (svr *Server) userGroupDetail2Api(group *model.UserGroupDetail) *apisecurity.UserGroup { +func (svr *Server) userGroupDetail2Api(group *authcommon.UserGroupDetail) *apisecurity.UserGroup { if group == nil { return nil } @@ -524,7 +510,7 @@ func (svr *Server) userGroupDetail2Api(group *model.UserGroupDetail) *apisecurit } // userGroupRecordEntry 生成用户组的记录entry -func userGroupRecordEntry(ctx context.Context, req *apisecurity.UserGroup, md *model.UserGroup, +func userGroupRecordEntry(ctx context.Context, req *apisecurity.UserGroup, md *authcommon.UserGroup, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -543,7 +529,7 @@ func userGroupRecordEntry(ctx context.Context, req *apisecurity.UserGroup, md *m } // 生成修改用户组的记录entry -func modifyUserGroupRecordEntry(ctx context.Context, req *apisecurity.ModifyUserGroup, md *model.UserGroup, +func modifyUserGroupRecordEntry(ctx context.Context, req *apisecurity.ModifyUserGroup, md *authcommon.UserGroup, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -562,7 +548,7 @@ func modifyUserGroupRecordEntry(ctx context.Context, req *apisecurity.ModifyUser } // 生成用户-用户组关联关系的记录entry -func userRelationRecordEntry(ctx context.Context, req *apisecurity.UserGroupRelation, md *model.UserGroup, +func userRelationRecordEntry(ctx context.Context, req *apisecurity.UserGroupRelation, md *authcommon.UserGroup, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -579,3 +565,18 @@ func userRelationRecordEntry(ctx context.Context, req *apisecurity.UserGroupRela return entry } + +func defaultUserGroupPolicy(u *authcommon.UserGroupDetail) *authcommon.StrategyDetail { + // Create the user's default weight policy + return &authcommon.StrategyDetail{ + ID: utils.NewUUID(), + Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalGroup, u.Name), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Default: true, + Owner: u.Owner, + Revision: utils.NewUUID(), + Resources: []authcommon.StrategyResource{}, + Valid: true, + Comment: "Default Strategy", + } +} diff --git a/auth/user/group_test.go b/auth/user/group_test.go index e36229dc9..087a95439 100644 --- a/auth/user/group_test.go +++ b/auth/user/group_test.go @@ -32,7 +32,7 @@ import ( "github.com/polarismesh/polaris/cache" cachetypes "github.com/polarismesh/polaris/cache/api" v1 "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" storemock "github.com/polarismesh/polaris/store/mock" ) @@ -40,13 +40,13 @@ import ( type GroupTest struct { ctrl *gomock.Controller - ownerOne *model.User - ownerTwo *model.User + ownerOne *authcommon.User + ownerTwo *authcommon.User - users []*model.User - groups []*model.UserGroupDetail - newGroups []*model.UserGroupDetail - allGroups []*model.UserGroupDetail + users []*authcommon.User + groups []*authcommon.UserGroupDetail + newGroups []*authcommon.UserGroupDetail + allGroups []*authcommon.UserGroupDetail storage *storemock.MockStore cacheMgn *cache.CacheManager @@ -72,7 +72,7 @@ func newGroupTest(t *testing.T) *GroupTest { storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - storage.EXPECT().AddGroup(gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().AddGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(append(users, newUsers...), nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allGroups, nil) @@ -104,7 +104,7 @@ func newGroupTest(t *testing.T) *GroupTest { "salt": "polarismesh@2021", }, }, - }, storage, cacheMgn) + }, storage, nil, cacheMgn) _ = cacheMgn.TestUpdate() return &GroupTest{ ctrl: ctrl, @@ -496,7 +496,7 @@ func Test_server_DeleteGroup(t *testing.T) { reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any()).AnyTimes().Return(nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ { @@ -516,7 +516,7 @@ func Test_server_DeleteGroup(t *testing.T) { reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(nil, nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any()).AnyTimes().Return(nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ { @@ -536,7 +536,7 @@ func Test_server_DeleteGroup(t *testing.T) { reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any()).AnyTimes().Return(nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ { @@ -556,7 +556,7 @@ func Test_server_DeleteGroup(t *testing.T) { reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any()).AnyTimes().Return(nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ { @@ -580,7 +580,7 @@ func Test_server_UpdateGroupToken(t *testing.T) { groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).AnyTimes().Return(nil) - batchResp := groupTest.svr.UpdateGroupToken(reqCtx, &apisecurity.UserGroup{ + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ Id: utils.NewStringValue(groupTest.groups[2].ID), }) @@ -596,7 +596,7 @@ func Test_server_UpdateGroupToken(t *testing.T) { groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - batchResp := groupTest.svr.UpdateGroupToken(reqCtx, &apisecurity.UserGroup{ + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ Id: utils.NewStringValue(groupTest.groups[2].ID), }) @@ -612,7 +612,7 @@ func Test_server_UpdateGroupToken(t *testing.T) { groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - batchResp := groupTest.svr.UpdateGroupToken(reqCtx, &apisecurity.UserGroup{ + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ Id: utils.NewStringValue(groupTest.groups[2].ID), }) diff --git a/auth/user/inteceptor/auth/server.go b/auth/user/inteceptor/auth/server.go index 38171005b..6cc411516 100644 --- a/auth/user/inteceptor/auth/server.go +++ b/auth/user/inteceptor/auth/server.go @@ -24,29 +24,16 @@ import ( apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/wrapperspb" "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" + authmodel "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) -var ( - // MustOwner 必须超级账户 or 主账户 - MustOwner = true - // NotOwner 任意账户 - NotOwner = false - // WriteOp 写操作 - WriteOp = true - // ReadOp 读操作 - ReadOp = false -) - func NewServer(nextSvr auth.UserServer) auth.UserServer { return &Server{ nextSvr: nextSvr, @@ -54,12 +41,13 @@ func NewServer(nextSvr auth.UserServer) auth.UserServer { } type Server struct { - nextSvr auth.UserServer + nextSvr auth.UserServer + policySvr auth.StrategyServer } // Initialize 初始化 -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager) error { - return svr.nextSvr.Initialize(authOpt, storage, cacheMgr) +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policyMgr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { + return svr.nextSvr.Initialize(authOpt, storage, policyMgr, cacheMgr) } // Name 用户数据管理server名称 @@ -73,7 +61,7 @@ func (svr *Server) Login(req *apisecurity.LoginRequest) *apiservice.Response { } // CheckCredential 检查当前操作用户凭证 -func (svr *Server) CheckCredential(authCtx *model.AcquireContext) error { +func (svr *Server) CheckCredential(authCtx *authmodel.AcquireContext) error { return svr.nextSvr.CheckCredential(authCtx) } @@ -84,408 +72,443 @@ func (svr *Server) GetUserHelper() auth.UserHelper { // CreateUsers 批量创建用户 func (svr *Server) CreateUsers(ctx context.Context, users []*apisecurity.User) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) - return resp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.CreateUsers), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.CreateUsers(ctx, users) + return svr.nextSvr.CreateUsers(authCtx.GetRequestContext(), users) } // UpdateUser 更新用户信息 func (svr *Server) UpdateUser(ctx context.Context, user *apisecurity.User) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - rsp.User = user - return rsp - } - helper := svr.GetUserHelper() - targetUser := helper.GetUserByID(ctx, user.GetId().GetValue()) - if targetUser == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundUser) - } - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, user.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.UpdateUser), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: { + authmodel.ResourceEntry{ + ID: user.GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.UpdateUser(ctx, user) + return svr.nextSvr.UpdateUser(authCtx.GetRequestContext(), user) } // UpdateUserPassword 更新用户密码 func (svr *Server) UpdateUserPassword(ctx context.Context, req *apisecurity.ModifyUserPassword) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp - } - helper := svr.GetUserHelper() - targetUser := helper.GetUserByID(ctx, req.GetId().GetValue()) - if targetUser == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundUser) - } - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, req.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.UpdateUserPassword), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: { + authmodel.ResourceEntry{ + ID: req.GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.UpdateUserPassword(ctx, req) + return svr.nextSvr.UpdateUserPassword(authCtx.GetRequestContext(), req) } // DeleteUsers 批量删除用户 func (svr *Server) DeleteUsers(ctx context.Context, users []*apisecurity.User) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) - return resp - } + helper := svr.nextSvr.GetUserHelper() + resources := make([]authcommon.ResourceEntry, 0, len(users)) for i := range users { - user := users[i] - helper := svr.GetUserHelper() - targetUser := helper.GetUserByID(ctx, user.GetId().GetValue()) - // 已经删除的用户没必要在删除一次 - if targetUser == nil { - continue - } - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthBatchWriteResponse(apimodel.Code_NotAllowedAccess) + saveUser := helper.GetUserByID(ctx, users[i].GetId().GetValue()) + if saveUser == nil { + return api.NewBatchWriteResponse(apimodel.Code_NotFoundUser) } + resources = append(resources, authmodel.ResourceEntry{ + ID: users[i].GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }) } - return svr.nextSvr.DeleteUsers(ctx, users) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Delete), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DeleteUsers), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: resources, + }), + ) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + } + return svr.nextSvr.DeleteUsers(authCtx.GetRequestContext(), users) } // GetUsers 查询用户列表 func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code(rsp.GetCode().Value), rsp.Info.Value) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeUsers), + ) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } + ctx = authCtx.GetRequestContext() query["hide_admin"] = strconv.FormatBool(true) // 如果不是超级管理员,查看数据有限制 - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { + if authcommon.ParseUserRole(ctx) != authmodel.AdminUserRole { // 设置 owner 参数,只能查看对应 owner 下的用户 query["owner"] = utils.ParseOwnerID(ctx) } - return svr.nextSvr.GetUsers(ctx, query) + + cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ + Type: apisecurity.ResourceType_Users, + ID: u.ID, + Metadata: u.Metadata, + }) + }) + + return svr.nextSvr.GetUsers(authCtx.GetRequestContext(), query) } // GetUserToken 获取用户的 token func (svr *Server) GetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp - } - helper := svr.GetUserHelper() - targetUser := helper.GetUser(ctx, user) - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, user.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeUserToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: { + authmodel.ResourceEntry{ + ID: user.GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetUserToken(ctx, user) + return svr.nextSvr.GetUserToken(authCtx.GetRequestContext(), user) } // UpdateUserToken 禁止用户的token使用 -func (svr *Server) UpdateUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, NotOwner) - if rsp != nil { - return rsp - } - helper := svr.GetUserHelper() - targetUser := helper.GetUser(ctx, user) - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) - } - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { - if targetUser.GetUserType().GetValue() != strconv.Itoa(int(model.SubAccountUserRole)) { - return api.NewUserResponseWithMsg(apimodel.Code_NotAllowedAccess, "only disable sub-account token", user) - } +func (svr *Server) EnableUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, user.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.EnableUserToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: { + authmodel.ResourceEntry{ + ID: user.GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.UpdateUserToken(ctx, user) + return svr.nextSvr.EnableUserToken(ctx, user) } // ResetUserToken 重置用户的token func (svr *Server) ResetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, NotOwner) - if rsp != nil { - return rsp - } - helper := svr.GetUserHelper() - targetUser := helper.GetUserByID(ctx, user.GetId().GetValue()) - if !checkUserViewPermission(ctx, targetUser) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, user.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.ResetUserToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_Users: { + authmodel.ResourceEntry{ + ID: user.GetId().GetValue(), + Type: apisecurity.ResourceType_Users, + Metadata: saveUser.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.ResetUserToken(ctx, user) } // CreateGroup 创建用户组 func (svr *Server) CreateGroup(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - return rsp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.CreateUserGroup), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.CreateGroup(ctx, group) + return svr.nextSvr.CreateGroup(authCtx.GetRequestContext(), group) } // UpdateGroups 更新用户组 func (svr *Server) UpdateGroups(ctx context.Context, groups []*apisecurity.ModifyUserGroup) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) - return resp - } - - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) + helper := svr.nextSvr.GetUserHelper() + resources := make([]authcommon.ResourceEntry, 0, len(groups)) for i := range groups { - item := groups[i] - rsp := svr.checkUpdateGroup(ctx, item) - api.Collect(resp, rsp) - } - if !api.IsSuccess(resp) { - return resp + saveGroup := helper.GetGroup(ctx, &apisecurity.UserGroup{Id: groups[i].GetId()}) + if saveGroup == nil { + return api.NewBatchWriteResponse(apimodel.Code_NotFoundUserGroup) + } + resources = append(resources, authmodel.ResourceEntry{ + Type: apisecurity.ResourceType_UserGroups, + ID: groups[i].GetId().GetValue(), + Metadata: saveGroup.Metadata, + }) } - return svr.nextSvr.UpdateGroups(ctx, groups) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.UpdateUserGroups), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: resources, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + } + return svr.nextSvr.UpdateGroups(authCtx.GetRequestContext(), groups) } // DeleteGroups 批量删除用户组 func (svr *Server) DeleteGroups(ctx context.Context, groups []*apisecurity.UserGroup) *apiservice.BatchWriteResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) - api.Collect(resp, rsp) - return resp - } - resp := api.NewAuthBatchWriteResponse(apimodel.Code_ExecuteSuccess) + helper := svr.nextSvr.GetUserHelper() + resources := make([]authcommon.ResourceEntry, 0, len(groups)) for i := range groups { - item := groups[i] - if !svr.checkGroupViewAuth(ctx, item.GetId().GetValue()) { - api.Collect(resp, api.NewAuthResponse(apimodel.Code_NotAllowedAccess)) + saveGroup := helper.GetGroup(ctx, &apisecurity.UserGroup{Id: groups[i].GetId()}) + if saveGroup == nil { + return api.NewBatchWriteResponse(apimodel.Code_NotFoundUserGroup) } + resources = append(resources, authmodel.ResourceEntry{ + ID: groups[i].GetId().GetValue(), + }) } - if !api.IsSuccess(resp) { - return resp + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DeleteUserGroups), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: resources, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.DeleteGroups(ctx, groups) } // GetGroups 查询用户组列表(不带用户详细信息) func (svr *Server) GetGroups(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - resp := api.NewAuthBatchQueryResponse(apimodel.Code_ExecuteSuccess) - api.QueryCollect(resp, rsp) - return resp + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeUserGroups), + ) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - - delete(query, "owner") - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { + ctx = authCtx.GetRequestContext() + if authcommon.ParseUserRole(ctx) != authmodel.AdminUserRole { // step 1: 设置 owner 信息,只能查看归属主帐户下的用户组 query["owner"] = utils.ParseOwnerID(ctx) - if authcommon.ParseUserRole(ctx) != model.OwnerUserRole { - // step 2: 非主帐户,只能查看自己所在的用户组 - if _, ok := query["user_id"]; !ok { - query["user_id"] = utils.ParseUserID(ctx) - } - } } + cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ + Type: apisecurity.ResourceType_UserGroups, + ID: u.ID, + Metadata: u.Metadata, + }) + }) + delete(query, "owner") return svr.nextSvr.GetGroups(ctx, query) } // GetGroup 根据用户组信息,查询该用户组下的用户相信 func (svr *Server) GetGroup(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp - } - if !svr.checkGroupViewAuth(ctx, req.GetId().GetValue()) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) + helper := svr.nextSvr.GetUserHelper() + saveGroup := helper.GetGroup(ctx, req) + if saveGroup == nil { + return api.NewResponse(apimodel.Code_NotFoundUserGroup) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeUserGroupDetail), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: { + authmodel.ResourceEntry{ + Type: apisecurity.ResourceType_UserGroups, + ID: req.GetId().GetValue(), + Metadata: saveGroup.Metadata, + }, + }, + }), + ) + + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.GetGroup(ctx, req) } // GetGroupToken 获取用户组的 token func (svr *Server) GetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, ReadOp, NotOwner) - if rsp != nil { - return rsp - } - if !svr.checkGroupViewAuth(ctx, group.GetId().GetValue()) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) - } - return svr.nextSvr.GetGroupToken(ctx, group) -} - -// UpdateGroupToken 取消用户组的 token 使用 -func (svr *Server) UpdateGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - return rsp - } - saveGroup := svr.GetUserHelper().GetGroup(ctx, &apisecurity.UserGroup{ - Id: wrapperspb.String(group.GetId().GetValue()), - }) + helper := svr.nextSvr.GetUserHelper() + saveGroup := helper.GetGroup(ctx, group) if saveGroup == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundUserGroup) - } - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { - if saveGroup.GetOwner().GetValue() != utils.ParseUserID(ctx) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) - } - } - return svr.nextSvr.UpdateGroupToken(ctx, group) -} - -// ResetGroupToken 重置用户组的 token -func (svr *Server) ResetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { - ctx, rsp := svr.verifyAuth(ctx, WriteOp, MustOwner) - if rsp != nil { - return rsp - } - saveGroup := svr.GetUserHelper().GetGroup(ctx, &apisecurity.UserGroup{ - Id: wrapperspb.String(group.GetId().GetValue()), - }) - if saveGroup == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundUserGroup) - } - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { - if saveGroup.GetOwner().GetValue() != utils.ParseUserID(ctx) { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) - } - } - return svr.nextSvr.ResetGroupToken(ctx, group) -} - -// verifyAuth 用于 user、group 以及 strategy 模块的鉴权工作检查 -func (svr *Server) verifyAuth(ctx context.Context, isWrite bool, - needOwner bool) (context.Context, *apiservice.Response) { - authToken := utils.ParseAuthToken(ctx) - - if authToken == "" { - log.Error("[Auth][Server] auth token is empty", utils.RequestID(ctx)) - return nil, api.NewAuthResponse(apimodel.Code_EmptyAutToken) - } - - authCtx := model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.AuthModule), + return api.NewResponse(apimodel.Code_NotFoundUserGroup) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeUserGroupToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: { + authmodel.ResourceEntry{ + ID: group.GetId().GetValue(), + Type: apisecurity.ResourceType_UserGroups, + Metadata: saveGroup.Metadata, + }, + }, + }), ) - // case 1. 如果 error 不是 token 被禁止的 error,直接返回 - // case 2. 如果 error 是 token 被禁止,按下面情况判断 - // i. 如果当前只是一个数据的读取操作,则放通 - // ii. 如果当前是一个数据的写操作,则只能允许处于正常的 token 进行操作 - if err := svr.CheckCredential(authCtx); err != nil { - log.Error("[Auth][Server] verify auth token", utils.RequestID(ctx), zap.Error(err)) - return nil, api.NewAuthResponse(apimodel.Code_AuthTokenForbidden) - } - - attachVal, exist := authCtx.GetAttachment(model.TokenDetailInfoKey) - if !exist { - log.Error("[Auth][Server] token detail info not exist", utils.RequestID(ctx)) - return nil, api.NewAuthResponse(apimodel.Code_TokenNotExisted) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - - operateInfo := attachVal.(auth.OperatorInfo) - if isWrite && operateInfo.Disable { - log.Error("[Auth][Server] token is disabled", utils.RequestID(ctx), - zap.String("operation", authCtx.GetMethod())) - return nil, api.NewAuthResponse(apimodel.Code_TokenDisabled) - } - - if !operateInfo.IsUserToken { - log.Error("[Auth][Server] only user role can access this API", utils.RequestID(ctx)) - return nil, api.NewAuthResponse(apimodel.Code_OperationRoleForbidden) - } - - if needOwner && auth.IsSubAccount(operateInfo) { - log.Error("[Auth][Server] only admin/owner account can access this API", utils.RequestID(ctx)) - return nil, api.NewAuthResponse(apimodel.Code_OperationRoleForbidden) - } - - return authCtx.GetRequestContext(), nil -} - -// checkUserViewPermission 检查是否可以操作该用户 -// Case 1: 如果是自己操作自己,通过 -// Case 2: 如果是主账户操作自己的子账户,通过 -// Case 3: 如果是超级账户,通过 -func checkUserViewPermission(ctx context.Context, user *apisecurity.User) bool { - role := authcommon.ParseUserRole(ctx) - if role == model.AdminUserRole { - log.Debug("check user view permission", utils.RequestID(ctx), zap.Bool("admin", true)) - return true - } - - userId := utils.ParseUserID(ctx) - if user.GetId().GetValue() == userId { - return true - } - - if user.GetOwner().GetValue() == userId { - log.Debug("check user view permission", utils.RequestID(ctx), - zap.Any("user", user), zap.String("owner", user.GetOwner().GetValue()), zap.String("operator", userId)) - return true - } - log.Warn("check user view permission", utils.RequestID(ctx), - zap.Any("user", user), zap.String("owner", user.GetOwner().GetValue()), zap.String("operator", userId)) - return false + return svr.nextSvr.GetGroupToken(ctx, group) } -// checkUpdateGroup 检查用户组的更新请求 -func (svr *Server) checkUpdateGroup(ctx context.Context, req *apisecurity.ModifyUserGroup) *apiservice.Response { - userId := utils.ParseUserID(ctx) - isOwner := utils.ParseIsOwner(ctx) - saveGroup := svr.GetUserHelper().GetGroup(ctx, &apisecurity.UserGroup{ - Id: wrapperspb.String(req.GetId().GetValue()), - }) +// EnableGroupToken 取消用户组的 token 使用 +func (svr *Server) EnableGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + helper := svr.nextSvr.GetUserHelper() + saveGroup := helper.GetGroup(ctx, group) if saveGroup == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundUserGroup) - } + return api.NewResponse(apimodel.Code_NotFoundUserGroup) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.EnableUserGroupToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: { + authmodel.ResourceEntry{ + ID: group.GetId().GetValue(), + Type: apisecurity.ResourceType_UserGroups, + Metadata: saveGroup.Metadata, + }, + }, + }), + ) - // 满足以下情况才可以进行操作 - // 1.管理员 - // 2.自己在这个用户组里面 - // 3.自己是这个用户组的owner角色 - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { - inGroup := false - for i := range saveGroup.GetRelation().GetUsers() { - if userId == saveGroup.GetRelation().GetUsers()[i].GetId().GetValue() { - inGroup = true - break - } - } - if !inGroup && saveGroup.GetOwner().GetValue() != userId { - return api.NewAuthResponse(apimodel.Code_NotAllowedAccess) - } - // 如果当前用户只是在这个组里面,但不是该用户组的owner,那只能添加用户,不能删除用户 - if inGroup && !isOwner && len(req.GetRemoveRelations().GetUsers()) != 0 { - return api.NewAuthResponseWithMsg( - apimodel.Code_NotAllowedAccess, "only main account can remove user from usergroup") - } + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return nil + return svr.nextSvr.EnableGroupToken(ctx, group) } -func (svr *Server) checkGroupViewAuth(ctx context.Context, id string) bool { - saveGroup := svr.GetUserHelper().GetGroup(ctx, &apisecurity.UserGroup{ - Id: wrapperspb.String(id), - }) +// ResetGroupToken 重置用户组的 token +func (svr *Server) ResetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + helper := svr.nextSvr.GetUserHelper() + saveGroup := helper.GetGroup(ctx, group) if saveGroup == nil { - return false - } + return api.NewResponse(apimodel.Code_NotFoundUserGroup) + } + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.ResetUserGroupToken), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authmodel.ResourceEntry{ + apisecurity.ResourceType_UserGroups: { + authmodel.ResourceEntry{ + ID: group.GetId().GetValue(), + Type: apisecurity.ResourceType_UserGroups, + Metadata: saveGroup.Metadata, + }, + }, + }), + ) - if authcommon.ParseUserRole(ctx) != model.AdminUserRole { - userID := utils.ParseUserID(ctx) - inGroup := svr.GetUserHelper().CheckUserInGroup(ctx, &apisecurity.UserGroup{ - Id: wrapperspb.String(id), - }, &apisecurity.User{ - Id: wrapperspb.String(userID), - }) - isGroupOwner := saveGroup.GetOwner().GetValue() == userID - if !isGroupOwner && !inGroup { - log.Error("can't see group info", zap.String("user", userID), - zap.String("group", id), zap.Bool("group-owner", isGroupOwner), - zap.Bool("in-group", inGroup)) - return false - } + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return true + return svr.nextSvr.ResetGroupToken(ctx, group) } diff --git a/auth/user/inteceptor/paramcheck/server.go b/auth/user/inteceptor/paramcheck/server.go new file mode 100644 index 000000000..f6917b00d --- /dev/null +++ b/auth/user/inteceptor/paramcheck/server.go @@ -0,0 +1,228 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + "strconv" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" + + "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" +) + +var ( + // UserFilterAttributes 查询用户所能允许的参数查询列表 + UserFilterAttributes = map[string]bool{ + "id": true, + "name": true, + "owner": true, + "source": true, + "offset": true, + "group_id": true, + "limit": true, + "hide_admin": true, + } + // UserGroupAttributes is the user link group attributes + UserGroupAttributes = map[string]struct{}{ + "id": {}, + "user_id": {}, + "user_name": {}, + "group_id": {}, + "name": {}, + "offset": {}, + "limit": {}, + "owner": {}, + } +) + +func NewServer(nextSvr auth.UserServer) auth.UserServer { + return &Server{ + nextSvr: nextSvr, + } +} + +type Server struct { + nextSvr auth.UserServer + policySvr auth.StrategyServer +} + +// Initialize 初始化 +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, + cacheMgr cachetypes.CacheManager) error { + return svr.nextSvr.Initialize(authOpt, storage, policySvr, cacheMgr) +} + +// Name 用户数据管理server名称 +func (svr *Server) Name() string { + return svr.nextSvr.Name() +} + +// Login 登录动作 +func (svr *Server) Login(req *apisecurity.LoginRequest) *apiservice.Response { + return svr.nextSvr.Login(req) +} + +// CheckCredential 检查当前操作用户凭证 +func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { + return svr.nextSvr.CheckCredential(authCtx) +} + +// GetUserHelper +func (svr *Server) GetUserHelper() auth.UserHelper { + return svr.nextSvr.GetUserHelper() +} + +// CreateUsers 批量创建用户 +func (svr *Server) CreateUsers(ctx context.Context, users []*apisecurity.User) *apiservice.BatchWriteResponse { + return svr.nextSvr.CreateUsers(ctx, users) +} + +// UpdateUser 更新用户信息 +func (svr *Server) UpdateUser(ctx context.Context, user *apisecurity.User) *apiservice.Response { + return svr.nextSvr.UpdateUser(ctx, user) +} + +// UpdateUserPassword 更新用户密码 +func (svr *Server) UpdateUserPassword(ctx context.Context, req *apisecurity.ModifyUserPassword) *apiservice.Response { + return svr.nextSvr.UpdateUserPassword(ctx, req) +} + +// DeleteUsers 批量删除用户 +func (svr *Server) DeleteUsers(ctx context.Context, users []*apisecurity.User) *apiservice.BatchWriteResponse { + return svr.nextSvr.DeleteUsers(ctx, users) +} + +// GetUsers 查询用户列表 +func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + log.Debug("[Auth][User] origin get users query params", utils.RequestID(ctx), zap.Any("query", query)) + var ( + offset, limit uint32 + err error + searchFilters = make(map[string]string, len(query)+1) + ) + + for key, value := range query { + if _, ok := UserFilterAttributes[key]; !ok { + log.Error("[Auth][User] attribute it not allowed", utils.RequestID(ctx), zap.String("key", key)) + return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") + } + searchFilters[key] = value + } + + offset, limit, err = utils.ParseOffsetAndLimit(searchFilters) + if err != nil { + return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) + } + searchFilters["offset"] = strconv.FormatUint(uint64(offset), 10) + searchFilters["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetUsers(ctx, query) +} + +// GetUserToken 获取用户的 token +func (svr *Server) GetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { + return svr.nextSvr.GetUserToken(ctx, user) +} + +// EnableUserToken 禁止用户的token使用 +func (svr *Server) EnableUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { + helper := svr.nextSvr.GetUserHelper() + saveUser := helper.GetUserByID(ctx, user.GetId().GetValue()) + if saveUser == nil { + return api.NewResponse(apimodel.Code_NotFoundUser) + } + if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { + if saveUser.GetUserType().GetValue() != strconv.Itoa(int(authcommon.SubAccountUserRole)) { + return api.NewUserResponseWithMsg(apimodel.Code_NotAllowedAccess, "only disable sub-account token", user) + } + } + return svr.nextSvr.EnableUserToken(ctx, user) +} + +// ResetUserToken 重置用户的token +func (svr *Server) ResetUserToken(ctx context.Context, user *apisecurity.User) *apiservice.Response { + return svr.nextSvr.ResetUserToken(ctx, user) +} + +// CreateGroup 创建用户组 +func (svr *Server) CreateGroup(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + return svr.nextSvr.CreateGroup(ctx, group) +} + +// UpdateGroups 更新用户组 +func (svr *Server) UpdateGroups(ctx context.Context, groups []*apisecurity.ModifyUserGroup) *apiservice.BatchWriteResponse { + return svr.nextSvr.UpdateGroups(ctx, groups) +} + +// DeleteGroups 批量删除用户组 +func (svr *Server) DeleteGroups(ctx context.Context, groups []*apisecurity.UserGroup) *apiservice.BatchWriteResponse { + return svr.nextSvr.DeleteGroups(ctx, groups) +} + +// GetGroups 查询用户组列表(不带用户详细信息) +func (svr *Server) GetGroups(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + log.Info("[Auth][Group] origin get groups query params", + utils.RequestID(ctx), zap.Any("query", query)) + + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) + } + + searchFilters := make(map[string]string, len(query)) + for key, value := range query { + if _, ok := UserGroupAttributes[key]; !ok { + log.Error("[Auth][Group] get groups attribute it not allowed", utils.RequestID(ctx), zap.String("key", key)) + return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") + } + searchFilters[key] = value + } + + searchFilters["offset"] = strconv.FormatUint(uint64(offset), 10) + searchFilters["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetGroups(ctx, query) +} + +// GetGroup 根据用户组信息,查询该用户组下的用户相信 +func (svr *Server) GetGroup(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { + return svr.nextSvr.GetGroup(ctx, req) +} + +// GetGroupToken 获取用户组的 token +func (svr *Server) GetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + return svr.nextSvr.GetGroupToken(ctx, group) +} + +// UpdateGroupToken 取消用户组的 token 使用 +func (svr *Server) EnableGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + return svr.nextSvr.EnableGroupToken(ctx, group) +} + +// ResetGroupToken 重置用户组的 token +func (svr *Server) ResetGroupToken(ctx context.Context, group *apisecurity.UserGroup) *apiservice.Response { + return svr.nextSvr.ResetGroupToken(ctx, group) +} diff --git a/auth/user/server.go b/auth/user/server.go index 2c76d14fa..db1484327 100644 --- a/auth/user/server.go +++ b/auth/user/server.go @@ -30,6 +30,7 @@ import ( cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/store" @@ -63,11 +64,11 @@ func DefaultUserConfig() *AuthConfig { } type Server struct { - authOpt *AuthConfig - storage store.Store - history plugin.History - cacheMgr cachetypes.CacheManager - helper auth.UserHelper + authOpt *AuthConfig + storage store.Store + policySvr auth.StrategyServer + cacheMgr cachetypes.CacheManager + helper auth.UserHelper } // Name of the user operator plugin @@ -75,9 +76,10 @@ func (svr *Server) Name() string { return auth.DefaultUserMgnPluginName } -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager) error { +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { svr.cacheMgr = cacheMgr svr.storage = storage + svr.policySvr = policySvr if err := svr.parseOptions(authOpt); err != nil { return err } @@ -85,11 +87,6 @@ func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, cacheMg _ = cacheMgr.OpenResourceCache(cachetypes.ConfigEntry{ Name: cachetypes.UsersName, }) - // 获取History插件,注意:插件的配置在bootstrap已经设置好 - svr.history = plugin.GetHistory() - if svr.history == nil { - log.Warnf("Not Found History Log Plugin") - } svr.helper = &DefaultUserHelper{svr: svr} return nil } @@ -150,9 +147,9 @@ func (svr *Server) Login(req *apisecurity.LoginRequest) *apiservice.Response { if err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { return api.NewAuthResponseWithMsg( - apimodel.Code_NotAllowedAccess, model.ErrorWrongUsernameOrPassword.Error()) + apimodel.Code_NotAllowedAccess, authcommon.ErrorWrongUsernameOrPassword.Error()) } - return api.NewAuthResponseWithMsg(apimodel.Code_ExecuteException, model.ErrorWrongUsernameOrPassword.Error()) + return api.NewAuthResponseWithMsg(apimodel.Code_ExecuteException, authcommon.ErrorWrongUsernameOrPassword.Error()) } return api.NewLoginResponse(apimodel.Code_ExecuteSuccess, &apisecurity.LoginResponse{ @@ -160,23 +157,13 @@ func (svr *Server) Login(req *apisecurity.LoginRequest) *apiservice.Response { OwnerId: utils.NewStringValue(user.Owner), Token: utils.NewStringValue(user.Token), Name: utils.NewStringValue(user.Name), - Role: utils.NewStringValue(model.UserRoleNames[user.Type]), + Role: utils.NewStringValue(authcommon.UserRoleNames[user.Type]), }) } // RecordHistory Server对外提供history插件的简单封装 func (svr *Server) RecordHistory(entry *model.RecordEntry) { - // 如果插件没有初始化,那么不记录history - if svr.history == nil { - return - } - // 如果数据为空,则不需要打印了 - if entry == nil { - return - } - - // 调用插件记录history - svr.history.Record(entry) + plugin.GetHistory().Record(entry) } func (svr *Server) GetUserHelper() auth.UserHelper { diff --git a/auth/user/token.go b/auth/user/token.go index 3942257ef..411401205 100644 --- a/auth/user/token.go +++ b/auth/user/token.go @@ -30,13 +30,13 @@ import ( "github.com/google/uuid" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) // decodeToken 解析 token 信息,如果 t == "",直接返回一个空对象 func (svr *Server) decodeToken(t string) (auth.OperatorInfo, error) { if t == "" { - return auth.OperatorInfo{}, model.ErrorTokenInvalid + return auth.OperatorInfo{}, authcommon.ErrorTokenInvalid } ret, err := DecryptMessage([]byte(svr.authOpt.Salt), t) @@ -45,59 +45,79 @@ func (svr *Server) decodeToken(t string) (auth.OperatorInfo, error) { } tokenDetails := strings.Split(ret, TokenSplit) if len(tokenDetails) != 2 { - return auth.OperatorInfo{}, model.ErrorTokenInvalid + return auth.OperatorInfo{}, authcommon.ErrorTokenInvalid } detail := strings.Split(tokenDetails[1], "/") if len(detail) != 2 { - return auth.OperatorInfo{}, model.ErrorTokenInvalid + return auth.OperatorInfo{}, authcommon.ErrorTokenInvalid } tokenInfo := auth.OperatorInfo{ Origin: t, - IsUserToken: detail[0] == model.TokenForUser, + IsUserToken: detail[0] == authcommon.TokenForUser, OperatorID: detail[1], - Role: model.UnknownUserRole, + Role: authcommon.UnknownUserRole, } return tokenInfo, nil } +type TokenPrincipal interface { + GetToken() string + Disable() bool + OwnerID() string + SelfID() string +} + // checkToken 对 token 进行检查,如果 token 是一个空,直接返回默认值,但是不返回错误 // return {owner-id} {is-owner} {error} func (svr *Server) checkToken(tokenInfo *auth.OperatorInfo) (string, bool, error) { if auth.IsEmptyOperator(*tokenInfo) { return "", false, nil } + principal, err := svr.getTokenPrincipal(tokenInfo) + if err != nil { + return "", false, err + } - id := tokenInfo.OperatorID + if tokenInfo.Origin != principal.GetToken() { + return "", false, authcommon.ErrorTokenNotExist + } + tokenInfo.Disable = principal.Disable() + if principal.OwnerID() == "" { + return principal.SelfID(), true, nil + } + + return principal.OwnerID(), false, nil +} + +func (svr *Server) getTokenPrincipal(tokenInfo *auth.OperatorInfo) (TokenPrincipal, error) { if tokenInfo.IsUserToken { - user := svr.cacheMgr.User().GetUserByID(id) - if user == nil { - return "", false, model.ErrorNoUser + user := svr.cacheMgr.User().GetUserByID(tokenInfo.OperatorID) + if user != nil { + return user, nil } - - if tokenInfo.Origin != user.Token { - return "", false, model.ErrorTokenNotExist + if err := svr.cacheMgr.User().Update(); err != nil { + return nil, err } - - tokenInfo.Disable = !user.TokenEnable - if user.Owner == "" { - return user.ID, true, nil + user = svr.cacheMgr.User().GetUserByID(tokenInfo.OperatorID) + if user != nil { + return user, nil } - - return user.Owner, false, nil + return nil, authcommon.ErrorNoUser } - group := svr.cacheMgr.User().GetGroup(id) - if group == nil { - return "", false, model.ErrorNoUserGroup + group := svr.cacheMgr.User().GetGroup(tokenInfo.OperatorID) + if group != nil { + return group, nil } - - if tokenInfo.Origin != group.Token { - return "", false, model.ErrorTokenNotExist + if err := svr.cacheMgr.User().Update(); err != nil { + return nil, err } - - tokenInfo.Disable = !group.TokenEnable - return group.Owner, false, nil + group = svr.cacheMgr.User().GetGroup(tokenInfo.OperatorID) + if group != nil { + return group, nil + } + return nil, authcommon.ErrorNoUserGroup } const ( @@ -125,9 +145,9 @@ func CreateToken(uid, gid string, salt string) (string, error) { var val string if uid == "" { - val = fmt.Sprintf("%s/%s", model.TokenForUserGroup, gid) + val = fmt.Sprintf("%s/%s", authcommon.TokenForUserGroup, gid) } else { - val = fmt.Sprintf("%s/%s", model.TokenForUser, uid) + val = fmt.Sprintf("%s/%s", authcommon.TokenForUser, uid) } token := fmt.Sprintf(TokenPattern, uuid.NewString()[8:16], val) diff --git a/auth/user/user.go b/auth/user/user.go index e04453830..54535f9f0 100644 --- a/auth/user/user.go +++ b/auth/user/user.go @@ -31,6 +31,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -41,22 +42,7 @@ import ( type ( // User2Api convert user to api.User - User2Api func(user *model.User) *apisecurity.User -) - -var ( - - // UserFilterAttributes 查询用户所能允许的参数查询列表 - UserFilterAttributes = map[string]bool{ - "id": true, - "name": true, - "owner": true, - "source": true, - "offset": true, - "group_id": true, - "limit": true, - "hide_admin": true, - } + User2Api func(user *authcommon.User) *apisecurity.User ) // CreateUsers 批量创建用户 @@ -81,7 +67,7 @@ func (svr *Server) CreateUser(ctx context.Context, req *apisecurity.User) *apise } // 如果创建的目标账户类型是非子账户,则 ownerId 需要设置为 “” - if convertCreateUserRole(authcommon.ParseUserRole(ctx)) != model.SubAccountUserRole { + if convertCreateUserRole(authcommon.ParseUserRole(ctx)) != authcommon.SubAccountUserRole { ownerID = "" } @@ -115,17 +101,40 @@ func (svr *Server) CreateUser(ctx context.Context, req *apisecurity.User) *apise func (svr *Server) createUser(ctx context.Context, req *apisecurity.User) *apiservice.Response { data, err := svr.createUserModel(req, authcommon.ParseUserRole(ctx)) - if err != nil { log.Error("[Auth][User] create user model", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponse(apimodel.Code_ExecuteException) } - if err := svr.storage.AddUser(data); err != nil { + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][User] create user begion storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } + defer func() { + _ = tx.Rollback() + }() + + if err := svr.storage.AddUser(tx, data); err != nil { log.Error("[Auth][User] add user into store", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } + if err := svr.policySvr.PolicyHelper().CreatePrincipal(ctx, tx, authcommon.Principal{ + PrincipalID: data.ID, + PrincipalType: authcommon.PrincipalUser, + Owner: data.Owner, + Name: data.Name, + }); err != nil { + log.Error("[Auth][User] add user default policy rule", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + + if err := tx.Commit(); err != nil { + log.Error("[Auth][User] create user commit storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } + log.Info("[Auth][User] create user", utils.RequestID(ctx), zap.String("name", req.GetName().GetValue())) svr.RecordHistory(userRecordEntry(ctx, req, data, model.OCreate)) @@ -183,8 +192,8 @@ func (svr *Server) UpdateUserPassword(ctx context.Context, req *apisecurity.Modi return api.NewAuthResponse(apimodel.Code_NotFoundUser) } - ignoreOrigin := authcommon.ParseUserRole(ctx) == model.AdminUserRole || - authcommon.ParseUserRole(ctx) == model.OwnerUserRole + ignoreOrigin := authcommon.ParseUserRole(ctx) == authcommon.AdminUserRole || + authcommon.ParseUserRole(ctx) == authcommon.OwnerUserRole data, needUpdate, err := updateUserPasswordAttribute(ignoreOrigin, user, req) if err != nil { log.Error("[Auth][User] compute user update attribute", zap.Error(err), @@ -227,10 +236,9 @@ func (svr *Server) DeleteUsers(ctx context.Context, reqs []*apisecurity.User) *a // Case 3. 主账户角色下,只能删除自己创建的子账户 // Case 4. 超级账户角色下,可以删除任意账户 func (svr *Server) DeleteUser(ctx context.Context, req *apisecurity.User) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) user, err := svr.storage.GetUser(req.Id.GetValue()) if err != nil { - log.Error("[Auth][User] get user from store", utils.ZapRequestID(requestID), zap.Error(err)) + log.Error("[Auth][User] get user from store", utils.RequestID(ctx), zap.Error(err)) return api.NewUserResponse(commonstore.StoreCode2APICode(err), req) } if user == nil { @@ -239,14 +247,14 @@ func (svr *Server) DeleteUser(ctx context.Context, req *apisecurity.User) *apise if user.ID == utils.ParseOwnerID(ctx) { log.Error("[Auth][User] delete user forbidden, can't delete when self is owner", - utils.ZapRequestID(requestID), zap.String("name", req.Name.GetValue())) + utils.RequestID(ctx), zap.String("name", req.Name.GetValue())) return api.NewUserResponse(apimodel.Code_NotAllowedAccess, req) } - if user.Type == model.OwnerUserRole { + if user.Type == authcommon.OwnerUserRole { count, err := svr.storage.GetSubCount(user) if err != nil { log.Error("[Auth][User] get user sub-account", zap.String("owner", user.ID), - utils.ZapRequestID(requestID), zap.Error(err)) + utils.RequestID(ctx), zap.Error(err)) return api.NewUserResponse(commonstore.StoreCode2APICode(err), req) } if count != 0 { @@ -254,53 +262,49 @@ func (svr *Server) DeleteUser(ctx context.Context, req *apisecurity.User) *apise return api.NewUserResponse(apimodel.Code_SubAccountExisted, req) } } + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][User] delete user begion storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } + defer func() { + _ = tx.Rollback() + }() - if err := svr.storage.DeleteUser(user); err != nil { - log.Error("[Auth][User] delete user from store", utils.ZapRequestID(requestID), zap.Error(err)) + if err := svr.storage.DeleteUser(tx, user); err != nil { + log.Error("[Auth][User] delete user from store", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + if err := svr.policySvr.PolicyHelper().CleanPrincipal(ctx, tx, authcommon.Principal{ + PrincipalID: user.ID, + PrincipalType: authcommon.PrincipalUser, + Owner: user.Owner, + }); err != nil { + log.Error("[Auth][User] delete user from policy server", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } + if err := tx.Commit(); err != nil { + log.Error("[Auth][User] delete user commit storage tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(apimodel.Code_ExecuteException) + } - log.Info("[Auth][User] delete user", utils.ZapRequestID(requestID), - zap.String("name", req.Name.GetValue())) + log.Info("[Auth][User] delete user", utils.RequestID(ctx), zap.String("name", req.Name.GetValue())) svr.RecordHistory(userRecordEntry(ctx, req, user, model.ODelete)) return api.NewUserResponse(apimodel.Code_ExecuteSuccess, req) } // GetUsers 查询用户列表 -func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - requestID := utils.ParseRequestID(ctx) - log.Debug("[Auth][User] origin get users query params", - utils.ZapRequestID(requestID), zap.Any("query", query)) - - var ( - offset, limit uint32 - err error - searchFilters = make(map[string]string, len(query)+1) - ) - - for key, value := range query { - if _, ok := UserFilterAttributes[key]; !ok { - log.Errorf("[Auth][User] attribute(%s) it not allowed", key) - return api.NewAuthBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") - } - - searchFilters[key] = value - } - - var ( - total uint32 - users []*model.User - ) - - offset, limit, err = utils.ParseOffsetAndLimit(searchFilters) +func (svr *Server) GetUsers(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { + offset, limit, _ := utils.ParseOffsetAndLimit(filters) + + total, users, err := svr.cacheMgr.User().QueryUsers(ctx, cachetypes.UserSearchArgs{ + Filters: filters, + Offset: offset, + Limit: limit, + }) if err != nil { - return api.NewAuthBatchQueryResponse(apimodel.Code_InvalidParameter) - } - - total, users, err = svr.storage.GetUsers(searchFilters, offset, limit) - if err != nil { - log.Error("[Auth][User] get user from store", zap.Any("req", searchFilters), + log.Error("[Auth][User] get user from store", utils.RequestID(ctx), zap.Any("req", filters), zap.Error(err)) return api.NewAuthBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -314,7 +318,7 @@ func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apise // GetUserToken 获取用户 token func (svr *Server) GetUserToken(ctx context.Context, req *apisecurity.User) *apiservice.Response { - var user *model.User + var user *authcommon.User if req.GetId().GetValue() != "" { user = svr.cacheMgr.User().GetUserByID(req.GetId().GetValue()) } else if req.GetName().GetValue() != "" { @@ -348,8 +352,8 @@ func (svr *Server) GetUserToken(ctx context.Context, req *apisecurity.User) *api return api.NewUserResponse(apimodel.Code_ExecuteSuccess, out) } -// UpdateUserToken 更新用户 token -func (svr *Server) UpdateUserToken(ctx context.Context, req *apisecurity.User) *apiservice.Response { +// EnableUserToken 更新用户 token +func (svr *Server) EnableUserToken(ctx context.Context, req *apisecurity.User) *apiservice.Response { if checkErrResp := checkUpdateUser(req); checkErrResp != nil { return checkErrResp } @@ -418,18 +422,18 @@ func (svr *Server) ResetUserToken(ctx context.Context, req *apisecurity.User) *a // step 2. 最后对 token 进行一些验证步骤的执行 // step 3. 兜底措施:如果开启了鉴权的非严格模式,则根据错误的类型,判断是否转为匿名用户进行访问 // - 如果是访问权限控制相关模块(用户、用户组、权限策略),不得转为匿名用户 -func (svr *Server) CheckCredential(authCtx *model.AcquireContext) error { +func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { checkErr := func() error { authToken := utils.ParseAuthToken(authCtx.GetRequestContext()) operator, err := svr.decodeToken(authToken) if err != nil { log.Error("[Auth][Checker] decode token", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) - return model.ErrorTokenInvalid + return authcommon.ErrorTokenInvalid } ownerId, isOwner, err := svr.checkToken(&operator) if err != nil { - log.Errorf("[Auth][Checker] check token err : %s", err.Error()) + log.Error("[Auth][Checker] check token", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) return err } @@ -441,8 +445,9 @@ func (svr *Server) CheckCredential(authCtx *model.AcquireContext) error { authCtx.SetRequestContext(ctx) svr.parseOperatorInfo(operator, authCtx) if operator.Disable { - log.Warn("[Auth][Checker] token already disabled", utils.RequestID(authCtx.GetRequestContext()), - zap.Any("token", operator.String())) + log.Error("[Auth][Checker] token has been set disable", utils.RequestID(authCtx.GetRequestContext()), + zap.String("operator", operator.String())) + return authcommon.ErrorTokenDisabled } return nil }() @@ -454,13 +459,12 @@ func (svr *Server) CheckCredential(authCtx *model.AcquireContext) error { log.Warn("[Auth][Checker] parse operator info, downgrade to anonymous", utils.RequestID(authCtx.GetRequestContext()), zap.Error(checkErr)) // 操作者信息解析失败,降级为匿名用户 - authCtx.SetAttachment(model.TokenDetailInfoKey, auth.NewAnonymous()) + authCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.NewAnonymous()) } - return nil } -func (svr *Server) parseOperatorInfo(operator auth.OperatorInfo, authCtx *model.AcquireContext) { +func (svr *Server) parseOperatorInfo(operator auth.OperatorInfo, authCtx *authcommon.AcquireContext) { ctx := authCtx.GetRequestContext() if operator.IsUserToken { user := svr.cacheMgr.User().GetUserByID(operator.OperatorID) @@ -478,38 +482,41 @@ func (svr *Server) parseOperatorInfo(operator auth.OperatorInfo, authCtx *model. } } - authCtx.SetAttachment(model.OperatorRoleKey, operator.Role) - authCtx.SetAttachment(model.OperatorPrincipalType, func() model.PrincipalType { - if operator.IsUserToken { - return model.PrincipalUser - } - return model.PrincipalGroup - }()) - authCtx.SetAttachment(model.OperatorIDKey, operator.OperatorID) - authCtx.SetAttachment(model.OperatorOwnerKey, operator) - authCtx.SetAttachment(model.TokenDetailInfoKey, operator) + authCtx.SetAttachment(authcommon.PrincipalKey, authcommon.Principal{ + PrincipalID: operator.OperatorID, + PrincipalType: func() authcommon.PrincipalType { + if operator.IsUserToken { + return authcommon.PrincipalUser + } + return authcommon.PrincipalGroup + }(), + }) + authCtx.SetAttachment(authcommon.OperatorRoleKey, operator.Role) + authCtx.SetAttachment(authcommon.OperatorIDKey, operator.OperatorID) + authCtx.SetAttachment(authcommon.OperatorOwnerKey, operator) + authCtx.SetAttachment(authcommon.TokenDetailInfoKey, operator) authCtx.SetRequestContext(ctx) } -func canDowngradeAnonymous(authCtx *model.AcquireContext, err error) bool { - if authCtx.GetModule() == model.AuthModule { +func canDowngradeAnonymous(authCtx *authcommon.AcquireContext, err error) bool { + if authCtx.GetModule() == authcommon.AuthModule || authCtx.GetModule() == authcommon.MaintainModule { return false } if !authCtx.IsAllowAnonymous() { return false } - if errors.Is(err, model.ErrorTokenInvalid) { + if errors.Is(err, authcommon.ErrorTokenInvalid) { return true } - if errors.Is(err, model.ErrorTokenNotExist) { + if errors.Is(err, authcommon.ErrorTokenNotExist) { return true } return false } // user 数组转为[]*apisecurity.User -func enhancedUsers2Api(users []*model.User, handler User2Api) []*apisecurity.User { +func enhancedUsers2Api(users []*authcommon.User, handler User2Api) []*apisecurity.User { out := make([]*apisecurity.User, 0, len(users)) for _, entry := range users { outUser := handler(entry) @@ -520,7 +527,7 @@ func enhancedUsers2Api(users []*model.User, handler User2Api) []*apisecurity.Use } // model.Service 转为 api.Service -func user2Api(user *model.User) *apisecurity.User { +func user2Api(user *authcommon.User) *apisecurity.User { if user == nil { return nil } @@ -535,14 +542,14 @@ func user2Api(user *model.User) *apisecurity.User { Comment: utils.NewStringValue(user.Comment), Ctime: utils.NewStringValue(commontime.Time2String(user.CreateTime)), Mtime: utils.NewStringValue(commontime.Time2String(user.ModifyTime)), - UserType: utils.NewStringValue(model.UserRoleNames[user.Type]), + UserType: utils.NewStringValue(authcommon.UserRoleNames[user.Type]), } return out } // 生成用户的记录entry -func userRecordEntry(ctx context.Context, req *apisecurity.User, md *model.User, +func userRecordEntry(ctx context.Context, req *apisecurity.User, md *authcommon.User, operationType model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -600,7 +607,7 @@ func checkUpdateUser(req *apisecurity.User) *apiservice.Response { } // updateUserAttribute 更新用户属性 -func updateUserAttribute(old *model.User, newUser *apisecurity.User) (*model.User, bool, error) { +func updateUserAttribute(old *authcommon.User, newUser *apisecurity.User) (*authcommon.User, bool, error) { var needUpdate = true if newUser.Comment != nil && old.Comment != newUser.Comment.GetValue() { @@ -612,7 +619,7 @@ func updateUserAttribute(old *model.User, newUser *apisecurity.User) (*model.Use // updateUserAttribute 更新用户密码信息,如果用户的密码被更新 func updateUserPasswordAttribute( - isAdmin bool, user *model.User, req *apisecurity.ModifyUserPassword) (*model.User, bool, error) { + isAdmin bool, user *authcommon.User, req *apisecurity.ModifyUserPassword) (*authcommon.User, bool, error) { needUpdate := false if err := CheckPassword(req.NewPassword); err != nil { @@ -642,7 +649,7 @@ func updateUserPasswordAttribute( } // createUserModel 创建用户模型 -func (svr *Server) createUserModel(req *apisecurity.User, role model.UserRoleType) (*model.User, error) { +func (svr *Server) createUserModel(req *apisecurity.User, role authcommon.UserRoleType) (*authcommon.User, error) { pwd, err := bcrypt.GenerateFromPassword([]byte(req.GetPassword().GetValue()), bcrypt.DefaultCost) if err != nil { return nil, err @@ -653,7 +660,7 @@ func (svr *Server) createUserModel(req *apisecurity.User, role model.UserRoleTyp id = req.GetId().GetValue() } - user := &model.User{ + user := &authcommon.User{ ID: id, Name: req.GetName().GetValue(), Password: string(pwd), @@ -668,7 +675,7 @@ func (svr *Server) createUserModel(req *apisecurity.User, role model.UserRoleTyp } // 如果不是子账户的话,owner 就是自己 - if user.Type != model.SubAccountUserRole { + if user.Type != authcommon.SubAccountUserRole { user.Owner = "" } @@ -683,14 +690,14 @@ func (svr *Server) createUserModel(req *apisecurity.User, role model.UserRoleTyp } // convertCreateUserRole 转换为创建的目标用户的用户角色类型 -func convertCreateUserRole(role model.UserRoleType) model.UserRoleType { - if role == model.AdminUserRole { - return model.OwnerUserRole +func convertCreateUserRole(role authcommon.UserRoleType) authcommon.UserRoleType { + if role == authcommon.AdminUserRole { + return authcommon.OwnerUserRole } - if role == model.OwnerUserRole { - return model.SubAccountUserRole + if role == authcommon.OwnerUserRole { + return authcommon.SubAccountUserRole } - return model.SubAccountUserRole + return authcommon.SubAccountUserRole } diff --git a/auth/user/user_test.go b/auth/user/user_test.go index 907f78f68..3df71acd2 100644 --- a/auth/user/user_test.go +++ b/auth/user/user_test.go @@ -34,21 +34,21 @@ import ( cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" storemock "github.com/polarismesh/polaris/store/mock" ) type UserTest struct { - admin *model.User - ownerOne *model.User - ownerTwo *model.User + admin *authcommon.User + ownerOne *authcommon.User + ownerTwo *authcommon.User - users []*model.User - newUsers []*model.User - groups []*model.UserGroupDetail - newGroups []*model.UserGroupDetail - allGroups []*model.UserGroupDetail + users []*authcommon.User + newUsers []*authcommon.User + groups []*authcommon.UserGroupDetail + newGroups []*authcommon.UserGroupDetail + allGroups []*authcommon.UserGroupDetail storage *storemock.MockStore cacheMgn *cache.CacheManager @@ -68,16 +68,16 @@ func newUserTest(t *testing.T) *UserTest { users := createMockUser(10, "one") newUsers := createMockUser(10, "two") admin := createMockUser(1, "admin")[0] - admin.Type = model.AdminUserRole + admin.Type = authcommon.AdminUserRole admin.Owner = "" groups := createMockUserGroup(users) storage := storemock.NewMockStore(ctrl) storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().AddUser(gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().AddUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) storage.EXPECT().GetUserByName(gomock.Eq("create-user-1"), gomock.Any()).AnyTimes().Return(nil, nil) - storage.EXPECT().GetUserByName(gomock.Eq("create-user-2"), gomock.Any()).AnyTimes().Return(&model.User{ + storage.EXPECT().GetUserByName(gomock.Eq("create-user-2"), gomock.Any()).AnyTimes().Return(&authcommon.User{ Name: "create-user-2", }, nil) @@ -86,7 +86,7 @@ func newUserTest(t *testing.T) *UserTest { storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allUsers, nil) storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) - storage.EXPECT().DeleteUser(gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().DeleteUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) cfg := &cache.Config{} ctx, cancel := context.WithCancel(context.Background()) @@ -117,7 +117,7 @@ func newUserTest(t *testing.T) *UserTest { "salt": "polarismesh@2021", }, }, - }, storage, cacheMgn) + }, storage, nil, cacheMgn) _ = cacheMgn.TestUpdate() @@ -341,7 +341,7 @@ func Test_server_Login(t *testing.T) { assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) assert.Equal(t, uint32(apimodel.Code_NotAllowedAccess), rsp.GetCode().GetValue()) - assert.Contains(t, rsp.GetInfo().GetValue(), model.ErrorWrongUsernameOrPassword.Error()) + assert.Contains(t, rsp.GetInfo().GetValue(), authcommon.ErrorWrongUsernameOrPassword.Error()) }) } @@ -388,7 +388,7 @@ func Test_server_UpdateUser(t *testing.T) { Comment: &wrappers.StringValue{Value: "update owner account info"}, } - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&model.User{ + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ ID: uid, Owner: utils.NewUUID(), }, nil) @@ -528,7 +528,7 @@ func Test_server_UpdateUserPassword(t *testing.T) { NewPassword: &wrappers.StringValue{Value: "polaris@subaccount"}, } - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&model.User{ + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ ID: uid, Owner: utils.NewUUID(), }, nil) @@ -624,9 +624,9 @@ func Test_server_DeleteUser(t *testing.T) { }) uid := utils.NewUUID() - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&model.User{ + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ ID: uid, - Type: model.OwnerUserRole, + Type: authcommon.OwnerUserRole, Owner: "", }, nil) @@ -666,9 +666,9 @@ func Test_server_DeleteUser(t *testing.T) { uid := utils.NewUUID() oid := utils.NewUUID() - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&model.User{ + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ ID: uid, - Type: model.OwnerUserRole, + Type: authcommon.OwnerUserRole, Owner: oid, }, nil).AnyTimes() @@ -857,7 +857,7 @@ func Test_server_UpdateUserToken(t *testing.T) { defer userTest.Clean() _ = userTest.cacheMgn.TestUpdate() reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - resp := userTest.svr.UpdateUserToken(reqCtx, &apisecurity.User{ + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ Id: utils.NewStringValue(userTest.ownerOne.ID), }) @@ -870,10 +870,10 @@ func Test_server_UpdateUserToken(t *testing.T) { _ = userTest.cacheMgn.TestUpdate() reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[4].Token) - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&model.User{}, nil).AnyTimes() + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{}, nil).AnyTimes() userTest.storage.EXPECT().UpdateUser(gomock.Any()).Return(nil).AnyTimes() - resp := userTest.svr.UpdateUserToken(reqCtx, &apisecurity.User{ + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ Id: utils.NewStringValue(userTest.users[4].ID), }) @@ -885,7 +885,7 @@ func Test_server_UpdateUserToken(t *testing.T) { defer userTest.Clean() reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.users[3].ID)).Return(userTest.users[3], nil) - resp := userTest.svr.UpdateUserToken(reqCtx, &apisecurity.User{ + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ Id: utils.NewStringValue(userTest.users[3].ID), }) @@ -900,7 +900,7 @@ func Test_server_UpdateUserToken(t *testing.T) { t.Logf("operator-id : %s, user-two-owner : %s", userTest.ownerOne.ID, userTest.ownerTwo.ID) userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerTwo.ID)).Return(userTest.ownerTwo, nil).AnyTimes() - resp := userTest.svr.UpdateUserToken(reqCtx, &apisecurity.User{ + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ Id: utils.NewStringValue(userTest.ownerTwo.ID), }) @@ -914,7 +914,7 @@ func Test_server_UpdateUserToken(t *testing.T) { _ = userTest.cacheMgn.TestUpdate() reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.newUsers[3].ID)).Return(userTest.newUsers[3], nil).AnyTimes() - resp := userTest.svr.UpdateUserToken(reqCtx, &apisecurity.User{ + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ Id: utils.NewStringValue(userTest.newUsers[3].ID), }) @@ -945,7 +945,7 @@ func Test_AuthServer_NormalOperateUser(t *testing.T) { }) t.Run("非正常创建用户-直接操作存储层", func(t *testing.T) { - err := suit.Storage.AddUser(&model.User{}) + err := suit.Storage.AddUser(nil, &authcommon.User{}) assert.Error(t, err) }) diff --git a/bootstrap/server.go b/bootstrap/server.go index b3c17f314..117679ddb 100644 --- a/bootstrap/server.go +++ b/bootstrap/server.go @@ -42,6 +42,7 @@ import ( "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/common/version" config_center "github.com/polarismesh/polaris/config" @@ -347,19 +348,7 @@ func RestartServers(errCh chan error) error { return err } log.Infof("new config: %+v", cfg) - - // 把配置的每个apiserver,进行重启 - for _, protocol := range cfg.APIServers { - server, exist := apiserver.Slots[protocol.Name] - if !exist { - log.Errorf("api server slot %s not exists\n", protocol.Name) - return err - } - log.Infof("begin restarting server: %s", protocol.Name) - if err := server.Restart(protocol.Option, protocol.API, errCh); err != nil { - return err - } - } + // TODO: 配置的动态加载后续统一设计 return nil } @@ -451,8 +440,11 @@ func genContext() context.Context { ctx := context.Background() reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") ctx = context.WithValue(ctx, utils.StringContext("request-id"), fmt.Sprintf("self-%d", time.Now().Nanosecond())) - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, model.NewAcquireContext( - model.WithOperation(model.Read), model.WithModule(model.BootstrapModule), model.WithRequestContext(reqCtx))) + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, + authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.BootstrapModule), + authcommon.WithRequestContext(reqCtx))) return ctx } diff --git a/cache/api/funcs.go b/cache/api/funcs.go index 9567e3481..0fa3d26a6 100644 --- a/cache/api/funcs.go +++ b/cache/api/funcs.go @@ -18,10 +18,10 @@ package api import ( + "context" "crypto/sha1" "encoding/hex" "hash" - "sort" "time" ) @@ -53,8 +53,6 @@ func CompositeComputeRevision(revisions []string) (string, error) { } h := sha1.New() - sort.Strings(revisions) - for i := range revisions { if _, err := h.Write([]byte(revisions[i])); err != nil { return "", err @@ -63,3 +61,284 @@ func CompositeComputeRevision(revisions []string) (string, error) { return hex.EncodeToString(h.Sum(nil)), nil } + +// + +type ( + namespacePredicateCtxKey struct{} + servicePredicateCtxKey struct{} + routeRulePredicateCtxKey struct{} + ratelimitRulePredicateCtxKey struct{} + circuitbreakerRulePredicateCtxKey struct{} + faultdetectRulePredicateCtxKey struct{} + laneRulePredicateCtxKey struct{} + configGroupPredicateCtxKey struct{} + userPredicateCtxKey struct{} + userGroupPredicateCtxKey struct{} + authPolicyPredicateCtxKey struct{} + authRolePredicateCtxKey struct{} +) + +func AppendNamespacePredicate(ctx context.Context, p NamespacePredicate) context.Context { + var predicates []NamespacePredicate + + val := ctx.Value(namespacePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]NamespacePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, namespacePredicateCtxKey{}, predicates) +} + +func LoadNamespacePredicates(ctx context.Context) []NamespacePredicate { + var predicates []NamespacePredicate + + val := ctx.Value(namespacePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]NamespacePredicate) + } + return predicates +} + +func AppendServicePredicate(ctx context.Context, p ServicePredicate) context.Context { + var predicates []ServicePredicate + + val := ctx.Value(servicePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]ServicePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, servicePredicateCtxKey{}, predicates) +} + +func LoadServicePredicates(ctx context.Context) []ServicePredicate { + var predicates []ServicePredicate + + val := ctx.Value(servicePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]ServicePredicate) + } + return predicates +} + +func AppendRouterRulePredicate(ctx context.Context, p RouteRulePredicate) context.Context { + var predicates []RouteRulePredicate + + val := ctx.Value(routeRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]RouteRulePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, routeRulePredicateCtxKey{}, predicates) +} + +func LoadRouterRulePredicates(ctx context.Context) []RouteRulePredicate { + var predicates []RouteRulePredicate + + val := ctx.Value(routeRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]RouteRulePredicate) + } + return predicates +} + +func AppendRatelimitRulePredicate(ctx context.Context, p RateLimitRulePredicate) context.Context { + var predicates []RateLimitRulePredicate + + val := ctx.Value(ratelimitRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]RateLimitRulePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, ratelimitRulePredicateCtxKey{}, predicates) +} + +func LoadRatelimitRulePredicates(ctx context.Context) []RateLimitRulePredicate { + var predicates []RateLimitRulePredicate + + val := ctx.Value(ratelimitRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]RateLimitRulePredicate) + } + return predicates +} + +func AppendCircuitBreakerRulePredicate(ctx context.Context, p CircuitBreakerPredicate) context.Context { + var predicates []CircuitBreakerPredicate + + val := ctx.Value(circuitbreakerRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]CircuitBreakerPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, circuitbreakerRulePredicateCtxKey{}, predicates) +} + +func LoadCircuitBreakerRulePredicates(ctx context.Context) []CircuitBreakerPredicate { + var predicates []CircuitBreakerPredicate + + val := ctx.Value(circuitbreakerRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]CircuitBreakerPredicate) + } + return predicates +} + +func AppendFaultDetectRulePredicate(ctx context.Context, p FaultDetectPredicate) context.Context { + var predicates []FaultDetectPredicate + + val := ctx.Value(faultdetectRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]FaultDetectPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, faultdetectRulePredicateCtxKey{}, predicates) +} + +func LoadFaultDetectRulePredicates(ctx context.Context) []FaultDetectPredicate { + var predicates []FaultDetectPredicate + + val := ctx.Value(faultdetectRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]FaultDetectPredicate) + } + return predicates +} + +func AppendLaneRulePredicate(ctx context.Context, p LanePredicate) context.Context { + var predicates []LanePredicate + + val := ctx.Value(laneRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]LanePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, laneRulePredicateCtxKey{}, predicates) +} + +func LoadLaneRulePredicates(ctx context.Context) []LanePredicate { + var predicates []LanePredicate + + val := ctx.Value(laneRulePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]LanePredicate) + } + return predicates +} + +func AppendConfigGroupPredicate(ctx context.Context, p ConfigGroupPredicate) context.Context { + var predicates []ConfigGroupPredicate + + val := ctx.Value(configGroupPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]ConfigGroupPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, configGroupPredicateCtxKey{}, predicates) +} + +func LoadConfigGroupPredicates(ctx context.Context) []ConfigGroupPredicate { + var predicates []ConfigGroupPredicate + + val := ctx.Value(configGroupPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]ConfigGroupPredicate) + } + return predicates +} + +func AppendUserPredicate(ctx context.Context, p UserPredicate) context.Context { + var predicates []UserPredicate + + val := ctx.Value(userPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]UserPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, userPredicateCtxKey{}, predicates) +} + +func LoadUserPredicates(ctx context.Context) []UserPredicate { + var predicates []UserPredicate + + val := ctx.Value(userPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]UserPredicate) + } + return predicates +} + +func AppendUserGroupPredicate(ctx context.Context, p UserGroupPredicate) context.Context { + var predicates []UserGroupPredicate + + val := ctx.Value(userGroupPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]UserGroupPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, userGroupPredicateCtxKey{}, predicates) +} + +func LoadUserGroupPredicates(ctx context.Context) []UserGroupPredicate { + var predicates []UserGroupPredicate + + val := ctx.Value(userGroupPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]UserGroupPredicate) + } + return predicates +} + +func AppendAuthPolicyPredicate(ctx context.Context, p AuthPolicyPredicate) context.Context { + var predicates []AuthPolicyPredicate + + val := ctx.Value(authPolicyPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]AuthPolicyPredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, authPolicyPredicateCtxKey{}, predicates) +} + +func LoadAuthPolicyPredicates(ctx context.Context) []AuthPolicyPredicate { + var predicates []AuthPolicyPredicate + + val := ctx.Value(userGroupPredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]AuthPolicyPredicate) + } + return predicates +} + +func AppendAuthRolePredicate(ctx context.Context, p AuthRolePredicate) context.Context { + var predicates []AuthRolePredicate + + val := ctx.Value(authRolePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]AuthRolePredicate) + } + + predicates = append(predicates, p) + return context.WithValue(ctx, authRolePredicateCtxKey{}, predicates) +} + +func LoadAuthRolePredicates(ctx context.Context) []AuthRolePredicate { + var predicates []AuthRolePredicate + + val := ctx.Value(authRolePredicateCtxKey{}) + if val != nil { + predicates, _ = val.([]AuthRolePredicate) + } + return predicates +} diff --git a/cache/api/types.go b/cache/api/types.go index 3020e2fb4..d000f1828 100644 --- a/cache/api/types.go +++ b/cache/api/types.go @@ -30,6 +30,7 @@ import ( "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/store" ) @@ -66,6 +67,8 @@ const ( UsersName = "users" // StrategyRuleName strategy rule config name StrategyRuleName = "strategyRule" + // RolesName role data config name + RolesName = "roles" // ServiceContractName service contract config name ServiceContractName = "serviceContract" // GrayName gray config name @@ -93,6 +96,7 @@ const ( CacheServiceContract CacheGray CacheLaneRule + CacheRole CacheLast ) @@ -161,9 +165,14 @@ type CacheManager interface { ConfigGroup() ConfigGroupCache // Gray get Gray cache information Gray() GrayCache + // Role Get role cache information + Role() RoleCache } type ( + // NamespacePredicate . + NamespacePredicate func(context.Context, *model.Namespace) bool + // NamespaceCache 命名空间的 Cache 接口 NamespaceCache interface { Cache @@ -179,6 +188,9 @@ type ( ) type ( + // ServicePredicate . + ServicePredicate func(context.Context, *model.Service) bool + // ServiceIterProc 迭代回调函数 ServiceIterProc func(key string, value *model.Service) (bool, error) @@ -206,6 +218,8 @@ type ( OnlyExistHealthInstance bool // OnlyExistInstance 只展示存在实例的服务 OnlyExistInstance bool + // Predicates 额外的数据检查 + Predicates []ServicePredicate } // ServiceCache 服务数据缓存接口 @@ -229,7 +243,7 @@ type ( // GetServiceByCl5Name Get the corresponding SID according to CL5name GetServiceByCl5Name(cl5Name string) *model.Service // GetServicesByFilter Serving the service filtering in the cache through Filter - GetServicesByFilter(serviceFilters *ServiceArgs, + GetServicesByFilter(ctx context.Context, serviceFilters *ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) // ListServices get service list and revision by namespace ListServices(ns string) (string, []*model.Service) @@ -297,23 +311,93 @@ type ( ) type ( + // FaultDetectPredicate . + FaultDetectPredicate func(context.Context, *model.FaultDetectRule) bool + // FaultDetectArgs + FaultDetectArgs struct { + // Filter extend filter params + Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + ServiceNamespace string + DstNamespace string + DstService string + DstMethod string + // Enable + Enable *bool + // Offset + Offset uint32 + // Limit + Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []FaultDetectPredicate + } + // FaultDetectCache fault detect rule cache service FaultDetectCache interface { Cache + // Query . + Query(context.Context, *FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) // GetFaultDetectConfig 根据ServiceID获取探测配置 GetFaultDetectConfig(svcName string, namespace string) *model.ServiceWithFaultDetectRules + // GetRule 获取规则 ID 获取主动探测规则 + GetRule(id string) *model.FaultDetectRule } ) type ( + // LanePredicate . + LanePredicate func(context.Context, *model.LaneGroupProto) bool + // LaneGroupArgs . + LaneGroupArgs struct { + // Filter extend filter params + Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + // Enable + Enable *bool + // Offset + Offset uint32 + // Limit + Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []FaultDetectPredicate + } + // LaneCache . LaneCache interface { Cache + // Query . + Query(context.Context, *LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) // GetLaneRules 根据serviceID获取泳道规则 GetLaneRules(serviceKey *model.Service) ([]*model.LaneGroupProto, string) + // GetRule 获取规则 ID 获取全链路灰度规则 + GetRule(id string) *model.LaneGroup } ) type ( + // RouteRulePredicate . + RouteRulePredicate func(context.Context, *model.ExtendRouterConfig) bool // RoutingArgs Routing rules query parameters RoutingArgs struct { // Filter extend filter params @@ -344,6 +428,8 @@ type ( OrderField string // OrderType Sorting rules OrderType string + // Predicates 额外的数据检查 + Predicates []RouteRulePredicate } // RouterRuleIterProc Method definition of routing rules @@ -356,20 +442,26 @@ type ( GetRouterConfig(id, service, namespace string) (*apitraffic.Routing, error) // GetRouterConfig Obtain routing configuration based on serviceid GetRouterConfigV2(id, service, namespace string) (*apitraffic.Routing, error) + // GetNearbyRouteRule 根据服务名查询就近路由数据 + GetNearbyRouteRule(service, namespace string) ([]*apitraffic.RouteRule, string, error) // GetRoutingConfigCount Get the total number of routing configuration cache GetRoutingConfigCount() int // QueryRoutingConfigsV2 Query Route Configuration List - QueryRoutingConfigsV2(args *RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) + QueryRoutingConfigsV2(context.Context, *RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) // ListRouterRule list all router rule ListRouterRule(service, namespace string) []*model.ExtendRouterConfig // IsConvertFromV1 Whether the current routing rules are converted from the V1 rule IsConvertFromV1(id string) (string, bool) // IteratorRouterRule iterator router rule IteratorRouterRule(iterProc RouterRuleIterProc) + // GetRule 获取规则 ID 获取路由规则 + GetRule(id string) *model.ExtendRouterConfig } ) type ( + // RateLimitRulePredicate . + RateLimitRulePredicate func(context.Context, *model.RateLimit) bool // RateLimitRuleArgs ratelimit rules query parameters RateLimitRuleArgs struct { // Filter extend filter params @@ -392,6 +484,8 @@ type ( OrderField string // OrderType Sorting rules OrderType string + // Predicates . + Predicates []RateLimitRulePredicate } // RateLimitIterProc rate limit iter func @@ -405,9 +499,11 @@ type ( // GetRateLimitRules 根据serviceID获取限流数据 GetRateLimitRules(serviceKey model.ServiceKey) ([]*model.RateLimit, string) // QueryRateLimitRules - QueryRateLimitRules(args RateLimitRuleArgs) (uint32, []*model.RateLimit, error) + QueryRateLimitRules(context.Context, RateLimitRuleArgs) (uint32, []*model.RateLimit, error) // GetRateLimitsCount 获取限流规则总数 GetRateLimitsCount() int + // GetRule 获取规则 ID 获取限流规则 + GetRule(id string) *model.RateLimit } ) @@ -429,11 +525,50 @@ type ( ) type ( + // CircuitBreakerPredicate . + CircuitBreakerPredicate func(context.Context, *model.CircuitBreakerRule) bool + // CircuitBreakerRuleArgs . + CircuitBreakerRuleArgs struct { + // Filter extend filter params + Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + // SourceService source service name + SourceService string + // SourceNamespace source service namespace + SourceNamespace string + // DestinationService destination service name + DestinationService string + // DestinationNamespace destination service namespace + DestinationNamespace string + // Enable + Enable *bool + // Offset + Offset uint32 + // Limit + Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []CircuitBreakerPredicate + } // CircuitBreakerCache circuitBreaker配置的cache接口 CircuitBreakerCache interface { Cache + // Query . + Query(context.Context, *CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) // GetCircuitBreakerConfig 根据ServiceID获取熔断配置 GetCircuitBreakerConfig(svcName string, namespace string) *model.ServiceWithCircuitBreakerRules + // GetRule 获取规则 ID 获取熔断规则 + GetRule(id string) *model.CircuitBreakerRule } ) @@ -490,6 +625,9 @@ type ( OrderType string } + // ConfigGroupPredicate . + ConfigGroupPredicate func(context.Context, *model.ConfigFileGroup) bool + // ConfigGroupCache file cache ConfigGroupCache interface { Cache @@ -520,38 +658,82 @@ type ( ) type ( + UserSearchArgs struct { + Filters map[string]string + Offset uint32 + Limit uint32 + } + UserGroupSearchArgs struct { + Filters map[string]string + Offset uint32 + Limit uint32 + } + + // UserPredicate . + UserPredicate func(context.Context, *authcommon.User) bool + // UserGroupPredicate . + UserGroupPredicate func(context.Context, *authcommon.UserGroupDetail) bool // UserCache User information cache UserCache interface { Cache // GetAdmin 获取管理员信息 - GetAdmin() *model.User + GetAdmin() *authcommon.User // GetUserByID - GetUserByID(id string) *model.User + GetUserByID(id string) *authcommon.User // GetUserByName - GetUserByName(name, ownerName string) *model.User + GetUserByName(name, ownerName string) *authcommon.User // GetUserGroup - GetGroup(id string) *model.UserGroupDetail + GetGroup(id string) *authcommon.UserGroupDetail // IsUserInGroup 判断 userid 是否在对应的 group 中 IsUserInGroup(userId, groupId string) bool // IsOwner IsOwner(id string) bool // GetUserLinkGroupIds GetUserLinkGroupIds(id string) []string + // QueryUsers . + QueryUsers(context.Context, UserSearchArgs) (uint32, []*authcommon.User, error) + // QueryUserGroups . + QueryUserGroups(context.Context, UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail, error) + } + + PolicySearchArgs struct { + Filters map[string]string + Offset uint32 + Limit uint32 } + // AuthPolicyPredicate . + AuthPolicyPredicate func(context.Context, *authcommon.StrategyDetail) bool + // StrategyCache is a cache for strategy rules. StrategyCache interface { Cache - // GetStrategyDetailsByUID - GetStrategyDetailsByUID(uid string) []*model.StrategyDetail - // GetStrategyDetailsByGroupID returns all strategy details of a group. - GetStrategyDetailsByGroupID(groupId string) []*model.StrategyDetail - // IsResourceLinkStrategy 该资源是否关联了鉴权策略 - IsResourceLinkStrategy(resType apisecurity.ResourceType, resId string) bool - // IsResourceEditable 判断该资源是否可以操作 - IsResourceEditable(principal model.Principal, resType apisecurity.ResourceType, resId string) bool - // ForceSync 强制同步鉴权策略到cache (串行) - ForceSync() error + // GetPrincipalPolicies 根据 effect 获取 principal 的策略信息 + GetPrincipalPolicies(effect string, p authcommon.Principal) []*authcommon.StrategyDetail + // Hint 确认某个 principal 对于资源的访问权限 + Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction + // Query . + Query(context.Context, PolicySearchArgs) (uint32, []*authcommon.StrategyDetail, error) + } + + RoleSearchArgs struct { + Filters map[string]string + Offset uint32 + Limit uint32 + } + + // AuthPolicyPredicate . + AuthRolePredicate func(context.Context, *authcommon.Role) bool + + // RoleCache . + RoleCache interface { + Cache + // GetRole . + GetRole(id string) *authcommon.Role + // Query . + Query(context.Context, RoleSearchArgs) (uint32, []*authcommon.Role, error) + // GetPrincipalRoles . + GetPrincipalRoles(authcommon.Principal) []*authcommon.Role } ) diff --git a/cache/auth/default.go b/cache/auth/default.go index e2a5ba9de..b53686af6 100644 --- a/cache/auth/default.go +++ b/cache/auth/default.go @@ -23,5 +23,5 @@ import ( var ( _ types.UserCache = (*userCache)(nil) - _ types.StrategyCache = (*strategyCache)(nil) + _ types.StrategyCache = (*policyCache)(nil) ) diff --git a/cache/auth/policy.go b/cache/auth/policy.go new file mode 100644 index 000000000..771d1bb86 --- /dev/null +++ b/cache/auth/policy.go @@ -0,0 +1,411 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + "context" + "fmt" + "math" + "sort" + "strconv" + "time" + + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" + + types "github.com/polarismesh/polaris/cache/api" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" +) + +const ( + removePrincipalChSize = 8 +) + +// policyCache +type policyCache struct { + *types.BaseCache + + rules *utils.SyncMap[string, *authcommon.PolicyDetailCache] + allowPolicies map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]] + denyPolicies map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]] + + // principalResources + principalResources map[authcommon.PrincipalType]*utils.SyncMap[string, *authcommon.PrincipalResourceContainer] + + allowResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] + denyResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] + + singleFlight *singleflight.Group +} + +// NewStrategyCache +func NewStrategyCache(storage store.Store, cacheMgr types.CacheManager) types.StrategyCache { + return &policyCache{ + BaseCache: types.NewBaseCache(storage, cacheMgr), + singleFlight: new(singleflight.Group), + } +} + +func (sc *policyCache) Initialize(c map[string]interface{}) error { + sc.initContainers() + return nil +} + +func (sc *policyCache) Clear() error { + sc.BaseCache.Clear() + sc.initContainers() + return nil +} + +func (sc *policyCache) initContainers() { + sc.rules = utils.NewSyncMap[string, *authcommon.PolicyDetailCache]() + sc.allowPolicies = map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]]{ + authcommon.PrincipalUser: utils.NewSyncMap[string, *utils.SyncSet[string]](), + authcommon.PrincipalGroup: utils.NewSyncMap[string, *utils.SyncSet[string]](), + } + sc.denyPolicies = map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]]{ + authcommon.PrincipalUser: utils.NewSyncMap[string, *utils.SyncSet[string]](), + authcommon.PrincipalGroup: utils.NewSyncMap[string, *utils.SyncSet[string]](), + } + sc.principalResources = map[authcommon.PrincipalType]*utils.SyncMap[string, *authcommon.PrincipalResourceContainer]{ + authcommon.PrincipalUser: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), + authcommon.PrincipalGroup: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), + } + sc.allowResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() + sc.denyResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() +} + +func (sc *policyCache) Name() string { + return types.StrategyRuleName +} + +func (sc *policyCache) Update() error { + // 多个线程竞争,只有一个线程进行更新 + _, err, _ := sc.singleFlight.Do(sc.Name(), func() (interface{}, error) { + return nil, sc.DoCacheUpdate(sc.Name(), sc.realUpdate) + }) + return err +} + +func (sc *policyCache) realUpdate() (map[string]time.Time, int64, error) { + // 获取几秒前的全部数据 + var ( + start = time.Now() + lastTime = sc.LastFetchTime() + strategies, err = sc.BaseCache.Store().GetMoreStrategies(lastTime, sc.IsFirstUpdate()) + ) + if err != nil { + log.Errorf("[Cache][AuthStrategy] refresh auth strategy cache err: %s", err.Error()) + return nil, -1, err + } + + lastMtimes, add, update, del := sc.setStrategys(strategies) + log.Info("[Cache][AuthStrategy] get more auth strategy", + zap.Int("add", add), zap.Int("update", update), zap.Int("delete", del), + zap.Time("last", lastTime), zap.Duration("used", time.Since(start))) + return lastMtimes, int64(len(strategies)), nil +} + +// setStrategys 处理策略的数据更新情况 +// step 1. 先处理resource以及principal的数据更新情况(主要是为了能够获取到新老数据进行对比计算) +// step 2. 处理真正的 strategy 的缓存更新 +func (sc *policyCache) setStrategys(strategies []*authcommon.StrategyDetail) (map[string]time.Time, int, int, int) { + var add, remove, update int + lastMtime := sc.LastMtime(sc.Name()).Unix() + + for index := range strategies { + rule := strategies[index] + sc.handlePrincipalPolicies(rule) + if !rule.Valid { + sc.rules.Delete(rule.ID) + remove++ + } else { + if _, ok := sc.rules.Load(rule.ID); !ok { + add++ + } else { + update++ + } + sc.rules.Store(rule.ID, authcommon.NewPolicyDetailCache(rule)) + } + + lastMtime = int64(math.Max(float64(lastMtime), float64(rule.ModifyTime.Unix()))) + } + + return map[string]time.Time{sc.Name(): time.Unix(lastMtime, 0)}, add, update, remove +} + +// handlePrincipalPolicies +func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.StrategyDetail) { + // 计算 uid -> auth rule + principals := rule.Principals + + if oldRule, exist := sc.rules.Load(rule.ID); exist { + delMembers := make([]authcommon.Principal, 0, 8) + // 计算前后对比, principal 的变化 + newRes := make(map[string]struct{}, len(principals)) + for i := range principals { + newRes[fmt.Sprintf("%d_%s", principals[i].PrincipalType, principals[i].PrincipalID)] = struct{}{} + } + + // 筛选出从策略中被踢出的 principal 列表 + for i := range oldRule.Principals { + item := oldRule.Principals[i] + if _, ok := newRes[fmt.Sprintf("%d_%s", item.PrincipalType, item.PrincipalID)]; !ok { + delMembers = append(delMembers, item) + } + } + + // 针对被剔除的 principal 列表,清理掉所关联的鉴权策略信息 + for rIndex := range delMembers { + principal := delMembers[rIndex] + sc.writePrincipalLink(principal, rule, true) + } + } + if rule.Valid { + for pos := range principals { + principal := principals[pos] + sc.writePrincipalLink(principal, rule, false) + } + } else { + for pos := range principals { + principal := principals[pos] + sc.writePrincipalLink(principal, rule, true) + } + } +} + +func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule *authcommon.StrategyDetail, del bool) { + linkContainers := sc.allowPolicies[principal.PrincipalType] + if rule.Action == apisecurity.AuthAction_DENY.String() { + linkContainers = sc.denyPolicies[principal.PrincipalType] + } + values, ok := linkContainers.Load(principal.PrincipalID) + if !ok && !del { + linkContainers.ComputeIfAbsent(principal.PrincipalID, func(k string) *utils.SyncSet[string] { + return utils.NewSyncSet[string]() + }) + } + if del { + values.Remove(rule.ID) + } else { + values, _ := linkContainers.Load(principal.PrincipalID) + values.Add(rule.ID) + } + + principalResources, _ := sc.principalResources[principal.PrincipalType].ComputeIfAbsent(principal.PrincipalID, func(k string) *authcommon.PrincipalResourceContainer { + return authcommon.NewPrincipalResourceContainer() + }) + + if rule.IsDeny() { + for i := range rule.Resources { + item := rule.Resources[i] + if rule.Valid { + principalResources.SaveDenyResource(item) + } else { + principalResources.DelDenyResource(item) + } + } + return + } + for i := range rule.Resources { + item := rule.Resources[i] + if rule.Valid { + principalResources.SaveAllowResource(item) + } else { + principalResources.DelAllowResource(item) + } + } +} + +func (sc *policyCache) GetPrincipalPolicies(effect string, p authcommon.Principal) []*authcommon.StrategyDetail { + var ruleIds *utils.SyncSet[string] + var exist bool + switch effect { + case "allow": + ruleIds, exist = sc.allowPolicies[p.PrincipalType].Load(p.PrincipalID) + case "deny": + ruleIds, exist = sc.denyPolicies[p.PrincipalType].Load(p.PrincipalID) + default: + allowRuleIds, allowExist := sc.allowPolicies[p.PrincipalType].Load(p.PrincipalID) + denyRuleIds, denyExist := sc.denyPolicies[p.PrincipalType].Load(p.PrincipalID) + if allowRuleIds == nil { + allowRuleIds = utils.NewSyncSet[string]() + } + allowRuleIds.AddAll(denyRuleIds) + + ruleIds = allowRuleIds + exist = allowExist || denyExist + } + + if !exist { + return nil + } + if ruleIds.Len() == 0 { + return nil + } + result := make([]*authcommon.StrategyDetail, 0, 16) + ruleIds.Range(func(val string) { + strategy, ok := sc.rules.Load(val) + if ok { + result = append(result, strategy.StrategyDetail) + } + }) + return result +} + +// GetPrincipalResources 返回 principal 的资源信息,返回顺序为 (allow, deny) +func (sc *policyCache) Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction { + resources, ok := sc.principalResources[p.PrincipalType].Load(p.PrincipalID) + if !ok { + return apisecurity.AuthAction_DENY + } + action, ok := resources.Hint(r.Type, r.ID) + if ok { + return action + } + // 如果没办法从直接的 resource 中判断出来,那就根据资源标签在确认下,注意,这里必须 allMatch 才可以 + if sc.hintLabels(p, r, sc.denyResourceLabels) { + return apisecurity.AuthAction_DENY + } + if sc.hintLabels(p, r, sc.allowResourceLabels) { + return apisecurity.AuthAction_ALLOW + } + return apisecurity.AuthAction_DENY +} + +func (sc *policyCache) hintLabels(p authcommon.Principal, r *authcommon.ResourceEntry, + containers *utils.SyncMap[string, *utils.RefSyncSet[string]]) bool { + allMatch := true + for k, v := range r.Metadata { + labelVals, ok := sc.denyResourceLabels.Load(k) + if !ok { + allMatch = false + } + allMatch = labelVals.Contains(v) + if !allMatch { + break + } + } + return allMatch +} + +// Query implements api.StrategyCache. +func (sc *policyCache) Query(ctx context.Context, args types.PolicySearchArgs) (uint32, []*authcommon.StrategyDetail, error) { + if err := sc.Update(); err != nil { + return 0, nil, err + } + + searchId, hasId := args.Filters["id"] + searchName, hasName := args.Filters["name"] + searchOwner, hasOwner := args.Filters["owner"] + searchDefault, hasDefault := args.Filters["default"] + searchResType, hasResType := args.Filters["res_type"] + searchResID, _ := args.Filters["res_id"] + searchPrincipalId, hasPrincipalId := args.Filters["principal_id"] + searchPrincipalType, _ := args.Filters["principal_type"] + + predicates := types.LoadAuthPolicyPredicates(ctx) + + rules := make([]*authcommon.StrategyDetail, 0, args.Limit) + + sc.rules.Range(func(key string, val *authcommon.PolicyDetailCache) { + if hasId && val.ID != searchId { + return + } + if hasName && !utils.IsWildMatch(val.Name, searchName) { + return + } + if hasOwner && searchOwner != val.Owner { + if !hasPrincipalId { + return + } + if searchPrincipalType == strconv.Itoa(int(authcommon.PrincipalUser)) { + if _, ok := val.UserPrincipal[searchPrincipalId]; !ok { + return + } + } + if searchPrincipalType == strconv.Itoa(int(authcommon.PrincipalGroup)) { + if _, ok := val.GroupPrincipal[searchPrincipalId]; !ok { + return + } + } + } + if hasDefault && searchDefault != strconv.FormatBool(val.Default) { + return + } + if hasResType { + resources, ok := val.ResourceDict[authcommon.SearchTypeMapping[searchResType]] + if !ok { + return + } + if !resources.Contains(searchResID) { + return + } + } + if hasPrincipalId { + if searchPrincipalType == strconv.Itoa(int(authcommon.PrincipalUser)) { + if _, ok := val.UserPrincipal[searchPrincipalId]; !ok { + return + } + } + if searchPrincipalType == strconv.Itoa(int(authcommon.PrincipalGroup)) { + if _, ok := val.GroupPrincipal[searchPrincipalId]; !ok { + return + } + } + } + + for i := range predicates { + if !predicates[i](ctx, val.StrategyDetail) { + return + } + } + rules = append(rules, val.StrategyDetail) + }) + + total, ret := sc.toPage(rules, args) + return total, ret, nil +} + +func (sc *policyCache) toPage(rules []*authcommon.StrategyDetail, args types.PolicySearchArgs) (uint32, []*authcommon.StrategyDetail) { + beginIndex := args.Offset + endIndex := beginIndex + args.Limit + totalCount := uint32(len(rules)) + + if totalCount == 0 { + return totalCount, []*authcommon.StrategyDetail{} + } + if beginIndex >= endIndex { + return totalCount, []*authcommon.StrategyDetail{} + } + if beginIndex >= totalCount { + return totalCount, []*authcommon.StrategyDetail{} + } + if endIndex > totalCount { + endIndex = totalCount + } + + sort.Slice(rules, func(i, j int) bool { + return rules[i].ModifyTime.After(rules[j].ModifyTime) + }) + + return totalCount, rules[beginIndex:endIndex] +} diff --git a/cache/auth/policy_test.go b/cache/auth/policy_test.go new file mode 100644 index 000000000..f9f538a43 --- /dev/null +++ b/cache/auth/policy_test.go @@ -0,0 +1,18 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth diff --git a/cache/auth/role.go b/cache/auth/role.go new file mode 100644 index 000000000..f249c5b66 --- /dev/null +++ b/cache/auth/role.go @@ -0,0 +1,262 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + "context" + "time" + + types "github.com/polarismesh/polaris/cache/api" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +// NewRoleCache +func NewRoleCache(storage store.Store, cacheMgr types.CacheManager) types.RoleCache { + return &roleCache{ + BaseCache: types.NewBaseCache(storage, cacheMgr), + singleFlight: new(singleflight.Group), + } +} + +type roleCache struct { + *types.BaseCache + // roles + roles *utils.SyncMap[string, *authcommon.Role] + // principalRoles + principalRoles map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]] + singleFlight *singleflight.Group +} + +// Initialize implements api.RoleCache. +func (r *roleCache) Initialize(c map[string]interface{}) error { + r.roles = utils.NewSyncMap[string, *authcommon.Role]() + r.principalRoles = map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]]{ + authcommon.PrincipalUser: utils.NewSyncMap[string, *utils.SyncSet[string]](), + authcommon.PrincipalGroup: utils.NewSyncMap[string, *utils.SyncSet[string]](), + } + return nil +} + +// Name implements api.RoleCache. +func (r *roleCache) Name() string { + return types.RolesName +} + +// Clear implements api.RoleCache. +// Subtle: this method shadows the method (*BaseCache).Clear of roleCache.BaseCache. +func (r *roleCache) Clear() error { + r.roles = utils.NewSyncMap[string, *authcommon.Role]() + r.principalRoles = map[authcommon.PrincipalType]*utils.SyncMap[string, *utils.SyncSet[string]]{ + authcommon.PrincipalUser: utils.NewSyncMap[string, *utils.SyncSet[string]](), + authcommon.PrincipalGroup: utils.NewSyncMap[string, *utils.SyncSet[string]](), + } + return nil +} + +// Update implements api.RoleCache. +func (r *roleCache) Update() error { + // 多个线程竞争,只有一个线程进行更新 + _, err, _ := r.singleFlight.Do(r.Name(), func() (interface{}, error) { + return nil, r.DoCacheUpdate(r.Name(), r.realUpdate) + }) + return err +} + +func (r *roleCache) realUpdate() (map[string]time.Time, int64, error) { + // 获取几秒前的全部数据 + var ( + start = time.Now() + lastTime = r.LastFetchTime() + roles, err = r.BaseCache.Store().GetMoreRoles(r.IsFirstUpdate(), lastTime) + ) + if err != nil { + log.Errorf("[Cache][Roles] refresh auth roles cache err: %s", err.Error()) + return nil, -1, err + } + + lastMtime, add, update, del := r.setRoles(roles) + log.Info("[Cache][Roles] get more auth role", + zap.Int("add", add), zap.Int("update", update), zap.Int("delete", del), + zap.Time("last", lastTime), zap.Duration("used", time.Since(start))) + return map[string]time.Time{ + r.Name(): lastMtime, + }, int64(len(roles)), nil +} + +func (r *roleCache) setRoles(roles []*authcommon.Role) (time.Time, int, int, int) { + var add, remove, update int + lastMtime := r.LastMtime(r.Name()).Unix() + + for i := range roles { + item := roles[i] + oldVal, exist := r.roles.Load(item.ID) + r.dealPrincipalRoles(oldVal, true) + if !item.Valid { + remove++ + r.roles.Delete(item.ID) + } else { + if exist { + update++ + } else { + add++ + } + r.dealPrincipalRoles(item, false) + r.roles.Store(item.ID, item) + } + } + r.cleanEmptyPrincipalRoles() + return time.Unix(lastMtime, 0), add, update, remove +} + +func (r *roleCache) cleanEmptyPrincipalRoles() { + // 清理掉 principal 没有关联任何 role 的容器 + for pt := range r.principalRoles { + r.principalRoles[pt].Range(func(key string, val *utils.SyncSet[string]) { + if val.Len() == 0 { + r.principalRoles[pt].Delete(key) + } + }) + } +} + +// dealPrincipalRoles 处理 principal 和 role 的关联关系 +func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { + if role == nil { + return + } + if isDel { + users := role.Users + for i := range users { + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), + func(k string) *utils.SyncSet[string] { + return utils.NewSyncSet[string]() + }) + container.Remove(role.ID) + } + groups := role.UserGroups + for i := range groups { + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), + func(k string) *utils.SyncSet[string] { + return utils.NewSyncSet[string]() + }) + container.Remove(role.ID) + } + return + } + users := role.Users + for i := range users { + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), + func(k string) *utils.SyncSet[string] { + return utils.NewSyncSet[string]() + }) + container.Add(role.ID) + } + groups := role.UserGroups + for i := range groups { + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), + func(k string) *utils.SyncSet[string] { + return utils.NewSyncSet[string]() + }) + container.Add(role.ID) + } +} + +// Query implements api.RoleCache. +func (r *roleCache) Query(ctx context.Context, args types.RoleSearchArgs) (uint32, []*authcommon.Role, error) { + if err := r.Update(); err != nil { + return 0, nil, err + } + var ( + total uint32 + roles []*authcommon.Role + ) + + searchId, hasId := args.Filters["id"] + searchName, hasName := args.Filters["name"] + searchSource, hasSource := args.Filters["source"] + + predicates := types.LoadAuthRolePredicates(ctx) + + r.roles.Range(func(key string, val *authcommon.Role) { + if hasId && key != searchId { + return + } + if hasName { + if !utils.IsWildMatch(val.Name, searchName) { + return + } + } + if hasSource { + if !utils.IsWildMatch(val.Source, searchSource) { + return + } + } + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + roles = append(roles, val) + }) + + total, roles = r.toPage(total, roles, args) + return total, roles, nil +} + +func (r *roleCache) toPage(total uint32, roles []*authcommon.Role, args types.RoleSearchArgs) (uint32, []*authcommon.Role) { + if args.Limit == 0 { + return total, roles + } + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset + if start > total { + return total, nil + } + if end > total { + end = total + } + return total, roles[start:end] +} + +// GetPrincipalRoles implements api.RoleCache. +func (r *roleCache) GetPrincipalRoles(p authcommon.Principal) []*authcommon.Role { + containers, ok := r.principalRoles[p.PrincipalType].Load(p.PrincipalID) + if !ok { + return nil + } + + result := make([]*authcommon.Role, 0, containers.Len()) + containers.Range(func(val string) { + role, ok := r.roles.Load(val) + if !ok { + return + } + result = append(result, role) + }) + return result +} + +// GetRole implements api.RoleCache. +func (r *roleCache) GetRole(id string) *authcommon.Role { + ret, _ := r.roles.Load(id) + return ret +} diff --git a/cache/auth/strategy.go b/cache/auth/strategy.go deleted file mode 100644 index e71d624fe..000000000 --- a/cache/auth/strategy.go +++ /dev/null @@ -1,440 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import ( - "fmt" - "math" - "time" - - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" - - types "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/store" -) - -const ( - removePrincipalChSize = 8 -) - -// strategyCache -type strategyCache struct { - *types.BaseCache - - storage store.Store - strategys *utils.SyncMap[string, *model.StrategyDetailCache] - uid2Strategy *utils.SyncMap[string, *utils.SyncSet[string]] - groupid2Strategy *utils.SyncMap[string, *utils.SyncSet[string]] - - namespace2Strategy *utils.SyncMap[string, *utils.SyncSet[string]] - service2Strategy *utils.SyncMap[string, *utils.SyncSet[string]] - configGroup2Strategy *utils.SyncMap[string, *utils.SyncSet[string]] - - lastMtime int64 - userCache *userCache - singleFlight *singleflight.Group -} - -// NewStrategyCache -func NewStrategyCache(storage store.Store, cacheMgr types.CacheManager) types.StrategyCache { - return &strategyCache{ - BaseCache: types.NewBaseCache(storage, cacheMgr), - storage: storage, - } -} - -func (sc *strategyCache) Initialize(c map[string]interface{}) error { - sc.userCache = sc.BaseCache.CacheMgr.GetCacher(types.CacheUser).(*userCache) - sc.strategys = utils.NewSyncMap[string, *model.StrategyDetailCache]() - sc.uid2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.groupid2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.namespace2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.service2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.configGroup2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.singleFlight = new(singleflight.Group) - sc.lastMtime = 0 - return nil -} - -func (sc *strategyCache) Update() error { - // 多个线程竞争,只有一个线程进行更新 - _, err, _ := sc.singleFlight.Do(sc.Name(), func() (interface{}, error) { - return nil, sc.DoCacheUpdate(sc.Name(), sc.realUpdate) - }) - return err -} - -func (sc *strategyCache) ForceSync() error { - return sc.Update() -} - -func (sc *strategyCache) realUpdate() (map[string]time.Time, int64, error) { - // 获取几秒前的全部数据 - var ( - start = time.Now() - lastTime = sc.LastFetchTime() - strategies, err = sc.storage.GetStrategyDetailsForCache(lastTime, sc.IsFirstUpdate()) - ) - if err != nil { - log.Errorf("[Cache][AuthStrategy] refresh auth strategy cache err: %s", err.Error()) - return nil, -1, err - } - - lastMtimes, add, update, del := sc.setStrategys(strategies) - log.Info("[Cache][AuthStrategy] get more auth strategy", - zap.Int("add", add), zap.Int("update", update), zap.Int("delete", del), - zap.Time("last", lastTime), zap.Duration("used", time.Since(start))) - return lastMtimes, int64(len(strategies)), nil -} - -// setStrategys 处理策略的数据更新情况 -// step 1. 先处理resource以及principal的数据更新情况(主要是为了能够获取到新老数据进行对比计算) -// step 2. 处理真正的 strategy 的缓存更新 -func (sc *strategyCache) setStrategys(strategies []*model.StrategyDetail) (map[string]time.Time, int, int, int) { - var add, remove, update int - - sc.handlerResourceStrategy(strategies) - sc.handlerPrincipalStrategy(strategies) - - lastMtime := sc.LastMtime(sc.Name()).Unix() - - for index := range strategies { - rule := strategies[index] - if !rule.Valid { - sc.strategys.Delete(rule.ID) - remove++ - } else { - _, ok := sc.strategys.Load(rule.ID) - if !ok { - add++ - } else { - update++ - } - sc.strategys.Store(rule.ID, buildEnchanceStrategyDetail(rule)) - } - - lastMtime = int64(math.Max(float64(lastMtime), float64(rule.ModifyTime.Unix()))) - } - - return map[string]time.Time{sc.Name(): time.Unix(lastMtime, 0)}, add, update, remove -} - -func buildEnchanceStrategyDetail(strategy *model.StrategyDetail) *model.StrategyDetailCache { - users := make(map[string]model.Principal, 0) - groups := make(map[string]model.Principal, 0) - - for index := range strategy.Principals { - principal := strategy.Principals[index] - if principal.PrincipalRole == model.PrincipalUser { - users[principal.PrincipalID] = principal - } else { - groups[principal.PrincipalID] = principal - } - } - - return &model.StrategyDetailCache{ - StrategyDetail: strategy, - UserPrincipal: users, - GroupPrincipal: groups, - } -} - -func (sc *strategyCache) writeSet(linkContainers *utils.SyncMap[string, *utils.SyncSet[string]], key, val string, isDel bool) { - if isDel { - values, ok := linkContainers.Load(key) - if ok { - values.Remove(val) - } - } else { - if _, ok := linkContainers.Load(key); !ok { - linkContainers.Store(key, utils.NewSyncSet[string]()) - } - values, _ := linkContainers.Load(key) - values.Add(val) - } -} - -// handlerResourceStrategy 处理资源视角下策略的缓存 -// 根据新老策略的资源列表比对,计算出哪些资源不在和该策略存在关联关系,哪些资源新增了相关的策略 -func (sc *strategyCache) handlerResourceStrategy(strategies []*model.StrategyDetail) { - operateLink := func(resType int32, resId, strategyId string, remove bool) { - switch resType { - case int32(apisecurity.ResourceType_Namespaces): - sc.writeSet(sc.namespace2Strategy, resId, strategyId, remove) - case int32(apisecurity.ResourceType_Services): - sc.writeSet(sc.service2Strategy, resId, strategyId, remove) - case int32(apisecurity.ResourceType_ConfigGroups): - sc.writeSet(sc.configGroup2Strategy, resId, strategyId, remove) - } - } - - for sIndex := range strategies { - rule := strategies[sIndex] - addRes := rule.Resources - - if oldRule, exist := sc.strategys.Load(rule.ID); exist { - delRes := make([]model.StrategyResource, 0, 8) - // 计算前后对比, resource 的变化 - newRes := make(map[string]struct{}, len(addRes)) - for i := range addRes { - newRes[fmt.Sprintf("%d_%s", addRes[i].ResType, addRes[i].ResID)] = struct{}{} - } - - // 筛选出从策略中被踢出的 resource 列表 - for i := range oldRule.Resources { - item := oldRule.Resources[i] - if _, ok := newRes[fmt.Sprintf("%d_%s", item.ResType, item.ResID)]; !ok { - delRes = append(delRes, item) - } - } - - // 针对被剔除的 resource 列表,清理掉所关联的鉴权策略信息 - for rIndex := range delRes { - resource := delRes[rIndex] - operateLink(resource.ResType, resource.ResID, rule.ID, true) - } - } - - for rIndex := range addRes { - resource := addRes[rIndex] - if rule.Valid { - operateLink(resource.ResType, resource.ResID, rule.ID, false) - } else { - operateLink(resource.ResType, resource.ResID, rule.ID, true) - } - } - } -} - -// handlerPrincipalStrategy -func (sc *strategyCache) handlerPrincipalStrategy(strategies []*model.StrategyDetail) { - for index := range strategies { - rule := strategies[index] - // 计算 uid -> auth rule - principals := rule.Principals - - if oldRule, exist := sc.strategys.Load(rule.ID); exist { - delMembers := make([]model.Principal, 0, 8) - // 计算前后对比, principal 的变化 - newRes := make(map[string]struct{}, len(principals)) - for i := range principals { - newRes[fmt.Sprintf("%d_%s", principals[i].PrincipalRole, principals[i].PrincipalID)] = struct{}{} - } - - // 筛选出从策略中被踢出的 principal 列表 - for i := range oldRule.Principals { - item := oldRule.Principals[i] - if _, ok := newRes[fmt.Sprintf("%d_%s", item.PrincipalRole, item.PrincipalID)]; !ok { - delMembers = append(delMembers, item) - } - } - - // 针对被剔除的 principal 列表,清理掉所关联的鉴权策略信息 - for rIndex := range delMembers { - principal := delMembers[rIndex] - sc.removePrincipalLink(principal, rule) - } - } - if rule.Valid { - for pos := range principals { - principal := principals[pos] - sc.addPrincipalLink(principal, rule) - } - } else { - for pos := range principals { - principal := principals[pos] - sc.removePrincipalLink(principal, rule) - } - } - } -} - -func (sc *strategyCache) removePrincipalLink(principal model.Principal, rule *model.StrategyDetail) { - linkContainers := sc.uid2Strategy - if principal.PrincipalRole != model.PrincipalUser { - linkContainers = sc.groupid2Strategy - } - sc.writeSet(linkContainers, principal.PrincipalID, rule.ID, true) -} - -func (sc *strategyCache) addPrincipalLink(principal model.Principal, rule *model.StrategyDetail) { - linkContainers := sc.uid2Strategy - if principal.PrincipalRole != model.PrincipalUser { - linkContainers = sc.groupid2Strategy - } - sc.writeSet(linkContainers, principal.PrincipalID, rule.ID, false) -} - -func (sc *strategyCache) Clear() error { - sc.BaseCache.Clear() - sc.strategys = utils.NewSyncMap[string, *model.StrategyDetailCache]() - sc.uid2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.groupid2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.namespace2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.service2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.configGroup2Strategy = utils.NewSyncMap[string, *utils.SyncSet[string]]() - sc.lastMtime = 0 - return nil -} - -func (sc *strategyCache) Name() string { - return types.StrategyRuleName -} - -// 对于 check 逻辑,如果是计算 * 策略,则必须要求 * 资源下必须有策略 -// 如果是具体的资源ID,则该资源下不必有策略,如果没有策略就认为这个资源是可以被任何人编辑的 -func (sc *strategyCache) checkResourceEditable(strategIds *utils.SyncSet[string], principal model.Principal, mustCheck bool) bool { - // 是否可以编辑 - editable := false - // 是否真的包含策略 - isCheck := strategIds.Len() != 0 - - // 如果根本没有遍历过,则表示该资源下没有对应的策略列表,直接返回可编辑状态即可 - if !isCheck && !mustCheck { - return true - } - - strategIds.Range(func(strategyId string) { - isCheck = true - if rule, ok := sc.strategys.Load(strategyId); ok { - if principal.PrincipalRole == model.PrincipalUser { - _, exist := rule.UserPrincipal[principal.PrincipalID] - editable = editable || exist - } else { - _, exist := rule.GroupPrincipal[principal.PrincipalID] - editable = editable || exist - } - } - }) - - return editable -} - -// IsResourceEditable 判断当前资源是否可以操作 -// 这里需要考虑两种情况,一种是 “ * ” 策略,另一种是明确指出了具体的资源ID的策略 -func (sc *strategyCache) IsResourceEditable( - principal model.Principal, resType apisecurity.ResourceType, resId string) bool { - var ( - valAll, val *utils.SyncSet[string] - ok bool - ) - switch resType { - case apisecurity.ResourceType_Namespaces: - val, ok = sc.namespace2Strategy.Load(resId) - valAll, _ = sc.namespace2Strategy.Load("*") - case apisecurity.ResourceType_Services: - val, ok = sc.service2Strategy.Load(resId) - valAll, _ = sc.service2Strategy.Load("*") - case apisecurity.ResourceType_ConfigGroups: - val, ok = sc.configGroup2Strategy.Load(resId) - valAll, _ = sc.configGroup2Strategy.Load("*") - } - - // 代表该资源没有关联到任何策略,任何人都可以编辑 - if !ok { - return true - } - - principals := make([]model.Principal, 0, 4) - principals = append(principals, principal) - if principal.PrincipalRole == model.PrincipalUser { - groupids := sc.userCache.GetUserLinkGroupIds(principal.PrincipalID) - for i := range groupids { - principals = append(principals, model.Principal{ - PrincipalID: groupids[i], - PrincipalRole: model.PrincipalGroup, - }) - } - } - - for i := range principals { - item := principals[i] - if valAll != nil && sc.checkResourceEditable(valAll, item, true) { - return true - } - - if sc.checkResourceEditable(val, item, false) { - return true - } - } - - return false -} - -func (sc *strategyCache) GetStrategyDetailsByUID(uid string) []*model.StrategyDetail { - return sc.getStrategyDetails(uid, "") -} - -func (sc *strategyCache) GetStrategyDetailsByGroupID(groupid string) []*model.StrategyDetail { - return sc.getStrategyDetails("", groupid) -} - -func (sc *strategyCache) getStrategyDetails(uid string, gid string) []*model.StrategyDetail { - var ( - strategyIds []string - ) - if uid != "" { - sets, ok := sc.uid2Strategy.Load(uid) - if !ok { - return nil - } - strategyIds = sets.ToSlice() - } else if gid != "" { - sets, ok := sc.groupid2Strategy.Load(gid) - if !ok { - return nil - } - strategyIds = sets.ToSlice() - } - - result := make([]*model.StrategyDetail, 0, 16) - if len(strategyIds) > 0 { - for i := range strategyIds { - strategy, ok := sc.strategys.Load(strategyIds[i]) - if ok { - result = append(result, strategy.StrategyDetail) - } - } - } - return result -} - -// IsResourceLinkStrategy 校验 -func (sc *strategyCache) IsResourceLinkStrategy(resType apisecurity.ResourceType, resId string) bool { - hasLinkRule := func(sets *utils.SyncSet[string]) bool { - return sets.Len() != 0 - } - - switch resType { - case apisecurity.ResourceType_Namespaces: - val, ok := sc.namespace2Strategy.Load(resId) - return ok && hasLinkRule(val) - case apisecurity.ResourceType_Services: - val, ok := sc.service2Strategy.Load(resId) - return ok && hasLinkRule(val) - case apisecurity.ResourceType_ConfigGroups: - val, ok := sc.configGroup2Strategy.Load(resId) - return ok && hasLinkRule(val) - default: - return true - } -} diff --git a/cache/auth/strategy_test.go b/cache/auth/strategy_test.go deleted file mode 100644 index 8549f757f..000000000 --- a/cache/auth/strategy_test.go +++ /dev/null @@ -1,520 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import ( - "fmt" - "testing" - "time" - - "github.com/golang/mock/gomock" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - - types "github.com/polarismesh/polaris/cache/api" - cachemock "github.com/polarismesh/polaris/cache/mock" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/store/mock" -) - -func Test_strategyCache(t *testing.T) { - t.Run("get_policy", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockStore.EXPECT().GetUnixSecond(gomock.Any()).Return(time.Now().Unix(), nil) - mockStore.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).Return(buildStrategies(10), nil).AnyTimes() - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - _ = strategyCache.ForceSync() - _, _, _ = strategyCache.realUpdate() - - policies := strategyCache.GetStrategyDetailsByUID("user-1") - assert.True(t, len(policies) > 0, len(policies)) - - policies = strategyCache.GetStrategyDetailsByGroupID("group-1") - assert.True(t, len(policies) > 0, len(policies)) - - policies = strategyCache.GetStrategyDetailsByUID("fake-user-1") - assert.True(t, len(policies) == 0, len(policies)) - - policies = strategyCache.GetStrategyDetailsByGroupID("fake-group-1") - assert.True(t, len(policies) == 0, len(policies)) - }) - - t.Run("资源没有关联任何策略", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockStore.EXPECT().GetUnixSecond(gomock.Any()).Return(time.Now().Unix(), nil) - mockStore.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).Return(buildStrategies(10), nil).AnyTimes() - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - _ = strategyCache.ForceSync() - _, _, _ = strategyCache.realUpdate() - - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.True(t, ret, "must be true") - - ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Namespaces, "namespace-1") - assert.True(t, ret, "must be true") - ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Services, "service-1") - assert.True(t, ret, "must be true") - ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_ConfigGroups, "config_group-1") - assert.True(t, ret, "must be true") - - strategyCache.Clear() - }) - - t.Run("操作的目标资源关联了策略-自己在principal-user列表中", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - strategyCache.setStrategys([]*model.StrategyDetail{ - { - ID: fmt.Sprintf("rule-%d", 1), - Name: fmt.Sprintf("rule-%d", 1), - Principals: []model.Principal{ - { - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, - }, - Valid: true, - Resources: []model.StrategyResource{ - { - StrategyID: fmt.Sprintf("rule-%d", 1), - ResType: 0, - ResID: "*", - }, - }, - }, - }) - - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.True(t, ret, "must be true") - }) - - t.Run("操作的目标资源关联了策略-自己不在principal-user列表中", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - strategyCache.setStrategys(buildStrategies(10)) - - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-20", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - assert.False(t, ret, "must be false") - - ret = strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-20", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Services, "service-1") - assert.False(t, ret, "must be false") - - ret = strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-20", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_ConfigGroups, "config_group-1") - assert.False(t, ret, "must be false") - }) - - t.Run("操作的目标资源关联了策略-自己属于principal-group中组成员", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr).(*userCache) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - userCache.groups.Store("group-1", &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ - ID: "group-1", - }, - UserIds: map[string]struct{}{ - "user-1": {}, - }, - }) - - userCache.user2Groups.Store("user-1", utils.NewSyncSet[string]()) - links, _ := userCache.user2Groups.Load("user-1") - links.Add("group-1") - - strategyCache.setStrategys(buildStrategies(10)) - - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.True(t, ret, "must be true") - }) - - t.Run("操作关联策略的资源-策略在操作成功-策略移除操作失败", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr).(*userCache) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - userCache.groups.Store("group-1", &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ - ID: "group-1", - }, - UserIds: map[string]struct{}{ - "user-1": {}, - }, - }) - - userCache.user2Groups.Store("user-1", utils.NewSyncSet[string]()) - links, _ := userCache.user2Groups.Load("user-1") - links.Add("group-1") - strategyCache.strategys.Store("rule-1", &model.StrategyDetailCache{ - StrategyDetail: &model.StrategyDetail{ - ID: "rule-1", - Name: "rule-1", - Principals: []model.Principal{}, - Resources: []model.StrategyResource{}, - }, - GroupPrincipal: map[string]model.Principal{ - "group-1": { - PrincipalID: "group-1", - }, - }, - }) - strategyCache.strategys.Store("rule-2", &model.StrategyDetailCache{ - StrategyDetail: &model.StrategyDetail{ - ID: "rule-2", - Name: "rule-2", - Principals: []model.Principal{}, - Resources: []model.StrategyResource{}, - }, - GroupPrincipal: map[string]model.Principal{ - "group-2": { - PrincipalID: "group-2", - }, - }, - }) - - strategyCache.writeSet(strategyCache.namespace2Strategy, "namespace-1", "rule-1", false) - strategyCache.writeSet(strategyCache.namespace2Strategy, "namespace-1", "rule-2", false) - - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.True(t, ret, "must be true") - - strategyCache.handlerResourceStrategy([]*model.StrategyDetail{ - { - ID: "rule-1", - Name: "rule-1", - Valid: false, - Principals: []model.Principal{}, - Resources: []model.StrategyResource{ - { - StrategyID: "rule-1", - ResType: 0, - ResID: "namespace-1", - }, - }, - }, - }) - - ret = strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.False(t, ret, "must be false") - }) - - t.Run("", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockCacheMgr := cachemock.NewMockCacheManager(ctrl) - mockStore := mock.NewMockStore(ctrl) - - t.Cleanup(func() { - ctrl.Finish() - }) - - userCache := NewUserCache(mockStore, mockCacheMgr).(*userCache) - strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) - - mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() - mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() - mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() - - userCache.Initialize(map[string]interface{}{}) - strategyCache.Initialize(map[string]interface{}{}) - - userCache.groups.Store("group-1", &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ - ID: "group-1", - }, - UserIds: map[string]struct{}{ - "user-1": {}, - }, - }) - - strategyDetail := &model.StrategyDetail{ - ID: "rule-1", - Name: "rule-1", - Principals: []model.Principal{ - { - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, - { - PrincipalID: "group-1", - PrincipalRole: model.PrincipalGroup, - }, - }, - Valid: true, - Resources: []model.StrategyResource{ - { - StrategyID: "rule-1", - ResType: 0, - ResID: "*", - }, - }, - } - - strategyDetail2 := &model.StrategyDetail{ - ID: "rule-2", - Name: "rule-2", - Principals: []model.Principal{ - { - PrincipalID: "user-2", - PrincipalRole: model.PrincipalUser, - }, - { - PrincipalID: "group-2", - PrincipalRole: model.PrincipalGroup, - }, - }, - Valid: true, - Resources: []model.StrategyResource{ - { - StrategyID: "rule-2", - ResType: 0, - ResID: "namespace-1", - }, - }, - } - - strategyCache.strategys.Store("rule-1", &model.StrategyDetailCache{ - StrategyDetail: strategyDetail, - UserPrincipal: map[string]model.Principal{ - "user-1": { - PrincipalID: "user-1", - }, - }, - GroupPrincipal: map[string]model.Principal{ - "group-1": { - PrincipalID: "group-1", - }, - }, - }) - strategyCache.strategys.Store("rule-2", &model.StrategyDetailCache{ - StrategyDetail: strategyDetail2, - UserPrincipal: map[string]model.Principal{ - "user-2": { - PrincipalID: "user-2", - }, - }, - GroupPrincipal: map[string]model.Principal{ - "group-2": { - PrincipalID: "group-2", - }, - }, - }) - - strategyCache.handlerPrincipalStrategy([]*model.StrategyDetail{strategyDetail2}) - strategyCache.handlerResourceStrategy([]*model.StrategyDetail{strategyDetail2}) - strategyCache.handlerPrincipalStrategy([]*model.StrategyDetail{strategyDetail}) - strategyCache.handlerResourceStrategy([]*model.StrategyDetail{strategyDetail}) - ret := strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.True(t, ret, "must be true") - - ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Namespaces, "namespace-1") - assert.True(t, ret, "must be true") - - strategyDetail.Valid = false - - strategyCache.handlerPrincipalStrategy([]*model.StrategyDetail{strategyDetail}) - strategyCache.handlerResourceStrategy([]*model.StrategyDetail{strategyDetail}) - strategyCache.strategys.Delete(strategyDetail.ID) - ret = strategyCache.IsResourceEditable(model.Principal{ - PrincipalID: "user-1", - PrincipalRole: model.PrincipalUser, - }, apisecurity.ResourceType_Namespaces, "namespace-1") - - assert.False(t, ret, "must be false") - }) -} - -func buildStrategies(num int) []*model.StrategyDetail { - - ret := make([]*model.StrategyDetail, 0, num) - - for i := 0; i < num; i++ { - principals := make([]model.Principal, 0, num) - for j := 0; j < num; j++ { - principals = append(principals, model.Principal{ - PrincipalID: fmt.Sprintf("user-%d", i+1), - PrincipalRole: model.PrincipalUser, - }, model.Principal{ - PrincipalID: fmt.Sprintf("group-%d", i+1), - PrincipalRole: model.PrincipalGroup, - }) - } - - ret = append(ret, &model.StrategyDetail{ - ID: fmt.Sprintf("rule-%d", i+1), - Name: fmt.Sprintf("rule-%d", i+1), - Principals: principals, - Valid: true, - Resources: []model.StrategyResource{ - { - StrategyID: fmt.Sprintf("rule-%d", i+1), - ResType: 0, - ResID: fmt.Sprintf("namespace-%d", i+1), - }, - { - StrategyID: fmt.Sprintf("rule-%d", i+1), - ResType: 1, - ResID: fmt.Sprintf("service-%d", i+1), - }, - { - StrategyID: fmt.Sprintf("rule-%d", i+1), - ResType: 2, - ResID: fmt.Sprintf("config_group-%d", i+1), - }, - }, - }) - } - - return ret -} - -func testBuildPrincipalMap(principals []model.Principal, role model.PrincipalType) map[string]model.Principal { - ret := make(map[string]model.Principal, 0) - for i := range principals { - principal := principals[i] - if principal.PrincipalRole == role { - ret[principal.PrincipalID] = principal - } - } - - return ret -} diff --git a/cache/auth/user.go b/cache/auth/user.go index 673724c3d..1e7084875 100644 --- a/cache/auth/user.go +++ b/cache/auth/user.go @@ -18,6 +18,7 @@ package auth import ( + "context" "fmt" "math" "sync/atomic" @@ -27,7 +28,7 @@ import ( "golang.org/x/sync/singleflight" types "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -54,11 +55,11 @@ type userCache struct { adminUser atomic.Value // userid -> user - users *utils.SyncMap[string, *model.User] + users *utils.SyncMap[string, *authcommon.User] // username -> user - name2Users *utils.SyncMap[string, *model.User] + name2Users *utils.SyncMap[string, *authcommon.User] // groupid -> group - groups *utils.SyncMap[string, *model.UserGroupDetail] + groups *utils.SyncMap[string, *authcommon.UserGroupDetail] // userid -> groups user2Groups *utils.SyncMap[string, *utils.SyncSet[string]] @@ -78,9 +79,9 @@ func NewUserCache(storage store.Store, cacheMgr types.CacheManager) types.UserCa // Initialize func (uc *userCache) Initialize(_ map[string]interface{}) error { - uc.users = utils.NewSyncMap[string, *model.User]() - uc.name2Users = utils.NewSyncMap[string, *model.User]() - uc.groups = utils.NewSyncMap[string, *model.UserGroupDetail]() + uc.users = utils.NewSyncMap[string, *authcommon.User]() + uc.name2Users = utils.NewSyncMap[string, *authcommon.User]() + uc.groups = utils.NewSyncMap[string, *authcommon.UserGroupDetail]() uc.user2Groups = utils.NewSyncMap[string, *utils.SyncSet[string]]() uc.adminUser = atomic.Value{} uc.singleFlight = new(singleflight.Group) @@ -121,12 +122,12 @@ func (uc *userCache) realUpdate() (map[string]time.Time, int64, error) { return lastMimes, int64(len(users) + len(groups)), nil } -func (uc *userCache) setUserAndGroups(users []*model.User, - groups []*model.UserGroupDetail) (map[string]time.Time, userRefreshResult) { +func (uc *userCache) setUserAndGroups(users []*authcommon.User, + groups []*authcommon.UserGroupDetail) (map[string]time.Time, userRefreshResult) { ret := userRefreshResult{} - ownerSupplier := func(user *model.User) *model.User { - if user.Type == model.SubAccountUserRole { + ownerSupplier := func(user *authcommon.User) *authcommon.User { + if user.Type == authcommon.SubAccountUserRole { owner, _ := uc.users.Load(user.Owner) return owner } @@ -137,13 +138,13 @@ func (uc *userCache) setUserAndGroups(users []*model.User, // 更新 users 缓存 // step 1. 先更新 owner 用户 - uc.handlerUserCacheUpdate(lastMimes, &ret, users, func(user *model.User) bool { - return user.Type == model.OwnerUserRole + uc.handlerUserCacheUpdate(lastMimes, &ret, users, func(user *authcommon.User) bool { + return user.Type == authcommon.OwnerUserRole }, ownerSupplier) // step 2. 更新非 owner 用户 - uc.handlerUserCacheUpdate(lastMimes, &ret, users, func(user *model.User) bool { - return user.Type == model.SubAccountUserRole + uc.handlerUserCacheUpdate(lastMimes, &ret, users, func(user *authcommon.User) bool { + return user.Type == authcommon.SubAccountUserRole }, ownerSupplier) uc.handlerGroupCacheUpdate(lastMimes, &ret, groups) @@ -151,8 +152,8 @@ func (uc *userCache) setUserAndGroups(users []*model.User, } // handlerUserCacheUpdate 处理用户信息更新 -func (uc *userCache) handlerUserCacheUpdate(lastMimes map[string]time.Time, ret *userRefreshResult, users []*model.User, - filter func(user *model.User) bool, ownerSupplier func(user *model.User) *model.User) { +func (uc *userCache) handlerUserCacheUpdate(lastMimes map[string]time.Time, ret *userRefreshResult, users []*authcommon.User, + filter func(user *authcommon.User) bool, ownerSupplier func(user *authcommon.User) *authcommon.User) { lastUserMtime := uc.LastMtime("users").Unix() @@ -161,7 +162,7 @@ func (uc *userCache) handlerUserCacheUpdate(lastMimes map[string]time.Time, ret lastUserMtime = int64(math.Max(float64(lastUserMtime), float64(user.ModifyTime.Unix()))) - if user.Type == model.AdminUserRole { + if user.Type == authcommon.AdminUserRole { uc.adminUser.Store(user) uc.users.Store(user.ID, user) uc.name2Users.Store(fmt.Sprintf(NameLinkOwnerTemp, user.Name, user.Name), user) @@ -197,7 +198,7 @@ func (uc *userCache) handlerUserCacheUpdate(lastMimes map[string]time.Time, ret // handlerGroupCacheUpdate 处理用户组信息更新 func (uc *userCache) handlerGroupCacheUpdate(lastMimes map[string]time.Time, ret *userRefreshResult, - groups []*model.UserGroupDetail) { + groups []*authcommon.UserGroupDetail) { lastGroupMtime := uc.LastMtime("group").Unix() @@ -211,7 +212,7 @@ func (uc *userCache) handlerGroupCacheUpdate(lastMimes map[string]time.Time, ret uc.groups.Delete(group.ID) ret.groupDel++ } else { - var oldGroup *model.UserGroupDetail + var oldGroup *authcommon.UserGroupDetail if oldVal, ok := uc.groups.Load(group.ID); ok { ret.groupUpdate++ oldGroup = oldVal @@ -253,9 +254,9 @@ func (uc *userCache) handlerGroupCacheUpdate(lastMimes map[string]time.Time, ret func (uc *userCache) Clear() error { uc.BaseCache.Clear() - uc.users = utils.NewSyncMap[string, *model.User]() - uc.name2Users = utils.NewSyncMap[string, *model.User]() - uc.groups = utils.NewSyncMap[string, *model.UserGroupDetail]() + uc.users = utils.NewSyncMap[string, *authcommon.User]() + uc.name2Users = utils.NewSyncMap[string, *authcommon.User]() + uc.groups = utils.NewSyncMap[string, *authcommon.UserGroupDetail]() uc.user2Groups = utils.NewSyncMap[string, *utils.SyncSet[string]]() uc.adminUser = atomic.Value{} uc.lastUserMtime = 0 @@ -268,13 +269,13 @@ func (uc *userCache) Name() string { } // GetAdmin 获取管理员数据信息 -func (uc *userCache) GetAdmin() *model.User { +func (uc *userCache) GetAdmin() *authcommon.User { val := uc.adminUser.Load() if val == nil { return nil } - return val.(*model.User) + return val.(*authcommon.User) } // IsOwner 判断当前用户是否是 owner 角色 @@ -284,7 +285,7 @@ func (uc *userCache) IsOwner(id string) bool { return false } ut := val.Type - return ut == model.AdminUserRole || ut == model.OwnerUserRole + return ut == authcommon.AdminUserRole || ut == authcommon.OwnerUserRole } func (uc *userCache) IsUserInGroup(userId, groupId string) bool { @@ -297,7 +298,7 @@ func (uc *userCache) IsUserInGroup(userId, groupId string) bool { } // GetUserByID 根据用户ID获取用户缓存对象 -func (uc *userCache) GetUserByID(id string) *model.User { +func (uc *userCache) GetUserByID(id string) *authcommon.User { if id == "" { return nil } @@ -310,7 +311,7 @@ func (uc *userCache) GetUserByID(id string) *model.User { } // GetUserByName 通过用户 name 以及 owner 获取用户缓存对象 -func (uc *userCache) GetUserByName(name, ownerName string) *model.User { +func (uc *userCache) GetUserByName(name, ownerName string) *authcommon.User { val, ok := uc.name2Users.Load(fmt.Sprintf(NameLinkOwnerTemp, ownerName, name)) if !ok { @@ -320,7 +321,7 @@ func (uc *userCache) GetUserByName(name, ownerName string) *model.User { } // GetGroup 通过用户组ID获取用户组缓存对象 -func (uc *userCache) GetGroup(id string) *model.UserGroupDetail { +func (uc *userCache) GetGroup(id string) *authcommon.UserGroupDetail { if id == "" { return nil } @@ -344,3 +345,133 @@ func (uc *userCache) GetUserLinkGroupIds(userId string) []string { } return val.ToSlice() } + +// QueryUsers . +func (uc *userCache) QueryUsers(ctx context.Context, args types.UserSearchArgs) (uint32, []*authcommon.User, error) { + searchId, hasId := args.Filters["id"] + searchName, hasName := args.Filters["name"] + searchOwner, hasOwner := args.Filters["owner"] + searchSource, hasSource := args.Filters["source"] + searchGroupId, hasGroup := args.Filters["group_id"] + + predicates := types.LoadUserPredicates(ctx) + + if hasGroup { + g, ok := uc.groups.Load(searchGroupId) + if !ok { + return 0, nil, nil + } + predicates = append(predicates, func(ctx context.Context, u *authcommon.User) bool { + _, exist := g.UserIds[u.ID] + return exist + }) + } + + result := make([]*authcommon.User, 0, 32) + uc.users.Range(func(key string, val *authcommon.User) { + // 超级账户不做展示 + if authcommon.UserRoleType(val.Type) == authcommon.AdminUserRole { + return + } + if hasId && searchId != key { + return + } + if hasOwner && val.Owner != searchOwner { + return + } + if hasName && !utils.IsWildMatch(val.Name, searchName) { + return + } + if hasSource && !utils.IsWildMatch(val.Source, searchSource) { + return + } + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + result = append(result, val) + }) + + total, ret := uc.listUsersPage(result, args) + return total, ret, nil +} + +func (uc *userCache) listUsersPage(users []*authcommon.User, args types.UserSearchArgs) (uint32, []*authcommon.User) { + total := uint32(len(users)) + if args.Limit == 0 { + return total, nil + } + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset + if start > total { + return total, nil + } + if end > total { + end = total + } + return total, users[start:end] +} + +// QueryUserGroups . +func (uc *userCache) QueryUserGroups(ctx context.Context, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail, error) { + searchId, hasId := args.Filters["id"] + searchName, hasName := args.Filters["name"] + searchOwner, hasOwner := args.Filters["owner"] + searchSource, hasSource := args.Filters["source"] + + predicates := types.LoadUserGroupPredicates(ctx) + + searchUserId, hasUserId := args.Filters["user_id"] + if hasUserId { + if _, ok := uc.users.Load(searchUserId); !ok { + return 0, nil, nil + } + predicates = append(predicates, func(ctx context.Context, ugd *authcommon.UserGroupDetail) bool { + _, exist := ugd.UserIds[searchUserId] + return exist + }) + } + + result := make([]*authcommon.UserGroupDetail, 0, 32) + uc.groups.Range(func(key string, val *authcommon.UserGroupDetail) { + // 超级账户不做展示 + if hasId && searchId != key { + return + } + if hasOwner && val.Owner != searchOwner { + return + } + if hasName && !utils.IsWildMatch(val.Name, searchName) { + return + } + if hasSource && !utils.IsWildMatch(val.Source, searchSource) { + return + } + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + result = append(result, val) + }) + + total, ret := uc.listUserGroupsPage(result, args) + return total, ret, nil +} + +func (uc *userCache) listUserGroupsPage(groups []*authcommon.UserGroupDetail, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail) { + total := uint32(len(groups)) + if args.Limit == 0 { + return total, nil + } + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset + if start > total { + return total, nil + } + if end > total { + end = total + } + return total, groups[start:end] +} diff --git a/cache/auth/user_test.go b/cache/auth/user_test.go index 8d7330023..e87648ae4 100644 --- a/cache/auth/user_test.go +++ b/cache/auth/user_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/assert" types "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store/mock" ) @@ -49,24 +49,24 @@ func newTestUserCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *userC } // 生成测试数据 -func genModelUsers(total int) []*model.User { +func genModelUsers(total int) []*authcommon.User { if total%10 != 0 { panic(errors.New("total must like 10, 20, 30, 40, ...")) } - out := make([]*model.User, 0, total) + out := make([]*authcommon.User, 0, total) - var owner *model.User + var owner *authcommon.User for i := 0; i < total; i++ { if i%10 == 0 { - owner = &model.User{ + owner = &authcommon.User{ ID: fmt.Sprintf("owner-user-%d", i), Name: fmt.Sprintf("owner-user-%d", i), Password: fmt.Sprintf("owner-user-%d", i), Owner: "", Source: "Polaris", - Type: model.OwnerUserRole, + Type: authcommon.OwnerUserRole, Token: fmt.Sprintf("owner-user-%d", i), Valid: true, } @@ -74,13 +74,13 @@ func genModelUsers(total int) []*model.User { continue } - entry := &model.User{ + entry := &authcommon.User{ ID: fmt.Sprintf("sub-user-%d", i), Name: fmt.Sprintf("sub-user-%d", i), Password: fmt.Sprintf("sub-user-%d", i), Owner: owner.ID, Source: "Polaris", - Type: model.SubAccountUserRole, + Type: authcommon.SubAccountUserRole, Token: fmt.Sprintf("sub-user-%d", i), Valid: true, } @@ -90,13 +90,13 @@ func genModelUsers(total int) []*model.User { return out } -func genModelUserGroups(users []*model.User) []*model.UserGroupDetail { +func genModelUserGroups(users []*authcommon.User) []*authcommon.UserGroupDetail { - out := make([]*model.UserGroupDetail, 0, len(users)) + out := make([]*authcommon.UserGroupDetail, 0, len(users)) for i := 0; i < len(users); i++ { - entry := &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + entry := &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: utils.NewUUID(), Name: fmt.Sprintf("group-%d", i), Owner: users[0].ID, @@ -124,16 +124,16 @@ func TestUserCache_UpdateNormal(t *testing.T) { users := genModelUsers(10) groups := genModelUserGroups(users) - admin := &model.User{ + admin := &authcommon.User{ ID: "admin-polaris", Name: "admin-polaris", - Type: model.AdminUserRole, + Type: authcommon.AdminUserRole, Valid: true, } t.Run("首次更新用户", func(t *testing.T) { - copyUsers := make([]*model.User, 0, len(users)) - copyGroups := make([]*model.UserGroupDetail, 0, len(groups)) + copyUsers := make([]*authcommon.User, 0, len(users)) + copyGroups := make([]*authcommon.UserGroupDetail, 0, len(groups)) for i := range users { copyUser := *users[i] @@ -187,7 +187,7 @@ func TestUserCache_UpdateNormal(t *testing.T) { deleteCnt := 0 for i := range users { // 主账户/管理账户 不能删除,因此这里对于第一个用户需要跳过 - if users[i].Type != model.SubAccountUserRole { + if users[i].Type != authcommon.SubAccountUserRole { continue } if rand.Int31n(3) < 1 { @@ -199,8 +199,8 @@ func TestUserCache_UpdateNormal(t *testing.T) { users[i].Comment = fmt.Sprintf("Update user %d", i) } - copyUsers := make([]*model.User, 0, len(users)) - copyGroups := make([]*model.UserGroupDetail, 0, len(groups)) + copyUsers := make([]*authcommon.User, 0, len(users)) + copyGroups := make([]*authcommon.UserGroupDetail, 0, len(groups)) for i := range users { copyUser := *users[i] diff --git a/cache/cache.go b/cache/cache.go index f183fc669..67c9cf35d 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -254,6 +254,11 @@ func (nc *CacheManager) Gray() types.GrayCache { return nc.caches[types.CacheGray].(types.GrayCache) } +// Role get Role cache information +func (nc *CacheManager) Role() types.RoleCache { + return nc.caches[types.CacheRole].(types.RoleCache) +} + // GetCacher get types.Cache impl func (nc *CacheManager) GetCacher(cacheIndex types.CacheIndex) types.Cache { return nc.caches[cacheIndex] diff --git a/cache/config/config_file.go b/cache/config/config_file.go index b6ff0bdde..6ba757c5a 100644 --- a/cache/config/config_file.go +++ b/cache/config/config_file.go @@ -66,7 +66,7 @@ type fileCache struct { // NewConfigFileCache 创建文件缓存 func NewConfigFileCache(storage store.Store, cacheMgr types.CacheManager) types.ConfigFileCache { fc := &fileCache{ - storage: storage, + storage: storage, } fc.BaseCache = types.NewBaseCacheWithRepoerMetrics(storage, cacheMgr, fc.reportMetricsInfo) return fc diff --git a/cache/config/config_file_metrics_test.go b/cache/config/config_file_metrics_test.go index e7470ffe7..2002446a9 100644 --- a/cache/config/config_file_metrics_test.go +++ b/cache/config/config_file_metrics_test.go @@ -59,7 +59,6 @@ func Test_cleanExpireConfigFileMetricLabel(t *testing.T) { }, }, }, - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/cache/config/config_group.go b/cache/config/config_group.go index 04db33c4c..206fdf804 100644 --- a/cache/config/config_group.go +++ b/cache/config/config_group.go @@ -47,7 +47,7 @@ type configGroupCache struct { // NewConfigGroupCache 创建文件缓存 func NewConfigGroupCache(storage store.Store, cacheMgr types.CacheManager) types.ConfigGroupCache { gc := &configGroupCache{ - storage: storage, + storage: storage, } gc.BaseCache = types.NewBaseCacheWithRepoerMetrics(storage, cacheMgr, gc.reportMetricsInfo) return gc diff --git a/cache/default.go b/cache/default.go index 4b1144986..d89e03fce 100644 --- a/cache/default.go +++ b/cache/default.go @@ -50,6 +50,7 @@ func init() { RegisterCache(types.ServiceContractName, types.CacheServiceContract) RegisterCache(types.GrayName, types.CacheGray) RegisterCache(types.LaneRuleName, types.CacheLaneRule) + RegisterCache(types.RolesName, types.CacheRole) } var ( @@ -93,7 +94,7 @@ func newCacheManager(ctx context.Context, cacheOpt *Config, storage store.Store) // 注册发现 & 服务治理缓存 mgr.RegisterCacher(types.CacheService, cachesvc.NewServiceCache(storage, mgr)) mgr.RegisterCacher(types.CacheInstance, cachesvc.NewInstanceCache(storage, mgr)) - mgr.RegisterCacher(types.CacheRoutingConfig, cachesvc.NewRoutingConfigCache(storage, mgr)) + mgr.RegisterCacher(types.CacheRoutingConfig, cachesvc.NewRouteRuleCache(storage, mgr)) mgr.RegisterCacher(types.CacheRateLimit, cachesvc.NewRateLimitCache(storage, mgr)) mgr.RegisterCacher(types.CacheCircuitBreaker, cachesvc.NewCircuitBreakerCache(storage, mgr)) mgr.RegisterCacher(types.CacheFaultDetector, cachesvc.NewFaultDetectCache(storage, mgr)) @@ -106,6 +107,7 @@ func newCacheManager(ctx context.Context, cacheOpt *Config, storage store.Store) // 用户/用户组 & 鉴权规则缓存 mgr.RegisterCacher(types.CacheUser, cacheauth.NewUserCache(storage, mgr)) mgr.RegisterCacher(types.CacheAuthStrategy, cacheauth.NewStrategyCache(storage, mgr)) + mgr.RegisterCacher(types.CacheRole, cacheauth.NewRoleCache(storage, mgr)) // 北极星SDK Client mgr.RegisterCacher(types.CacheClient, cacheclient.NewClientCache(storage, mgr)) mgr.RegisterCacher(types.CacheGray, cachegray.NewGrayCache(storage, mgr)) diff --git a/cache/mock/cache_mock.go b/cache/mock/cache_mock.go index ab05d026c..330596ad5 100644 --- a/cache/mock/cache_mock.go +++ b/cache/mock/cache_mock.go @@ -12,6 +12,7 @@ import ( gomock "github.com/golang/mock/gomock" api "github.com/polarismesh/polaris/cache/api" model "github.com/polarismesh/polaris/common/model" + auth "github.com/polarismesh/polaris/common/model/auth" store "github.com/polarismesh/polaris/store" model0 "github.com/polarismesh/specification/source/go/api/v1/model" security "github.com/polarismesh/specification/source/go/api/v1/security" @@ -375,6 +376,20 @@ func (mr *MockCacheManagerMockRecorder) RegisterCacher(cacheIndex, item interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterCacher", reflect.TypeOf((*MockCacheManager)(nil).RegisterCacher), cacheIndex, item) } +// Role mocks base method. +func (m *MockCacheManager) Role() api.RoleCache { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Role") + ret0, _ := ret[0].(api.RoleCache) + return ret0 +} + +// Role indicates an expected call of Role. +func (mr *MockCacheManagerMockRecorder) Role() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Role", reflect.TypeOf((*MockCacheManager)(nil).Role)) +} + // RoutingConfig mocks base method. func (m *MockCacheManager) RoutingConfig() api.RoutingConfigCache { m.ctrl.T.Helper() @@ -742,9 +757,9 @@ func (mr *MockServiceCacheMockRecorder) GetServiceByName(name, namespace interfa } // GetServicesByFilter mocks base method. -func (m *MockServiceCache) GetServicesByFilter(serviceFilters *api.ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) { +func (m *MockServiceCache) GetServicesByFilter(ctx context.Context, serviceFilters *api.ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServicesByFilter", serviceFilters, instanceFilters, offset, limit) + ret := m.ctrl.Call(m, "GetServicesByFilter", ctx, serviceFilters, instanceFilters, offset, limit) ret0, _ := ret[0].(uint32) ret1, _ := ret[1].([]*model.EnhancedService) ret2, _ := ret[2].(error) @@ -752,9 +767,9 @@ func (m *MockServiceCache) GetServicesByFilter(serviceFilters *api.ServiceArgs, } // GetServicesByFilter indicates an expected call of GetServicesByFilter. -func (mr *MockServiceCacheMockRecorder) GetServicesByFilter(serviceFilters, instanceFilters, offset, limit interface{}) *gomock.Call { +func (mr *MockServiceCacheMockRecorder) GetServicesByFilter(ctx, serviceFilters, instanceFilters, offset, limit interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByFilter", reflect.TypeOf((*MockServiceCache)(nil).GetServicesByFilter), serviceFilters, instanceFilters, offset, limit) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByFilter", reflect.TypeOf((*MockServiceCache)(nil).GetServicesByFilter), ctx, serviceFilters, instanceFilters, offset, limit) } // GetServicesCount mocks base method. @@ -1369,6 +1384,20 @@ func (mr *MockFaultDetectCacheMockRecorder) GetFaultDetectConfig(svcName, namesp return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFaultDetectConfig", reflect.TypeOf((*MockFaultDetectCache)(nil).GetFaultDetectConfig), svcName, namespace) } +// GetRule mocks base method. +func (m *MockFaultDetectCache) GetRule(id string) *model.FaultDetectRule { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRule", id) + ret0, _ := ret[0].(*model.FaultDetectRule) + return ret0 +} + +// GetRule indicates an expected call of GetRule. +func (mr *MockFaultDetectCacheMockRecorder) GetRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRule", reflect.TypeOf((*MockFaultDetectCache)(nil).GetRule), id) +} + // Initialize mocks base method. func (m *MockFaultDetectCache) Initialize(c map[string]interface{}) error { m.ctrl.T.Helper() @@ -1397,6 +1426,22 @@ func (mr *MockFaultDetectCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockFaultDetectCache)(nil).Name)) } +// Query mocks base method. +func (m *MockFaultDetectCache) Query(arg0 context.Context, arg1 *api.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*model.FaultDetectRule) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockFaultDetectCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockFaultDetectCache)(nil).Query), arg0, arg1) +} + // Update mocks base method. func (m *MockFaultDetectCache) Update() error { m.ctrl.T.Helper() @@ -1477,6 +1522,20 @@ func (mr *MockLaneCacheMockRecorder) GetLaneRules(serviceKey interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLaneRules", reflect.TypeOf((*MockLaneCache)(nil).GetLaneRules), serviceKey) } +// GetRule mocks base method. +func (m *MockLaneCache) GetRule(id string) *model.LaneGroup { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRule", id) + ret0, _ := ret[0].(*model.LaneGroup) + return ret0 +} + +// GetRule indicates an expected call of GetRule. +func (mr *MockLaneCacheMockRecorder) GetRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRule", reflect.TypeOf((*MockLaneCache)(nil).GetRule), id) +} + // Initialize mocks base method. func (m *MockLaneCache) Initialize(c map[string]interface{}) error { m.ctrl.T.Helper() @@ -1505,6 +1564,22 @@ func (mr *MockLaneCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockLaneCache)(nil).Name)) } +// Query mocks base method. +func (m *MockLaneCache) Query(arg0 context.Context, arg1 *api.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*model.LaneGroupProto) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockLaneCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockLaneCache)(nil).Query), arg0, arg1) +} + // Update mocks base method. func (m *MockLaneCache) Update() error { m.ctrl.T.Helper() @@ -1570,6 +1645,22 @@ func (mr *MockRoutingConfigCacheMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoutingConfigCache)(nil).Close)) } +// GetNearbyRouteRule mocks base method. +func (m *MockRoutingConfigCache) GetNearbyRouteRule(service, namespace string) ([]*traffic_manage.RouteRule, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNearbyRouteRule", service, namespace) + ret0, _ := ret[0].([]*traffic_manage.RouteRule) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetNearbyRouteRule indicates an expected call of GetNearbyRouteRule. +func (mr *MockRoutingConfigCacheMockRecorder) GetNearbyRouteRule(service, namespace interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNearbyRouteRule", reflect.TypeOf((*MockRoutingConfigCache)(nil).GetNearbyRouteRule), service, namespace) +} + // GetRouterConfig mocks base method. func (m *MockRoutingConfigCache) GetRouterConfig(id, service, namespace string) (*traffic_manage.Routing, error) { m.ctrl.T.Helper() @@ -1614,6 +1705,20 @@ func (mr *MockRoutingConfigCacheMockRecorder) GetRoutingConfigCount() *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingConfigCount", reflect.TypeOf((*MockRoutingConfigCache)(nil).GetRoutingConfigCount)) } +// GetRule mocks base method. +func (m *MockRoutingConfigCache) GetRule(id string) *model.ExtendRouterConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRule", id) + ret0, _ := ret[0].(*model.ExtendRouterConfig) + return ret0 +} + +// GetRule indicates an expected call of GetRule. +func (mr *MockRoutingConfigCacheMockRecorder) GetRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRule", reflect.TypeOf((*MockRoutingConfigCache)(nil).GetRule), id) +} + // Initialize mocks base method. func (m *MockRoutingConfigCache) Initialize(c map[string]interface{}) error { m.ctrl.T.Helper() @@ -1684,9 +1789,9 @@ func (mr *MockRoutingConfigCacheMockRecorder) Name() *gomock.Call { } // QueryRoutingConfigsV2 mocks base method. -func (m *MockRoutingConfigCache) QueryRoutingConfigsV2(args *api.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { +func (m *MockRoutingConfigCache) QueryRoutingConfigsV2(arg0 context.Context, arg1 *api.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryRoutingConfigsV2", args) + ret := m.ctrl.Call(m, "QueryRoutingConfigsV2", arg0, arg1) ret0, _ := ret[0].(uint32) ret1, _ := ret[1].([]*model.ExtendRouterConfig) ret2, _ := ret[2].(error) @@ -1694,9 +1799,9 @@ func (m *MockRoutingConfigCache) QueryRoutingConfigsV2(args *api.RoutingArgs) (u } // QueryRoutingConfigsV2 indicates an expected call of QueryRoutingConfigsV2. -func (mr *MockRoutingConfigCacheMockRecorder) QueryRoutingConfigsV2(args interface{}) *gomock.Call { +func (mr *MockRoutingConfigCacheMockRecorder) QueryRoutingConfigsV2(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRoutingConfigsV2", reflect.TypeOf((*MockRoutingConfigCache)(nil).QueryRoutingConfigsV2), args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRoutingConfigsV2", reflect.TypeOf((*MockRoutingConfigCache)(nil).QueryRoutingConfigsV2), arg0, arg1) } // Update mocks base method. @@ -1793,6 +1898,20 @@ func (mr *MockRateLimitCacheMockRecorder) GetRateLimitsCount() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRateLimitsCount", reflect.TypeOf((*MockRateLimitCache)(nil).GetRateLimitsCount)) } +// GetRule mocks base method. +func (m *MockRateLimitCache) GetRule(id string) *model.RateLimit { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRule", id) + ret0, _ := ret[0].(*model.RateLimit) + return ret0 +} + +// GetRule indicates an expected call of GetRule. +func (mr *MockRateLimitCacheMockRecorder) GetRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRule", reflect.TypeOf((*MockRateLimitCache)(nil).GetRule), id) +} + // Initialize mocks base method. func (m *MockRateLimitCache) Initialize(c map[string]interface{}) error { m.ctrl.T.Helper() @@ -1834,9 +1953,9 @@ func (mr *MockRateLimitCacheMockRecorder) Name() *gomock.Call { } // QueryRateLimitRules mocks base method. -func (m *MockRateLimitCache) QueryRateLimitRules(args api.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { +func (m *MockRateLimitCache) QueryRateLimitRules(arg0 context.Context, arg1 api.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryRateLimitRules", args) + ret := m.ctrl.Call(m, "QueryRateLimitRules", arg0, arg1) ret0, _ := ret[0].(uint32) ret1, _ := ret[1].([]*model.RateLimit) ret2, _ := ret[2].(error) @@ -1844,9 +1963,9 @@ func (m *MockRateLimitCache) QueryRateLimitRules(args api.RateLimitRuleArgs) (ui } // QueryRateLimitRules indicates an expected call of QueryRateLimitRules. -func (mr *MockRateLimitCacheMockRecorder) QueryRateLimitRules(args interface{}) *gomock.Call { +func (mr *MockRateLimitCacheMockRecorder) QueryRateLimitRules(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRateLimitRules", reflect.TypeOf((*MockRateLimitCache)(nil).QueryRateLimitRules), args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRateLimitRules", reflect.TypeOf((*MockRateLimitCache)(nil).QueryRateLimitRules), arg0, arg1) } // Update mocks base method. @@ -2091,6 +2210,20 @@ func (mr *MockCircuitBreakerCacheMockRecorder) GetCircuitBreakerConfig(svcName, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCircuitBreakerConfig", reflect.TypeOf((*MockCircuitBreakerCache)(nil).GetCircuitBreakerConfig), svcName, namespace) } +// GetRule mocks base method. +func (m *MockCircuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRule", id) + ret0, _ := ret[0].(*model.CircuitBreakerRule) + return ret0 +} + +// GetRule indicates an expected call of GetRule. +func (mr *MockCircuitBreakerCacheMockRecorder) GetRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRule", reflect.TypeOf((*MockCircuitBreakerCache)(nil).GetRule), id) +} + // Initialize mocks base method. func (m *MockCircuitBreakerCache) Initialize(c map[string]interface{}) error { m.ctrl.T.Helper() @@ -2119,6 +2252,22 @@ func (mr *MockCircuitBreakerCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCircuitBreakerCache)(nil).Name)) } +// Query mocks base method. +func (m *MockCircuitBreakerCache) Query(arg0 context.Context, arg1 *api.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*model.CircuitBreakerRule) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockCircuitBreakerCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockCircuitBreakerCache)(nil).Query), arg0, arg1) +} + // Update mocks base method. func (m *MockCircuitBreakerCache) Update() error { m.ctrl.T.Helper() @@ -2503,10 +2652,10 @@ func (mr *MockUserCacheMockRecorder) Close() *gomock.Call { } // GetAdmin mocks base method. -func (m *MockUserCache) GetAdmin() *model.User { +func (m *MockUserCache) GetAdmin() *auth.User { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAdmin") - ret0, _ := ret[0].(*model.User) + ret0, _ := ret[0].(*auth.User) return ret0 } @@ -2517,10 +2666,10 @@ func (mr *MockUserCacheMockRecorder) GetAdmin() *gomock.Call { } // GetGroup mocks base method. -func (m *MockUserCache) GetGroup(id string) *model.UserGroupDetail { +func (m *MockUserCache) GetGroup(id string) *auth.UserGroupDetail { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetGroup", id) - ret0, _ := ret[0].(*model.UserGroupDetail) + ret0, _ := ret[0].(*auth.UserGroupDetail) return ret0 } @@ -2531,10 +2680,10 @@ func (mr *MockUserCacheMockRecorder) GetGroup(id interface{}) *gomock.Call { } // GetUserByID mocks base method. -func (m *MockUserCache) GetUserByID(id string) *model.User { +func (m *MockUserCache) GetUserByID(id string) *auth.User { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserByID", id) - ret0, _ := ret[0].(*model.User) + ret0, _ := ret[0].(*auth.User) return ret0 } @@ -2545,10 +2694,10 @@ func (mr *MockUserCacheMockRecorder) GetUserByID(id interface{}) *gomock.Call { } // GetUserByName mocks base method. -func (m *MockUserCache) GetUserByName(name, ownerName string) *model.User { +func (m *MockUserCache) GetUserByName(name, ownerName string) *auth.User { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserByName", name, ownerName) - ret0, _ := ret[0].(*model.User) + ret0, _ := ret[0].(*auth.User) return ret0 } @@ -2628,6 +2777,38 @@ func (mr *MockUserCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockUserCache)(nil).Name)) } +// QueryUserGroups mocks base method. +func (m *MockUserCache) QueryUserGroups(arg0 context.Context, arg1 api.UserGroupSearchArgs) (uint32, []*auth.UserGroupDetail, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryUserGroups", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*auth.UserGroupDetail) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// QueryUserGroups indicates an expected call of QueryUserGroups. +func (mr *MockUserCacheMockRecorder) QueryUserGroups(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryUserGroups", reflect.TypeOf((*MockUserCache)(nil).QueryUserGroups), arg0, arg1) +} + +// QueryUsers mocks base method. +func (m *MockUserCache) QueryUsers(arg0 context.Context, arg1 api.UserSearchArgs) (uint32, []*auth.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryUsers", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*auth.User) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// QueryUsers indicates an expected call of QueryUsers. +func (mr *MockUserCacheMockRecorder) QueryUsers(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryUsers", reflect.TypeOf((*MockUserCache)(nil).QueryUsers), arg0, arg1) +} + // Update mocks base method. func (m *MockUserCache) Update() error { m.ctrl.T.Helper() @@ -2693,92 +2874,187 @@ func (mr *MockStrategyCacheMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStrategyCache)(nil).Close)) } -// ForceSync mocks base method. -func (m *MockStrategyCache) ForceSync() error { +// GetPrincipalPolicies mocks base method. +func (m *MockStrategyCache) GetPrincipalPolicies(effect string, p auth.Principal) []*auth.StrategyDetail { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForceSync") + ret := m.ctrl.Call(m, "GetPrincipalPolicies", effect, p) + ret0, _ := ret[0].([]*auth.StrategyDetail) + return ret0 +} + +// GetPrincipalPolicies indicates an expected call of GetPrincipalPolicies. +func (mr *MockStrategyCacheMockRecorder) GetPrincipalPolicies(effect, p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalPolicies", reflect.TypeOf((*MockStrategyCache)(nil).GetPrincipalPolicies), effect, p) +} + +// Hint mocks base method. +func (m *MockStrategyCache) Hint(p auth.Principal, r *auth.ResourceEntry) security.AuthAction { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Hint", p, r) + ret0, _ := ret[0].(security.AuthAction) + return ret0 +} + +// Hint indicates an expected call of Hint. +func (mr *MockStrategyCacheMockRecorder) Hint(p, r interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Hint", reflect.TypeOf((*MockStrategyCache)(nil).Hint), p, r) +} + +// Initialize mocks base method. +func (m *MockStrategyCache) Initialize(c map[string]interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Initialize", c) ret0, _ := ret[0].(error) return ret0 } -// ForceSync indicates an expected call of ForceSync. -func (mr *MockStrategyCacheMockRecorder) ForceSync() *gomock.Call { +// Initialize indicates an expected call of Initialize. +func (mr *MockStrategyCacheMockRecorder) Initialize(c interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceSync", reflect.TypeOf((*MockStrategyCache)(nil).ForceSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockStrategyCache)(nil).Initialize), c) } -// GetStrategyDetailsByGroupID mocks base method. -func (m *MockStrategyCache) GetStrategyDetailsByGroupID(groupId string) []*model.StrategyDetail { +// Name mocks base method. +func (m *MockStrategyCache) Name() string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStrategyDetailsByGroupID", groupId) - ret0, _ := ret[0].([]*model.StrategyDetail) + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) return ret0 } -// GetStrategyDetailsByGroupID indicates an expected call of GetStrategyDetailsByGroupID. -func (mr *MockStrategyCacheMockRecorder) GetStrategyDetailsByGroupID(groupId interface{}) *gomock.Call { +// Name indicates an expected call of Name. +func (mr *MockStrategyCacheMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockStrategyCache)(nil).Name)) +} + +// Query mocks base method. +func (m *MockStrategyCache) Query(arg0 context.Context, arg1 api.PolicySearchArgs) (uint32, []*auth.StrategyDetail, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*auth.StrategyDetail) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockStrategyCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategyDetailsByGroupID", reflect.TypeOf((*MockStrategyCache)(nil).GetStrategyDetailsByGroupID), groupId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStrategyCache)(nil).Query), arg0, arg1) } -// GetStrategyDetailsByUID mocks base method. -func (m *MockStrategyCache) GetStrategyDetailsByUID(uid string) []*model.StrategyDetail { +// Update mocks base method. +func (m *MockStrategyCache) Update() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStrategyDetailsByUID", uid) - ret0, _ := ret[0].([]*model.StrategyDetail) + ret := m.ctrl.Call(m, "Update") + ret0, _ := ret[0].(error) return ret0 } -// GetStrategyDetailsByUID indicates an expected call of GetStrategyDetailsByUID. -func (mr *MockStrategyCacheMockRecorder) GetStrategyDetailsByUID(uid interface{}) *gomock.Call { +// Update indicates an expected call of Update. +func (mr *MockStrategyCacheMockRecorder) Update() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategyDetailsByUID", reflect.TypeOf((*MockStrategyCache)(nil).GetStrategyDetailsByUID), uid) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStrategyCache)(nil).Update)) } -// Initialize mocks base method. -func (m *MockStrategyCache) Initialize(c map[string]interface{}) error { +// MockRoleCache is a mock of RoleCache interface. +type MockRoleCache struct { + ctrl *gomock.Controller + recorder *MockRoleCacheMockRecorder +} + +// MockRoleCacheMockRecorder is the mock recorder for MockRoleCache. +type MockRoleCacheMockRecorder struct { + mock *MockRoleCache +} + +// NewMockRoleCache creates a new mock instance. +func NewMockRoleCache(ctrl *gomock.Controller) *MockRoleCache { + mock := &MockRoleCache{ctrl: ctrl} + mock.recorder = &MockRoleCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRoleCache) EXPECT() *MockRoleCacheMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockRoleCache) Clear() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Initialize", c) + ret := m.ctrl.Call(m, "Clear") ret0, _ := ret[0].(error) return ret0 } -// Initialize indicates an expected call of Initialize. -func (mr *MockStrategyCacheMockRecorder) Initialize(c interface{}) *gomock.Call { +// Clear indicates an expected call of Clear. +func (mr *MockRoleCacheMockRecorder) Clear() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockStrategyCache)(nil).Initialize), c) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockRoleCache)(nil).Clear)) } -// IsResourceEditable mocks base method. -func (m *MockStrategyCache) IsResourceEditable(principal model.Principal, resType security.ResourceType, resId string) bool { +// Close mocks base method. +func (m *MockRoleCache) Close() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsResourceEditable", principal, resType, resId) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) return ret0 } -// IsResourceEditable indicates an expected call of IsResourceEditable. -func (mr *MockStrategyCacheMockRecorder) IsResourceEditable(principal, resType, resId interface{}) *gomock.Call { +// Close indicates an expected call of Close. +func (mr *MockRoleCacheMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsResourceEditable", reflect.TypeOf((*MockStrategyCache)(nil).IsResourceEditable), principal, resType, resId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoleCache)(nil).Close)) } -// IsResourceLinkStrategy mocks base method. -func (m *MockStrategyCache) IsResourceLinkStrategy(resType security.ResourceType, resId string) bool { +// GetPrincipalRoles mocks base method. +func (m *MockRoleCache) GetPrincipalRoles(arg0 auth.Principal) []*auth.Role { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsResourceLinkStrategy", resType, resId) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "GetPrincipalRoles", arg0) + ret0, _ := ret[0].([]*auth.Role) + return ret0 +} + +// GetPrincipalRoles indicates an expected call of GetPrincipalRoles. +func (mr *MockRoleCacheMockRecorder) GetPrincipalRoles(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalRoles", reflect.TypeOf((*MockRoleCache)(nil).GetPrincipalRoles), arg0) +} + +// GetRole mocks base method. +func (m *MockRoleCache) GetRole(id string) *auth.Role { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRole", id) + ret0, _ := ret[0].(*auth.Role) + return ret0 +} + +// GetRole indicates an expected call of GetRole. +func (mr *MockRoleCacheMockRecorder) GetRole(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockRoleCache)(nil).GetRole), id) +} + +// Initialize mocks base method. +func (m *MockRoleCache) Initialize(c map[string]interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Initialize", c) + ret0, _ := ret[0].(error) return ret0 } -// IsResourceLinkStrategy indicates an expected call of IsResourceLinkStrategy. -func (mr *MockStrategyCacheMockRecorder) IsResourceLinkStrategy(resType, resId interface{}) *gomock.Call { +// Initialize indicates an expected call of Initialize. +func (mr *MockRoleCacheMockRecorder) Initialize(c interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsResourceLinkStrategy", reflect.TypeOf((*MockStrategyCache)(nil).IsResourceLinkStrategy), resType, resId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Initialize", reflect.TypeOf((*MockRoleCache)(nil).Initialize), c) } // Name mocks base method. -func (m *MockStrategyCache) Name() string { +func (m *MockRoleCache) Name() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Name") ret0, _ := ret[0].(string) @@ -2786,13 +3062,29 @@ func (m *MockStrategyCache) Name() string { } // Name indicates an expected call of Name. -func (mr *MockStrategyCacheMockRecorder) Name() *gomock.Call { +func (mr *MockRoleCacheMockRecorder) Name() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockStrategyCache)(nil).Name)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockRoleCache)(nil).Name)) +} + +// Query mocks base method. +func (m *MockRoleCache) Query(arg0 context.Context, arg1 api.RoleSearchArgs) (uint32, []*auth.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*auth.Role) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockRoleCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockRoleCache)(nil).Query), arg0, arg1) } // Update mocks base method. -func (m *MockStrategyCache) Update() error { +func (m *MockRoleCache) Update() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Update") ret0, _ := ret[0].(error) @@ -2800,9 +3092,9 @@ func (m *MockStrategyCache) Update() error { } // Update indicates an expected call of Update. -func (mr *MockStrategyCacheMockRecorder) Update() *gomock.Call { +func (mr *MockRoleCacheMockRecorder) Update() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStrategyCache)(nil).Update)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockRoleCache)(nil).Update)) } // MockClientCache is a mock of ClientCache interface. diff --git a/cache/service/circuitbreaker.go b/cache/service/circuitbreaker.go index 3cfe8a4cf..ff02db390 100644 --- a/cache/service/circuitbreaker.go +++ b/cache/service/circuitbreaker.go @@ -18,6 +18,7 @@ package service import ( + "context" "crypto/sha1" "fmt" "sort" @@ -390,3 +391,14 @@ func (c *circuitBreakerCache) GetCircuitBreakerCount() int { } return len(names) } + +// Query implements api.CircuitBreakerCache. +func (c *circuitBreakerCache) Query(context.Context, *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { + panic("unimplemented") +} + +// GetRule implements api.FaultDetectCache. +func (f *circuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { + rule, _ := f.rules.Load(id) + return rule +} diff --git a/cache/service/default.go b/cache/service/default.go index 678753326..4a3d36ee3 100644 --- a/cache/service/default.go +++ b/cache/service/default.go @@ -24,7 +24,7 @@ import ( var ( _ types.InstanceCache = (*instanceCache)(nil) _ types.ServiceCache = (*serviceCache)(nil) - _ types.RoutingConfigCache = (*routingConfigCache)(nil) + _ types.RoutingConfigCache = (*RouteRuleCache)(nil) _ types.CircuitBreakerCache = (*circuitBreakerCache)(nil) _ types.RateLimitCache = (*rateLimitCache)(nil) _ types.FaultDetectCache = (*faultDetectCache)(nil) diff --git a/cache/service/faultdetect.go b/cache/service/faultdetect.go index 331336b17..92a88d9b7 100644 --- a/cache/service/faultdetect.go +++ b/cache/service/faultdetect.go @@ -18,6 +18,7 @@ package service import ( + "context" "crypto/sha1" "fmt" "sort" @@ -348,3 +349,14 @@ func (f *faultDetectCache) GetFaultDetectRuleCount(fun func(k, v interface{}) bo } } } + +// Query implements api.FaultDetectCache. +func (f *faultDetectCache) Query(context.Context, *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { + panic("unimplemented") +} + +// GetRule implements api.FaultDetectCache. +func (f *faultDetectCache) GetRule(id string) *model.FaultDetectRule { + rule, _ := f.rules.Load(id) + return rule +} diff --git a/cache/service/lane.go b/cache/service/lane.go index 96cb76a32..7d5b31973 100644 --- a/cache/service/lane.go +++ b/cache/service/lane.go @@ -18,6 +18,7 @@ package service import ( + "context" "time" "github.com/golang/protobuf/proto" @@ -43,8 +44,8 @@ type LaneCache struct { *types.BaseCache // single . single singleflight.Group - // groups name -> *model.LaneGroupProto - groups *utils.SyncMap[string, *model.LaneGroupProto] + // groups id -> *model.LaneGroupProto + rules *utils.SyncMap[string, *model.LaneGroupProto] // serviceRules namespace -> service -> []*model.LaneRuleProto serviceRules *utils.SyncMap[string, *utils.SyncMap[string, *utils.SyncMap[string, *model.LaneGroupProto]]] // revisions namespace -> service -> revision @@ -55,7 +56,7 @@ type LaneCache struct { func (lc *LaneCache) Initialize(c map[string]interface{}) error { lc.serviceRules = utils.NewSyncMap[string, *utils.SyncMap[string, *utils.SyncMap[string, *model.LaneGroupProto]]]() lc.revisions = utils.NewSyncMap[string, *utils.SyncMap[string, string]]() - lc.groups = utils.NewSyncMap[string, *model.LaneGroupProto]() + lc.rules = utils.NewSyncMap[string, *model.LaneGroupProto]() lc.single = singleflight.Group{} return nil } @@ -113,10 +114,10 @@ func (lc *LaneCache) setLaneRules(items map[string]*model.LaneGroup) (time.Time, lastMtime = item.ModifyTime.Unix() } - oldVal, exist := lc.groups.Load(item.ID) + oldVal, exist := lc.rules.Load(item.ID) if !item.Valid { del++ - _, _ = lc.groups.Delete(item.ID) + _, _ = lc.rules.Delete(item.ID) if exist { lc.processLaneRuleDelete(oldVal, affectSvcs) } @@ -127,7 +128,7 @@ func (lc *LaneCache) setLaneRules(items map[string]*model.LaneGroup) (time.Time, } else { add++ } - lc.groups.Store(item.ID, saveVal) + lc.rules.Store(item.ID, saveVal) lc.processLaneRuleUpsert(oldVal, saveVal, affectSvcs) } lc.postUpdateRevisions(affectSvcs) @@ -337,7 +338,7 @@ func (lc *LaneCache) LastMtime() time.Time { // Clear . func (lc *LaneCache) Clear() error { lc.revisions = utils.NewSyncMap[string, *utils.SyncMap[string, string]]() - lc.groups = utils.NewSyncMap[string, *model.LaneGroupProto]() + lc.rules = utils.NewSyncMap[string, *model.LaneGroupProto]() lc.serviceRules = utils.NewSyncMap[string, *utils.SyncMap[string, *utils.SyncMap[string, *model.LaneGroupProto]]]() return nil } @@ -354,3 +355,14 @@ func anyToSelector(data *anypb.Any, msg proto.Message) error { } return nil } + +// Query implements api.LaneCache. +func (lc *LaneCache) Query(context.Context, *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { + panic("unimplemented") +} + +// GetRule implements api.LaneCache. +func (f *LaneCache) GetRule(id string) *model.LaneGroup { + rule, _ := f.rules.Load(id) + return rule.LaneGroup +} diff --git a/cache/service/ratelimit_bucket.go b/cache/service/ratelimit_bucket.go index 98c4d6e0c..2c4f5c9f5 100644 --- a/cache/service/ratelimit_bucket.go +++ b/cache/service/ratelimit_bucket.go @@ -30,33 +30,36 @@ import ( "github.com/polarismesh/polaris/common/utils" ) -func newRateLimitRuleBucket() *rateLimitRuleBucket { - return &rateLimitRuleBucket{ +func newRateLimitRuleBucket() *RateLimitRuleContainer { + return &RateLimitRuleContainer{ ids: utils.NewSyncMap[string, *model.RateLimit](), rules: utils.NewSyncMap[string, *subRateLimitRuleBucket](), } } -type rateLimitRuleBucket struct { +type RateLimitRuleContainer struct { ids *utils.SyncMap[string, *model.RateLimit] rules *utils.SyncMap[string, *subRateLimitRuleBucket] } -func (r *rateLimitRuleBucket) foreach(proc types.RateLimitIterProc) { +func (r *RateLimitRuleContainer) foreach(proc types.RateLimitIterProc) { r.rules.Range(func(key string, val *subRateLimitRuleBucket) { val.foreach(proc) }) } -func (r *rateLimitRuleBucket) count() int { +func (r *RateLimitRuleContainer) count() int { return r.ids.Len() } -func (r *rateLimitRuleBucket) saveRule(rule *model.RateLimit) { +func (r *RateLimitRuleContainer) saveRule(rule *model.RateLimit) { r.cleanOldSvcRule(rule) r.ids.Store(rule.ID, rule) - key := buildServiceKey(rule.Proto.GetNamespace().GetValue(), rule.Proto.GetService().GetValue()) + key := (&model.ServiceKey{ + Namespace: rule.Proto.GetNamespace().GetValue(), + Name: rule.Proto.GetService().GetValue(), + }).Domain() if _, ok := r.rules.Load(key); !ok { r.rules.Store(key, &subRateLimitRuleBucket{ @@ -69,13 +72,17 @@ func (r *rateLimitRuleBucket) saveRule(rule *model.RateLimit) { } // cleanOldSvcRule 清理规则之前绑定的服务数据信息 -func (r *rateLimitRuleBucket) cleanOldSvcRule(rule *model.RateLimit) { +func (r *RateLimitRuleContainer) cleanOldSvcRule(rule *model.RateLimit) { oldRule, ok := r.ids.Load(rule.ID) if !ok { return } + // 清理原来老记录的绑定数据信息 - key := buildServiceKey(oldRule.Proto.GetNamespace().GetValue(), oldRule.Proto.GetService().GetValue()) + key := (&model.ServiceKey{ + Namespace: oldRule.Proto.GetNamespace().GetValue(), + Name: oldRule.Proto.GetService().GetValue(), + }).Domain() bucket, ok := r.rules.Load(key) if !ok { return @@ -87,11 +94,14 @@ func (r *rateLimitRuleBucket) cleanOldSvcRule(rule *model.RateLimit) { } } -func (r *rateLimitRuleBucket) delRule(rule *model.RateLimit) { +func (r *RateLimitRuleContainer) delRule(rule *model.RateLimit) { r.cleanOldSvcRule(rule) r.ids.Delete(rule.ID) - key := buildServiceKey(rule.Proto.GetNamespace().GetValue(), rule.Proto.GetService().GetValue()) + key := (&model.ServiceKey{ + Namespace: rule.Proto.GetNamespace().GetValue(), + Name: rule.Proto.GetService().GetValue(), + }).Domain() if _, ok := r.rules.Load(key); !ok { return } @@ -103,13 +113,13 @@ func (r *rateLimitRuleBucket) delRule(rule *model.RateLimit) { } } -func (r *rateLimitRuleBucket) getRuleByID(id string) *model.RateLimit { +func (r *RateLimitRuleContainer) getRuleByID(id string) *model.RateLimit { ret, _ := r.ids.Load(id) return ret } -func (r *rateLimitRuleBucket) getRules(serviceKey model.ServiceKey) ([]*model.RateLimit, string) { - key := buildServiceKey(serviceKey.Namespace, serviceKey.Name) +func (r *RateLimitRuleContainer) getRules(serviceKey model.ServiceKey) ([]*model.RateLimit, string) { + key := (&serviceKey).Domain() if _, ok := r.rules.Load(key); !ok { return nil, "" } @@ -118,8 +128,8 @@ func (r *rateLimitRuleBucket) getRules(serviceKey model.ServiceKey) ([]*model.Ra return b.toSlice(), b.revision } -func (r *rateLimitRuleBucket) reloadRevision(serviceKey model.ServiceKey) { - key := buildServiceKey(serviceKey.Namespace, serviceKey.Name) +func (r *RateLimitRuleContainer) reloadRevision(serviceKey model.ServiceKey) { + key := serviceKey.Domain() v, ok := r.rules.Load(key) if !ok { return diff --git a/cache/service/ratelimit_config.go b/cache/service/ratelimit_config.go index 375f489b3..dbae4b4f2 100644 --- a/cache/service/ratelimit_config.go +++ b/cache/service/ratelimit_config.go @@ -40,7 +40,7 @@ type rateLimitCache struct { waitFixRules map[string]struct{} svcCache types.ServiceCache storage store.Store - rules *rateLimitRuleBucket + rules *RateLimitRuleContainer singleFlight singleflight.Group } @@ -194,8 +194,8 @@ func (rlc *rateLimitCache) fixRulesServiceInfo() { if svc != nil { rule.Proto.Namespace = utils.NewStringValue(svc.Namespace) rule.Proto.Name = utils.NewStringValue(svc.Name) + delete(rlc.waitFixRules, rule.ID) } - delete(rlc.waitFixRules, rule.ID) } } @@ -224,3 +224,8 @@ func (rlc *rateLimitCache) fixRuleServiceInfo(rateLimit *model.RateLimit) { } delete(rlc.waitFixRules, rateLimit.ID) } + +// GetRule implements api.RateLimitCache. +func (rlc *rateLimitCache) GetRule(id string) *model.RateLimit { + return rlc.rules.getRuleByID(id) +} diff --git a/cache/service/ratelimit_config_test.go b/cache/service/ratelimit_config_test.go index bc0b62a9a..61efdc6d1 100644 --- a/cache/service/ratelimit_config_test.go +++ b/cache/service/ratelimit_config_test.go @@ -18,6 +18,7 @@ package service import ( + "context" "encoding/json" "fmt" "testing" @@ -401,7 +402,7 @@ func Test_QueryRateLimitRules(t *testing.T) { } t.Run("根据ID进行查询", func(t *testing.T) { - total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err := rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ ID: rateLimits[0].ID, Offset: 0, Limit: 100, @@ -414,7 +415,7 @@ func Test_QueryRateLimitRules(t *testing.T) { }) t.Run("根据Name进行查询", func(t *testing.T) { - total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err := rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Name: rateLimits[0].Name, Offset: 0, Limit: 100, @@ -427,7 +428,7 @@ func Test_QueryRateLimitRules(t *testing.T) { }) t.Run("根据Namespace&Service进行查询", func(t *testing.T) { - total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err := rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Service: "service-0", Namespace: "default", Offset: 0, @@ -444,7 +445,7 @@ func Test_QueryRateLimitRules(t *testing.T) { }) t.Run("根据分页进行查询", func(t *testing.T) { - total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err := rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Offset: 10, Limit: 5, }) @@ -453,7 +454,7 @@ func Test_QueryRateLimitRules(t *testing.T) { assert.Equal(t, int64(total), int64(len(rateLimits))) assert.Equal(t, int64(5), int64(len(ret))) - total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err = rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Offset: 100, Limit: 5, }) @@ -465,7 +466,7 @@ func Test_QueryRateLimitRules(t *testing.T) { t.Run("根据Disable进行查询", func(t *testing.T) { disable := true - total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err := rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Disable: &disable, Offset: 0, Limit: 100, @@ -476,7 +477,7 @@ func Test_QueryRateLimitRules(t *testing.T) { assert.Equal(t, int64(0), int64(len(ret))) disable = false - total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ + total, ret, err = rlc.QueryRateLimitRules(context.TODO(), types.RateLimitRuleArgs{ Disable: &disable, Offset: 0, Limit: 100, diff --git a/cache/service/ratelimit_query.go b/cache/service/ratelimit_query.go index aa03b28be..53f643073 100644 --- a/cache/service/ratelimit_query.go +++ b/cache/service/ratelimit_query.go @@ -18,6 +18,7 @@ package service import ( + "context" "sort" "strings" @@ -35,7 +36,7 @@ func (rlc *rateLimitCache) forceUpdate() error { } // QueryRateLimitRules -func (rlc *rateLimitCache) QueryRateLimitRules(args types.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { +func (rlc *rateLimitCache) QueryRateLimitRules(ctx context.Context, args types.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { if err := rlc.forceUpdate(); err != nil { return 0, nil, err } diff --git a/cache/service/router_rule.go b/cache/service/router_rule.go index 2d52d62f9..1529494bf 100644 --- a/cache/service/router_rule.go +++ b/cache/service/router_rule.go @@ -19,8 +19,6 @@ package service import ( "fmt" - "sort" - "sync" "time" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" @@ -34,46 +32,45 @@ import ( ) type ( - // routingConfigCache Routing rules cache - routingConfigCache struct { + // RouteRuleCache Routing rules cache + RouteRuleCache struct { *types.BaseCache serviceCache types.ServiceCache storage store.Store - bucket *routeRuleBucket + container *RouteRuleContainer lastMtimeV1 time.Time lastMtimeV2 time.Time singleFlight singleflight.Group - // pendingV1RuleIds Records need to be converted from V1 to V2 routing rules ID - plock sync.Mutex - pendingV1RuleIds map[string]*model.RoutingConfig + // waitDealV1RuleIds Records need to be converted from V1 to V2 routing rules ID + waitDealV1RuleIds *utils.SyncMap[string, *model.RoutingConfig] } ) -// NewRoutingConfigCache Return a object of operating RoutingConfigcache -func NewRoutingConfigCache(s store.Store, cacheMgr types.CacheManager) types.RoutingConfigCache { - return &routingConfigCache{ +// NewRouteRuleCache Return a object of operating RouteRuleCache +func NewRouteRuleCache(s store.Store, cacheMgr types.CacheManager) types.RoutingConfigCache { + return &RouteRuleCache{ BaseCache: types.NewBaseCache(s, cacheMgr), storage: s, } } // initialize The function of implementing the cache interface -func (rc *routingConfigCache) Initialize(_ map[string]interface{}) error { +func (rc *RouteRuleCache) Initialize(_ map[string]interface{}) error { rc.lastMtimeV1 = time.Unix(0, 0) rc.lastMtimeV2 = time.Unix(0, 0) - rc.pendingV1RuleIds = make(map[string]*model.RoutingConfig) - rc.bucket = newRouteRuleBucket() + rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() + rc.container = newRouteRuleContainer() rc.serviceCache = rc.BaseCache.CacheMgr.GetCacher(types.CacheService).(*serviceCache) return nil } // Update The function of implementing the cache interface -func (rc *routingConfigCache) Update() error { +func (rc *RouteRuleCache) Update() error { // Multiple thread competition, only one thread is updated _, err, _ := rc.singleFlight.Do(rc.Name(), func() (interface{}, error) { return nil, rc.DoCacheUpdate(rc.Name(), rc.realUpdate) @@ -82,7 +79,7 @@ func (rc *routingConfigCache) Update() error { } // update The function of implementing the cache interface -func (rc *routingConfigCache) realUpdate() (map[string]time.Time, int64, error) { +func (rc *RouteRuleCache) realUpdate() (map[string]time.Time, int64, error) { outV1, err := rc.storage.GetRoutingConfigsForCache(rc.LastFetchTime(), rc.IsFirstUpdate()) if err != nil { log.Errorf("[Cache] routing config v1 cache get from store err: %s", err.Error()) @@ -98,53 +95,49 @@ func (rc *routingConfigCache) realUpdate() (map[string]time.Time, int64, error) lastMtimes := map[string]time.Time{} rc.setRoutingConfigV1(lastMtimes, outV1) rc.setRoutingConfigV2(lastMtimes, outV2) + rc.container.reload() return lastMtimes, int64(len(outV1) + len(outV2)), err } // Clear The function of implementing the cache interface -func (rc *routingConfigCache) Clear() error { +func (rc *RouteRuleCache) Clear() error { rc.BaseCache.Clear() - rc.pendingV1RuleIds = make(map[string]*model.RoutingConfig) - rc.bucket = newRouteRuleBucket() + rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() + rc.container = newRouteRuleContainer() rc.lastMtimeV1 = time.Unix(0, 0) rc.lastMtimeV2 = time.Unix(0, 0) return nil } // Name The function of implementing the cache interface -func (rc *routingConfigCache) Name() string { +func (rc *RouteRuleCache) Name() string { return types.RoutingConfigName } -func (rc *routingConfigCache) ListRouterRule(service, namespace string) []*model.ExtendRouterConfig { - routerRules := rc.bucket.listEnableRules(service, namespace, true) +func (rc *RouteRuleCache) ListRouterRule(service, namespace string) []*model.ExtendRouterConfig { + routerRules := rc.container.SearchCustomRules(service, namespace) ret := make([]*model.ExtendRouterConfig, 0, len(routerRules)) - for level := range routerRules { - items := routerRules[level] - ret = append(ret, items...) - } + ret = append(ret, routerRules...) return ret } // GetRouterConfigV2 Obtain routing configuration based on serviceid -func (rc *routingConfigCache) GetRouterConfigV2(id, service, namespace string) (*apitraffic.Routing, error) { +func (rc *RouteRuleCache) GetRouterConfigV2(id, service, namespace string) (*apitraffic.Routing, error) { if id == "" && service == "" && namespace == "" { return nil, nil } - routerRules := rc.bucket.listEnableRules(service, namespace, true) - revisions := make([]string, 0, 8) + routerRules := rc.container.SearchCustomRules(service, namespace) + revisions := make([]string, 0, len(routerRules)) rulesV2 := make([]*apitraffic.RouteRule, 0, len(routerRules)) - for level := range routerRules { - items := routerRules[level] - for i := range items { - entry, err := items[i].ToApi() - if err != nil { - return nil, err - } - rulesV2 = append(rulesV2, entry) - revisions = append(revisions, entry.GetRevision()) + for i := range routerRules { + item := routerRules[i] + entry, err := item.ToApi() + if err != nil { + return nil, err } + rulesV2 = append(rulesV2, entry) + revisions = append(revisions, entry.GetRevision()) } revision, err := types.CompositeComputeRevision(revisions) if err != nil { @@ -162,61 +155,85 @@ func (rc *routingConfigCache) GetRouterConfigV2(id, service, namespace string) ( } // GetRouterConfig Obtain routing configuration based on serviceid -func (rc *routingConfigCache) GetRouterConfig(id, service, namespace string) (*apitraffic.Routing, error) { - if id == "" && service == "" && namespace == "" { +func (rc *RouteRuleCache) GetRouterConfig(id, svcName, namespace string) (*apitraffic.Routing, error) { + if id == "" && svcName == "" && namespace == "" { return nil, nil } - routerRules := rc.bucket.listEnableRules(service, namespace, false) - inBounds, outBounds, revisions := rc.convertV2toV1(routerRules, service, namespace) + key := model.ServiceKey{Namespace: namespace, Name: svcName} + + revisions := []string{} + inRule, inRevision := rc.container.customContainers[model.TrafficDirection_INBOUND].SearchCustomRuleV1(key) + revisions = append(revisions, inRevision...) + outRule, outRevision := rc.container.customContainers[model.TrafficDirection_OUTBOUND].SearchCustomRuleV1(key) + revisions = append(revisions, outRevision...) + revision, err := types.CompositeComputeRevision(revisions) if err != nil { log.Warn("[Cache][Routing] v2=>v1 compute revisions fail, use fake revision", zap.Error(err)) revision = utils.NewV2Revision() } - resp := &apitraffic.Routing{ + return &apitraffic.Routing{ Namespace: utils.NewStringValue(namespace), - Service: utils.NewStringValue(service), - Inbounds: inBounds, - Outbounds: outBounds, + Service: utils.NewStringValue(svcName), + Inbounds: inRule.Inbounds, + Outbounds: outRule.Outbounds, Revision: utils.NewStringValue(revision), - } - - return formatRoutingResponseV1(resp), nil + }, nil } -// formatRoutingResponseV1 Give the client's cache, no need to expose EXTENDINFO information data -func formatRoutingResponseV1(ret *apitraffic.Routing) *apitraffic.Routing { - inBounds := ret.Inbounds - outBounds := ret.Outbounds +// GetNearbyRouteRule 根据服务名查询就近路由数据 +func (rc *RouteRuleCache) GetNearbyRouteRule(service, namespace string) ([]*apitraffic.RouteRule, string, error) { + if service == "" && namespace == "" { + return nil, "", nil + } - for i := range inBounds { - inBounds[i].ExtendInfo = nil + svcKey := model.ServiceKey{ + Namespace: namespace, + Name: service, } - for i := range outBounds { - outBounds[i].ExtendInfo = nil + routerRules := rc.container.nearbyContainers.SearchRouteRuleV2(svcKey) + revisions := make([]string, 0, len(routerRules)) + ret := make([]*apitraffic.RouteRule, 0, len(routerRules)) + for i := range routerRules { + item := routerRules[i] + entry, err := item.ToApi() + if err != nil { + return nil, "", err + } + ret = append(ret, entry) + revisions = append(revisions, entry.GetRevision()) } - return ret + revision, err := types.CompositeComputeRevision(revisions) + if err != nil { + log.Warn("[Cache][Routing] v2=>v1 compute revisions fail, use fake revision", zap.Error(err)) + revision = utils.NewV2Revision() + } + + return ret, revision, nil } // IteratorRouterRule -func (rc *routingConfigCache) IteratorRouterRule(iterProc types.RouterRuleIterProc) { +func (rc *RouteRuleCache) IteratorRouterRule(iterProc types.RouterRuleIterProc) { // need to traverse the Routing cache bucket of V2 here - rc.bucket.foreach(iterProc) + rc.container.foreach(iterProc) } // GetRoutingConfigCount Get the total number of routing configuration cache -func (rc *routingConfigCache) GetRoutingConfigCount() int { - return rc.bucket.size() +func (rc *RouteRuleCache) GetRoutingConfigCount() int { + return rc.container.size() } -// setRoutingConfigV1 Update the data of the store to the cache and convert to v2 model -func (rc *routingConfigCache) setRoutingConfigV1(lastMtimes map[string]time.Time, cs []*model.RoutingConfig) { - rc.plock.Lock() - defer rc.plock.Unlock() +// GetRule implements api.RoutingConfigCache. +func (rc *RouteRuleCache) GetRule(id string) *model.ExtendRouterConfig { + rule, _ := rc.container.rules.Load(id) + return rule +} +// setRoutingConfigV1 Update the data of the store to the cache and convert to v2 model +func (rc *RouteRuleCache) setRoutingConfigV1(lastMtimes map[string]time.Time, cs []*model.RoutingConfig) { if len(cs) == 0 { return } @@ -230,38 +247,35 @@ func (rc *routingConfigCache) setRoutingConfigV1(lastMtimes map[string]time.Time } if !entry.Valid { // Delete the cache converted to V2 - rc.bucket.deleteV1(entry.ID) + rc.container.deleteV1(entry.ID) continue } - rc.pendingV1RuleIds[entry.ID] = entry + rc.waitDealV1RuleIds.Store(entry.ID, entry) } - for id := range rc.pendingV1RuleIds { - entry := rc.pendingV1RuleIds[id] + rc.waitDealV1RuleIds.Range(func(key string, val *model.RoutingConfig) { // Save to the new V2 cache - ok, v2rule, err := rc.convertV1toV2(entry) + ok, rules, err := rc.convertV1toV2(val) if err != nil { log.Warn("[Cache] routing parse v1 => v2 fail, will try again next", - zap.String("rule-id", entry.ID), zap.Error(err)) - continue + zap.String("rule-id", val.ID), zap.Error(err)) + return } if !ok { - log.Warn("[Cache] routing parse v1 => v2 is nil, will try again next", - zap.String("rule-id", entry.ID)) - continue + log.Warn("[Cache] routing parse v1 => v2 is nil, will try again next", zap.String("rule-id", val.ID)) + return } - if ok && v2rule != nil { - delete(rc.pendingV1RuleIds, id) - rc.bucket.saveV1(entry, v2rule) + if ok && len(rules) != 0 { + rc.waitDealV1RuleIds.Delete(key) + rc.container.saveV1(val, rules) } - } - + }) lastMtimes[rc.Name()] = time.Unix(lastMtimeV1, 0) - log.Infof("[Cache] convert routing parse v1 => v2 count : %d", rc.bucket.convertV2Size()) + log.Infof("[Cache] convert routing parse v1 => v2 count : %d", rc.container.convertV2Size()) } // setRoutingConfigV2 Store V2 Router Caches -func (rc *routingConfigCache) setRoutingConfigV2(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { +func (rc *RouteRuleCache) setRoutingConfigV2(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { if len(cs) == 0 { return } @@ -275,7 +289,7 @@ func (rc *routingConfigCache) setRoutingConfigV2(lastMtimes map[string]time.Time lastMtimeV2 = entry.ModifyTime.Unix() } if !entry.Valid { - rc.bucket.deleteV2(entry.ID) + rc.container.deleteV2(entry.ID) continue } extendEntry, err := entry.ToExpendRoutingConfig() @@ -283,17 +297,17 @@ func (rc *routingConfigCache) setRoutingConfigV2(lastMtimes map[string]time.Time log.Error("[Cache] routing config v2 convert to expend", zap.Error(err)) continue } - rc.bucket.saveV2(extendEntry) + rc.container.saveV2(extendEntry) } lastMtimes[rc.Name()+"v2"] = time.Unix(lastMtimeV2, 0) } -func (rc *routingConfigCache) IsConvertFromV1(id string) (string, bool) { - val, ok := rc.bucket.v1rulesToOld[id] +func (rc *RouteRuleCache) IsConvertFromV1(id string) (string, bool) { + val, ok := rc.container.v1rulesToOld[id] return val, ok } -func (rc *routingConfigCache) convertV1toV2(rule *model.RoutingConfig) (bool, []*model.ExtendRouterConfig, error) { +func (rc *RouteRuleCache) convertV1toV2(rule *model.RoutingConfig) (bool, []*model.ExtendRouterConfig, error) { svc := rc.serviceCache.GetServiceByID(rule.ID) if svc == nil { s, err := rc.storage.GetServiceByID(rule.ID) @@ -320,44 +334,3 @@ func (rc *routingConfigCache) convertV1toV2(rule *model.RoutingConfig) (bool, [] return true, ret, nil } - -// convertV2toV1 The routing rules of the V2 version are converted to V1 version to return to the client, -// which is used to compatible with SDK issuance configuration. -func (rc *routingConfigCache) convertV2toV1(entries map[routingLevel][]*model.ExtendRouterConfig, - service, namespace string) ([]*apitraffic.Route, []*apitraffic.Route, []string) { - level1 := entries[level1RoutingV2] - sort.Slice(level1, func(i, j int) bool { - return model.CompareRoutingV2(level1[i], level1[j]) - }) - - level2 := entries[level2RoutingV2] - sort.Slice(level2, func(i, j int) bool { - return model.CompareRoutingV2(level2[i], level2[j]) - }) - - level3 := entries[level3RoutingV2] - sort.Slice(level3, func(i, j int) bool { - return model.CompareRoutingV2(level3[i], level3[j]) - }) - - level1inRoutes, level1outRoutes, level1Revisions := model.BuildV1RoutesFromV2(service, namespace, level1) - level2inRoutes, level2outRoutes, level2Revisions := model.BuildV1RoutesFromV2(service, namespace, level2) - level3inRoutes, level3outRoutes, level3Revisions := model.BuildV1RoutesFromV2(service, namespace, level3) - - revisions := make([]string, 0, len(level1Revisions)+len(level2Revisions)+len(level3Revisions)) - revisions = append(revisions, level1Revisions...) - revisions = append(revisions, level2Revisions...) - revisions = append(revisions, level3Revisions...) - - inRoutes := make([]*apitraffic.Route, 0, len(level1inRoutes)+len(level2inRoutes)+len(level3inRoutes)) - inRoutes = append(inRoutes, level1inRoutes...) - inRoutes = append(inRoutes, level2inRoutes...) - inRoutes = append(inRoutes, level3inRoutes...) - - outRoutes := make([]*apitraffic.Route, 0, len(level1outRoutes)+len(level2outRoutes)+len(level3outRoutes)) - outRoutes = append(outRoutes, level1outRoutes...) - outRoutes = append(outRoutes, level2outRoutes...) - outRoutes = append(outRoutes, level3outRoutes...) - - return inRoutes, outRoutes, revisions -} diff --git a/cache/service/router_rule_bucket.go b/cache/service/router_rule_bucket.go index 7f5f19376..5b32ea53f 100644 --- a/cache/service/router_rule_bucket.go +++ b/cache/service/router_rule_bucket.go @@ -18,186 +18,342 @@ package service import ( - "fmt" + "sort" "sync" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" types "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" ) -type ( - routingLevel int16 - boundType int16 - - serviceInfo interface { - GetNamespace() string - GetService() string - } -) - -const ( - _ routingLevel = iota - level1RoutingV2 - level2RoutingV2 - level3RoutingV2 - - _ boundType = iota - inBound - outBound -) - -func newRouteRuleBucket() *routeRuleBucket { - return &routeRuleBucket{ - rules: make(map[string]*model.ExtendRouterConfig), - level1Rules: map[string]map[string]struct{}{}, - level2Rules: map[boundType]map[string]map[string]struct{}{ - inBound: {}, - outBound: {}, - }, - level3Rules: map[boundType]map[string]struct{}{ - inBound: {}, - outBound: {}, - }, - v1rules: map[string][]*model.ExtendRouterConfig{}, - v1rulesToOld: map[string]string{}, - } -} - -// ServiceWithCircuitBreakerRules 与服务关系绑定的熔断规则 +// ServiceWithRouterRules 与服务绑定的路由规则数据 type ServiceWithRouterRules struct { - mutex sync.RWMutex - Service model.ServiceKey - v2Rules map[string]*model.ExtendRouterConfig - v1Rules *apitraffic.Routing - Revision string + direction model.TrafficDirection + mutex sync.RWMutex + Service model.ServiceKey + // sortKeys: 针对 customv2Rules 做了排序 + sortKeys []string + rules map[string]*model.ExtendRouterConfig + revision string + + customv1Rules *apitraffic.Routing } -func NewServiceWithRouterRules(svcKey model.ServiceKey) *ServiceWithRouterRules { +func NewServiceWithRouterRules(svcKey model.ServiceKey, direction model.TrafficDirection) *ServiceWithRouterRules { return &ServiceWithRouterRules{ - Service: svcKey, - v2Rules: make(map[string]*model.ExtendRouterConfig), + direction: direction, + Service: svcKey, + rules: make(map[string]*model.ExtendRouterConfig), } } +// AddRouterRule 添加路由规则,注意,这里只会保留处于 Enable 状态的路由规则 func (s *ServiceWithRouterRules) AddRouterRule(rule *model.ExtendRouterConfig) { + if !rule.Enable { + return + } + if rule.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { + s.customv1Rules = &apitraffic.Routing{ + Inbounds: []*apitraffic.Route{}, + Outbounds: []*apitraffic.Route{}, + } + } + s.mutex.Lock() defer s.mutex.Unlock() - s.v2Rules[rule.ID] = rule + s.rules[rule.ID] = rule } func (s *ServiceWithRouterRules) DelRouterRule(id string) { s.mutex.Lock() defer s.mutex.Unlock() - delete(s.v2Rules, id) + delete(s.rules, id) } +// IterateRouterRules 这里是可以保证按照路由规则优先顺序进行遍历 func (s *ServiceWithRouterRules) IterateRouterRules(callback func(*model.ExtendRouterConfig)) { s.mutex.RLock() defer s.mutex.RUnlock() - for _, rule := range s.v2Rules { - callback(rule) + + for _, key := range s.sortKeys { + val, ok := s.rules[key] + if ok { + callback(val) + } } + } func (s *ServiceWithRouterRules) CountRouterRules() int { s.mutex.RLock() defer s.mutex.RUnlock() - return len(s.v2Rules) + return len(s.rules) } func (s *ServiceWithRouterRules) Clear() { s.mutex.Lock() defer s.mutex.Unlock() - s.v2Rules = make(map[string]*model.ExtendRouterConfig) - s.Revision = "" + s.rules = make(map[string]*model.ExtendRouterConfig) + s.revision = "" } -// routeRuleBucket v2 路由规则缓存 bucket -type routeRuleBucket struct { - lock sync.RWMutex - // rules id => routing rule - rules map[string]*model.ExtendRouterConfig - // level1Rules service(name)+namespace => 路由规则ID列表,只针对某个具体的服务有效 - level1Rules map[string]map[string]struct{} - // level2Rules service(*) + namespace => 路由规则ID列表, 针对某个命名空间下所有服务都生效的路由规则 - level2Rules map[boundType]map[string]map[string]struct{} - // level3Rules service(*) + namespace(*) => 路由规则ID列表, 针对所有命名空间下的所有服务都生效的规则 - level3Rules map[boundType]map[string]struct{} - // v1rules service-id => []*model.ExtendRouterConfig v1 版本的规则自动转为 v2 版本的规则,用于 v2 接口的数据查看 - v1rules map[string][]*model.ExtendRouterConfig +func (s *ServiceWithRouterRules) reload() { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.reloadRuleOrder() + s.reloadRevision() + s.reloadV1Rules() +} + +func (s *ServiceWithRouterRules) reloadRuleOrder() { + curRules := make([]*model.ExtendRouterConfig, 0, len(s.rules)) + for i := range s.rules { + curRules = append(curRules, s.rules[i]) + } + + sort.Slice(curRules, func(i, j int) bool { + return model.CompareRoutingV2(curRules[i], curRules[j]) + }) + + curKeys := make([]string, 0, len(curRules)) + for i := range curRules { + curKeys = append(curKeys, curRules[i].ID) + } + + s.sortKeys = curKeys +} - // fetched service cache +func (s *ServiceWithRouterRules) reloadRevision() { + revisioins := make([]string, 0, len(s.rules)) + for i := range s.sortKeys { + revisioins = append(revisioins, s.rules[s.sortKeys[i]].Revision) + } + s.revision, _ = types.CompositeComputeRevision(revisioins) +} + +func (s *ServiceWithRouterRules) reloadV1Rules() { + if s.customv1Rules == nil { + return + } + + rules := make([]*model.ExtendRouterConfig, 0, 32) + for i := range s.sortKeys { + rule, ok := s.rules[s.sortKeys[i]] + if !ok { + continue + } + rules = append(rules, rule) + } + + routes := make([]*apitraffic.Route, 0, 32) + + for i := range rules { + if rules[i].Priority != uint32(apitraffic.RoutingPolicy_RulePolicy) { + continue + } + routes = append(routes, model.BuildRoutes(rules[i], s.direction)...) + } + + s.customv1Rules = &apitraffic.Routing{} + switch s.direction { + case model.TrafficDirection_INBOUND: + s.customv1Rules.Inbounds = routes + case model.TrafficDirection_OUTBOUND: + s.customv1Rules.Outbounds = routes + } +} + +func newClientRouteRuleContainer(direction model.TrafficDirection) *ClientRouteRuleContainer { + return &ClientRouteRuleContainer{ + direction: direction, + exactRules: utils.NewSyncMap[string, *ServiceWithRouterRules](), + nsWildcardRules: utils.NewSyncMap[string, *ServiceWithRouterRules](), + allWildcardRules: NewServiceWithRouterRules(model.ServiceKey{Namespace: types.AllMatched, Name: types.AllMatched}, direction), + } +} + +type ClientRouteRuleContainer struct { + direction model.TrafficDirection // key1: namespace, key2: service - routerRules map[string]map[string]*ServiceWithRouterRules - // key1: namespace - nsWildcardRules map[string]*ServiceWithRouterRules + exactRules *utils.SyncMap[string, *ServiceWithRouterRules] + // key1: namespace is exact, service is full match + nsWildcardRules *utils.SyncMap[string, *ServiceWithRouterRules] // all rules are wildcard specific allWildcardRules *ServiceWithRouterRules - - // v1rulesToOld 转为 v2 规则id 对应的原本的 v1 规则id 信息 - v1rulesToOld map[string]string } -func (b *routeRuleBucket) getV2(id string) *model.ExtendRouterConfig { - b.lock.RLock() - defer b.lock.RUnlock() +func (c *ClientRouteRuleContainer) SearchRouteRuleV2(svc model.ServiceKey) []*model.ExtendRouterConfig { + ret := make([]*model.ExtendRouterConfig, 0, 32) + + exactRule, existExactRule := c.exactRules.Load(svc.Domain()) + if existExactRule { + exactRule.IterateRouterRules(func(erc *model.ExtendRouterConfig) { + ret = append(ret, erc) + }) + } - return b.rules[id] + nsWildcardRule, existNsWildcardRule := c.nsWildcardRules.Load(svc.Namespace) + if existNsWildcardRule { + nsWildcardRule.IterateRouterRules(func(erc *model.ExtendRouterConfig) { + ret = append(ret, erc) + }) + } + + c.allWildcardRules.IterateRouterRules(func(erc *model.ExtendRouterConfig) { + ret = append(ret, erc) + }) + return ret } -func (b *routeRuleBucket) saveV2(conf *model.ExtendRouterConfig) { - b.lock.Lock() - defer b.lock.Unlock() +// SearchCustomRuleV1 针对 v1 客户端拉取路由规则 +func (c *ClientRouteRuleContainer) SearchCustomRuleV1(svc model.ServiceKey) (*apitraffic.Routing, []string) { + ret := &apitraffic.Routing{ + Inbounds: make([]*apitraffic.Route, 0, 8), + Outbounds: make([]*apitraffic.Route, 0, 8), + } + exactRule, existExactRule := c.exactRules.Load(svc.Domain()) + nsWildcardRule, existNsWildcardRule := c.nsWildcardRules.Load(svc.Namespace) - b.rules[conf.ID] = conf - handler := func(bt boundType, item serviceInfo) { - // level1 级别 cache 处理 - if item.GetService() != model.MatchAll && item.GetNamespace() != model.MatchAll { - key := buildServiceKey(item.GetNamespace(), item.GetService()) - if _, ok := b.level1Rules[key]; !ok { - b.level1Rules[key] = map[string]struct{}{} - } + revisions := make([]string, 0, 2) - b.level1Rules[key][conf.ID] = struct{}{} - return + switch c.direction { + case model.TrafficDirection_INBOUND: + if existExactRule { + ret.Inbounds = append(ret.Inbounds, exactRule.customv1Rules.Inbounds...) + } + if existNsWildcardRule { + ret.Inbounds = append(ret.Inbounds, nsWildcardRule.customv1Rules.Inbounds...) + } + default: + if existExactRule { + ret.Outbounds = append(ret.Outbounds, exactRule.customv1Rules.Outbounds...) + revisions = append(revisions, exactRule.revision) } - // level2 级别 cache 处理 - if item.GetService() == model.MatchAll && item.GetNamespace() != model.MatchAll { - if _, ok := b.level2Rules[bt][item.GetNamespace()]; !ok { - b.level2Rules[bt][item.GetNamespace()] = map[string]struct{}{} - } - b.level2Rules[bt][item.GetNamespace()][conf.ID] = struct{}{} + if existNsWildcardRule { + ret.Outbounds = append(ret.Outbounds, nsWildcardRule.customv1Rules.Outbounds...) + } + } + if existExactRule { + revisions = append(revisions, exactRule.revision) + } + if existNsWildcardRule { + revisions = append(revisions, nsWildcardRule.revision) + } + + return ret, revisions +} + +func (c *ClientRouteRuleContainer) SaveRule(svcKey model.ServiceKey, item *model.ExtendRouterConfig) { + // level1 级别 cache 处理 + if svcKey.Name != model.MatchAll && svcKey.Namespace != model.MatchAll { + c.exactRules.ComputeIfAbsent(svcKey.Domain(), func(k string) *ServiceWithRouterRules { + return NewServiceWithRouterRules(svcKey, c.direction) + }) + svcContainer, _ := c.exactRules.Load(svcKey.Domain()) + svcContainer.AddRouterRule(item) + } + // level2 级别 cache 处理 + if svcKey.Name == model.MatchAll && svcKey.Namespace != model.MatchAll { + c.nsWildcardRules.ComputeIfAbsent(svcKey.Namespace, func(k string) *ServiceWithRouterRules { + return NewServiceWithRouterRules(svcKey, c.direction) + }) + + nsRules, _ := c.nsWildcardRules.Load(svcKey.Namespace) + nsRules.AddRouterRule(item) + } + // level3 级别 cache 处理 + if svcKey.Name == model.MatchAll && svcKey.Namespace == model.MatchAll { + c.allWildcardRules.AddRouterRule(item) + } +} + +func (c *ClientRouteRuleContainer) RemoveRule(svcKey model.ServiceKey, ruleId string) { + // level1 级别 cache 处理 + if svcKey.Name != model.MatchAll && svcKey.Namespace != model.MatchAll { + svcContainer, ok := c.exactRules.Load(svcKey.Domain()) + if !ok { return } - // level3 级别 cache 处理 - if item.GetService() == model.MatchAll && item.GetNamespace() == model.MatchAll { - b.level3Rules[bt][conf.ID] = struct{}{} + svcContainer.DelRouterRule(ruleId) + } + // level2 级别 cache 处理 + if svcKey.Name == model.MatchAll && svcKey.Namespace != model.MatchAll { + nsRules, ok := c.nsWildcardRules.Load(svcKey.Namespace) + if !ok { return } + nsRules.DelRouterRule(ruleId) + } + // level3 级别 cache 处理 + if svcKey.Name == model.MatchAll && svcKey.Namespace == model.MatchAll { + c.allWildcardRules.DelRouterRule(ruleId) } +} - if conf.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { - subRules := conf.RuleRouting.Rules - for i := range subRules { - sources := subRules[i].Sources - for i := range sources { - item := sources[i] - handler(outBound, item) - } - - destinations := subRules[i].Destinations - for i := range destinations { - item := destinations[i] - handler(inBound, item) - } - } +func newRouteRuleContainer() *RouteRuleContainer { + return &RouteRuleContainer{ + rules: utils.NewSyncMap[string, *model.ExtendRouterConfig](), + v1rules: map[string][]*model.ExtendRouterConfig{}, + v1rulesToOld: map[string]string{}, + nearbyContainers: newClientRouteRuleContainer(model.TrafficDirection_INBOUND), + customContainers: map[model.TrafficDirection]*ClientRouteRuleContainer{ + model.TrafficDirection_INBOUND: newClientRouteRuleContainer(model.TrafficDirection_INBOUND), + model.TrafficDirection_OUTBOUND: newClientRouteRuleContainer(model.TrafficDirection_OUTBOUND), + }, + effect: utils.NewSyncSet[model.ServiceKey](), } } +// RouteRuleContainer v2 路由规则缓存 bucket +type RouteRuleContainer struct { + // rules id => routing rule + rules *utils.SyncMap[string, *model.ExtendRouterConfig] + + // 就近路由规则缓存 + nearbyContainers *ClientRouteRuleContainer + // 自定义路由规则缓存 + customContainers map[model.TrafficDirection]*ClientRouteRuleContainer + + // effect 记录一次缓存更新中,那些服务的路由出现了更新 + effect *utils.SyncSet[model.ServiceKey] + + // ------- 这里的逻辑都是为了兼容老的数据规则,这个将在1.18.2代码中移除,通过升级工具一次性处理 ------ + lock sync.RWMutex + // v1rules service-id => []*model.ExtendRouterConfig v1 版本的规则自动转为 v2 版本的规则,用于 v2 接口的数据查看 + v1rules map[string][]*model.ExtendRouterConfig + // v1rulesToOld 转为 v2 规则id 对应的原本的 v1 规则id 信息 + v1rulesToOld map[string]string +} + +func (b *RouteRuleContainer) saveV2(conf *model.ExtendRouterConfig) { + b.rules.Store(conf.ID, conf) + handler := func(container *ClientRouteRuleContainer, svcKey model.ServiceKey) { + b.effect.Add(svcKey) + container.SaveRule(svcKey, conf) + } + + switch conf.GetRoutingPolicy() { + case apitraffic.RoutingPolicy_RulePolicy: + handler(b.customContainers[model.TrafficDirection_OUTBOUND], conf.RuleRouting.Caller) + handler(b.customContainers[model.TrafficDirection_INBOUND], conf.RuleRouting.Callee) + case apitraffic.RoutingPolicy_NearbyPolicy: + handler(b.nearbyContainers, model.ServiceKey{ + Namespace: conf.NearbyRouting.Namespace, + Name: conf.NearbyRouting.Service, + }) + } + +} + // saveV1 保存 v1 级别的路由规则 -func (b *routeRuleBucket) saveV1(v1rule *model.RoutingConfig, v2rules []*model.ExtendRouterConfig) { +func (b *RouteRuleContainer) saveV1(v1rule *model.RoutingConfig, v2rules []*model.ExtendRouterConfig) { + for i := range v2rules { + b.saveV2(v2rules[i]) + } + b.lock.Lock() defer b.lock.Unlock() @@ -209,55 +365,39 @@ func (b *routeRuleBucket) saveV1(v1rule *model.RoutingConfig, v2rules []*model.E } } -func (b *routeRuleBucket) convertV2Size() uint32 { +func (b *RouteRuleContainer) convertV2Size() uint32 { b.lock.RLock() defer b.lock.RUnlock() return uint32(len(b.v1rulesToOld)) } -func (b *routeRuleBucket) deleteV2(id string) { - b.lock.Lock() - defer b.lock.Unlock() - - rule := b.rules[id] - delete(b.rules, id) - - if rule == nil { +func (b *RouteRuleContainer) deleteV2(id string) { + rule, exist := b.rules.Load(id) + b.rules.Delete(id) + if !exist { return } - if rule.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { - return + handler := func(container *ClientRouteRuleContainer, svcKey model.ServiceKey) { + b.effect.Add(svcKey) + container.RemoveRule(svcKey, id) } - subRules := rule.RuleRouting.Rules - for i := range subRules { - for j := range subRules[i].GetSources() { - source := subRules[i].GetSources()[j] - service := source.GetService() - namespace := source.GetNamespace() - - if service == model.MatchAll && namespace == model.MatchAll { - delete(b.level3Rules[outBound], id) - delete(b.level3Rules[inBound], id) - } - - if service == model.MatchAll && namespace != model.MatchAll { - delete(b.level2Rules[outBound][namespace], id) - delete(b.level2Rules[inBound][namespace], id) - } - - if service != model.MatchAll && namespace != model.MatchAll { - key := buildServiceKey(namespace, service) - delete(b.level1Rules[key], id) - } - } + switch rule.GetRoutingPolicy() { + case apitraffic.RoutingPolicy_RulePolicy: + handler(b.customContainers[model.TrafficDirection_OUTBOUND], rule.RuleRouting.Caller) + handler(b.customContainers[model.TrafficDirection_INBOUND], rule.RuleRouting.Callee) + case apitraffic.RoutingPolicy_NearbyPolicy: + handler(b.nearbyContainers, model.ServiceKey{ + Namespace: rule.NearbyRouting.Namespace, + Name: rule.NearbyRouting.Service, + }) } } // deleteV1 删除 v1 的路由规则 -func (b *routeRuleBucket) deleteV1(serviceId string) { +func (b *RouteRuleContainer) deleteV1(serviceId string) { b.lock.Lock() defer b.lock.Unlock() @@ -269,16 +409,17 @@ func (b *routeRuleBucket) deleteV1(serviceId string) { for i := range items { delete(b.v1rulesToOld, items[i].ID) + b.deleteV2(items[i].ID) } delete(b.v1rules, serviceId) } // size Number of routing-v2 cache rules -func (b *routeRuleBucket) size() int { +func (b *RouteRuleContainer) size() int { b.lock.RLock() defer b.lock.RUnlock() - cnt := len(b.rules) + cnt := b.rules.Len() for k := range b.v1rules { cnt += len(b.v1rules[k]) } @@ -286,76 +427,34 @@ func (b *routeRuleBucket) size() int { return cnt } -// listEnableRules Inquire the routing rules of the V2 version through the service name, -// and perform some filtering according to the Predicate -func (b *routeRuleBucket) listEnableRules(service, namespace string, enableFullMatch bool) map[routingLevel][]*model.ExtendRouterConfig { - ret := make(map[routingLevel][]*model.ExtendRouterConfig) - tmpRecord := map[string]struct{}{} +func (b *RouteRuleContainer) SearchCustomRules(svcName, namespace string) []*model.ExtendRouterConfig { + ruleIds := map[string]struct{}{} - b.lock.RLock() - defer b.lock.RUnlock() + svcKey := model.ServiceKey{Namespace: namespace, Name: svcName} - predicate := func(item *model.ExtendRouterConfig) bool { - return item.Enable - } + ret := make([]*model.ExtendRouterConfig, 0, 32) - // Query Level1 V2 version routing rules - key := buildServiceKey(namespace, service) - ids := b.level1Rules[key] - level1 := make([]*model.ExtendRouterConfig, 0, 4) - for i := range ids { - if v, ok := b.rules[i]; ok && predicate(v) { - level1 = append(level1, v) - tmpRecord[v.ID] = struct{}{} - } + rules := b.customContainers[model.TrafficDirection_INBOUND].SearchRouteRuleV2(svcKey) + ret = append(ret, rules...) + for i := range rules { + ruleIds[rules[i].ID] = struct{}{} } - ret[level1RoutingV2] = level1 - - handler := func(ids map[string]struct{}, bt boundType) []*model.ExtendRouterConfig { - ret := make([]*model.ExtendRouterConfig, 0, 4) - - for k := range ids { - v := b.rules[k] - if v == nil { - continue - } - if _, ok := tmpRecord[v.ID]; ok { - continue - } - if !predicate(v) { - continue - } - ret = append(ret, v) - tmpRecord[v.ID] = struct{}{} - } - return ret + rules = b.customContainers[model.TrafficDirection_OUTBOUND].SearchRouteRuleV2(svcKey) + for i := range rules { + if _, ok := ruleIds[rules[i].ID]; !ok { + ret = append(ret, rules[i]) + } } - // Query Level 2 level routing-v2 rules - level2 := make([]*model.ExtendRouterConfig, 0, 4) - level2 = append(level2, handler(b.level2Rules[outBound][namespace], outBound)...) - level2 = append(level2, handler(b.level2Rules[inBound][namespace], inBound)...) - ret[level2RoutingV2] = level2 - - if enableFullMatch { - // Query Level3 level routing-v2 rules - level3 := make([]*model.ExtendRouterConfig, 0, 4) - level3 = append(level3, handler(b.level3Rules[outBound], outBound)...) - level3 = append(level3, handler(b.level3Rules[inBound], inBound)...) - ret[level3RoutingV2] = level3 - } return ret } // foreach Traversing all routing rules -func (b *routeRuleBucket) foreach(proc types.RouterRuleIterProc) { - b.lock.RLock() - defer b.lock.RUnlock() - - for k, v := range b.rules { - proc(k, v) - } +func (b *RouteRuleContainer) foreach(proc types.RouterRuleIterProc) { + b.rules.Range(func(key string, val *model.ExtendRouterConfig) { + proc(key, val) + }) for _, rules := range b.v1rules { for i := range rules { @@ -364,6 +463,51 @@ func (b *routeRuleBucket) foreach(proc types.RouterRuleIterProc) { } } -func buildServiceKey(namespace, service string) string { - return fmt.Sprintf("%s@@%s", namespace, service) +func (b *RouteRuleContainer) reload() { + b.effect.Range(func(val model.ServiceKey) { + b.reloadCustom(val) + b.reloadNearby(val) + }) +} + +func (b *RouteRuleContainer) reloadCustom(val model.ServiceKey) { + // 处理自定义路由 + // 处理 exact + rules, ok := b.customContainers[model.TrafficDirection_INBOUND].exactRules.Load(val.Domain()) + if ok { + rules.reload() + } + rules, ok = b.customContainers[model.TrafficDirection_OUTBOUND].exactRules.Load(val.Domain()) + if ok { + rules.reload() + } + + // 处理 ns wildcard + rules, ok = b.customContainers[model.TrafficDirection_INBOUND].nsWildcardRules.Load(val.Namespace) + if ok { + rules.reload() + } + rules, ok = b.customContainers[model.TrafficDirection_OUTBOUND].nsWildcardRules.Load(val.Namespace) + if ok { + rules.reload() + } + + // 处理 all wildcard + b.customContainers[model.TrafficDirection_INBOUND].allWildcardRules.reload() + b.customContainers[model.TrafficDirection_OUTBOUND].allWildcardRules.reload() +} + +func (b *RouteRuleContainer) reloadNearby(val model.ServiceKey) { + // 处理 exact + rules, ok := b.nearbyContainers.exactRules.Load(val.Domain()) + if ok { + rules.reload() + } + // 处理 ns wildcard + rules, ok = b.nearbyContainers.nsWildcardRules.Load(val.Namespace) + if ok { + rules.reload() + } + // 处理 all wildcard + b.nearbyContainers.allWildcardRules.reload() } diff --git a/cache/service/router_rule_query.go b/cache/service/router_rule_query.go index 09c78dc11..b5262c287 100644 --- a/cache/service/router_rule_query.go +++ b/cache/service/router_rule_query.go @@ -18,6 +18,7 @@ package service import ( + "context" "sort" "strings" @@ -29,7 +30,7 @@ import ( ) // forceUpdate 更新配置 -func (rc *routingConfigCache) forceUpdate() error { +func (rc *RouteRuleCache) forceUpdate() error { if err := rc.Update(); err != nil { return err } @@ -53,8 +54,8 @@ func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace destService, isWildDestSvc := utils.ParseWildName(destService) destNamespace, isWildDestNamespace := utils.ParseWildName(destNamespace) - for i := range rule.RuleRouting.Rules { - subRule := rule.RuleRouting.Rules[i] + for i := range rule.RuleRouting.RuleRouting.Rules { + subRule := rule.RuleRouting.RuleRouting.Rules[i] sources := subRule.GetSources() if hasSourceNamespace || hasSourceSvc { for i := range sources { @@ -119,7 +120,7 @@ func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace } // QueryRoutingConfigsV2 Query Route Configuration List -func (rc *routingConfigCache) QueryRoutingConfigsV2(args *types.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { +func (rc *RouteRuleCache) QueryRoutingConfigsV2(ctx context.Context, args *types.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { if err := rc.forceUpdate(); err != nil { return 0, nil, err } @@ -188,7 +189,7 @@ func (rc *routingConfigCache) QueryRoutingConfigsV2(args *types.RoutingArgs) (ui return amount, routings, nil } -func (rc *routingConfigCache) sortBeforeTrim(routings []*model.ExtendRouterConfig, +func (rc *RouteRuleCache) sortBeforeTrim(routings []*model.ExtendRouterConfig, args *types.RoutingArgs) (uint32, []*model.ExtendRouterConfig) { amount := uint32(len(routings)) if args.Offset >= amount || args.Limit == 0 { diff --git a/cache/service/router_rule_query_test.go b/cache/service/router_rule_query_test.go index 1c05075a8..d2628f264 100644 --- a/cache/service/router_rule_query_test.go +++ b/cache/service/router_rule_query_test.go @@ -43,19 +43,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "命名空间-或-精确查询", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -74,19 +76,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "命名空间-与-精确查询", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -105,19 +109,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "命名空间-或-模糊查询", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -136,19 +142,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "命名空间-与-模糊查询", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -167,19 +175,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间精确查询+服务名精确查询)-或", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -198,19 +208,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间精确查询+服务名精确查询)-与", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -229,19 +241,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间精确查询+服务名精确查询)-与", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -260,19 +274,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间模糊+服务名精确查询)-或", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -291,19 +307,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间模糊+服务名精确查询)-或", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, @@ -322,19 +340,21 @@ func Test_queryRoutingRuleV2ByService(t *testing.T) { name: "(命名空间模糊+服务名精确查询)-或", args: args{ rule: &model.ExtendRouterConfig{ - RuleRouting: &apitraffic.RuleRoutingConfig{ - Rules: []*apitraffic.SubRuleRouting{ - { - Sources: []*apitraffic.SourceService{ - { - Service: "test-1", - Namespace: "test-1", + RuleRouting: &model.RuleRoutingConfigWrapper{ + RuleRouting: &apitraffic.RuleRoutingConfig{ + Rules: []*apitraffic.SubRuleRouting{ + { + Sources: []*apitraffic.SourceService{ + { + Service: "test-1", + Namespace: "test-1", + }, }, - }, - Destinations: []*apitraffic.DestinationGroup{ - { - Service: "test-1", - Namespace: "test-1", + Destinations: []*apitraffic.DestinationGroup{ + { + Service: "test-1", + Namespace: "test-1", + }, }, }, }, diff --git a/cache/service/service_query.go b/cache/service/service_query.go index 5cd8f6ea5..27e6f4ee4 100644 --- a/cache/service/service_query.go +++ b/cache/service/service_query.go @@ -18,6 +18,7 @@ package service import ( + "context" "sort" "strings" @@ -40,7 +41,7 @@ func (sc *serviceCache) forceUpdate() error { } // GetServicesByFilter 通过filter在缓存中进行服务过滤 -func (sc *serviceCache) GetServicesByFilter(serviceFilters *types.ServiceArgs, +func (sc *serviceCache) GetServicesByFilter(ctx context.Context, serviceFilters *types.ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) { if err := sc.forceUpdate(); err != nil { diff --git a/cache/service/service_test.go b/cache/service/service_test.go index 087c8555a..b9015e3ec 100644 --- a/cache/service/service_test.go +++ b/cache/service/service_test.go @@ -420,7 +420,7 @@ func TestServiceCache_GetServicesByFilter(t *testing.T) { svcArgs := &types.ServiceArgs{ EmptyCondition: true, } - amount, services, err := sc.GetServicesByFilter(svcArgs, instArgs, 0, 10) + amount, services, err := sc.GetServicesByFilter(context.Background(), svcArgs, instArgs, 0, 10) if err != nil { t.Fatal(err) } diff --git a/common/model/admin/admin.go b/common/model/admin/admin.go new file mode 100644 index 000000000..2d6bfd81b --- /dev/null +++ b/common/model/admin/admin.go @@ -0,0 +1,61 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package admin + +import ( + "time" + + connlimit "github.com/polarismesh/polaris/common/conn/limit" +) + +// LeaderElection leader election info +type LeaderElection struct { + ElectKey string + Host string + Ctime int64 + CreateTime time.Time + Mtime int64 + ModifyTime time.Time + Valid bool +} + +type ConnReq struct { + Protocol string + Host string + Port int + Amount int +} + +type ConnCountResp struct { + Protocol string + Total int32 + Host map[string]int32 +} + +type ConnStatsResp struct { + Protocol string + ActiveConnTotal int32 + StatsTotal int + StatsSize int + Stats []*connlimit.HostConnStat +} + +type ScopeLevel struct { + Name string + Level string +} diff --git a/common/model/auth.go b/common/model/auth.go deleted file mode 100644 index afb94a1e3..000000000 --- a/common/model/auth.go +++ /dev/null @@ -1,393 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package model - -import ( - "errors" - "fmt" - "strconv" - "time" - - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "google.golang.org/protobuf/types/known/wrapperspb" - - commontime "github.com/polarismesh/polaris/common/time" -) - -var ( - // ErrorNoUser 没有找到对应的用户 - ErrorNoUser error = errors.New("no such user") - - // ErrorNoUserGroup 没有找到对应的用户组 - ErrorNoUserGroup error = errors.New("no such user group") - - // ErrorNoNamespace 没有找到对应的命名空间 - ErrorNoNamespace error = errors.New("no such namespace") - - // ErrorNoService 没有找到对应的服务 - ErrorNoService error = errors.New("no such service") - - // ErrorWrongUsernameOrPassword 用户或者密码错误 - ErrorWrongUsernameOrPassword error = errors.New("name or password is wrong") - - // ErrorTokenNotExist token 不存在 - ErrorTokenNotExist error = errors.New("token not exist") - - // ErrorTokenInvalid 非法的 token - ErrorTokenInvalid error = errors.New("invalid token") - - // ErrorTokenDisabled token 已经被禁用 - ErrorTokenDisabled error = errors.New("token already disabled") -) - -func ConvertToErrCode(err error) apimodel.Code { - if errors.Is(err, ErrorTokenNotExist) { - return apimodel.Code_TokenNotExisted - } - - if errors.Is(err, ErrorTokenDisabled) { - return apimodel.Code_TokenDisabled - } - - return apimodel.Code_NotAllowedAccess -} - -const ( - OperatorRoleKey string = "operator_role" - OperatorPrincipalType string = "operator_principal" - OperatorIDKey string = "operator_id" - OperatorOwnerKey string = "operator_owner" - OperatorLinkStrategy string = "operator_link_strategy" - LinkUsersKey string = "link_users" - LinkGroupsKey string = "link_groups" - RemoveLinkUsersKey string = "remove_link_users" - RemoveLinkGroupsKey string = "remove_link_groups" - - TokenDetailInfoKey string = "TokenInfo" - TokenForUser string = "uid" - TokenForUserGroup string = "groupid" - - ResourceAttachmentKey string = "resource_attachment" -) - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[PrincipalUser-1] - _ = x[PrincipalGroup-2] -} - -const _PrincipalType_name = "PrincipalUserPrincipalGroup" - -var _PrincipalType_index = [...]uint8{0, 13, 27} - -func (i PrincipalType) String() string { - i -= 1 - if i < 0 || i >= PrincipalType(len(_PrincipalType_index)-1) { - return "PrincipalType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _PrincipalType_name[_PrincipalType_index[i]:_PrincipalType_index[i+1]] -} - -//go:generate stringer -type=PrincipalType -type PrincipalType int - -const ( - PrincipalUser PrincipalType = 1 - PrincipalGroup PrincipalType = 2 -) - -// CheckPrincipalType 检查鉴权策略成员角色信息 -func CheckPrincipalType(role int) error { - switch PrincipalType(role) { - case PrincipalUser: - return nil - case PrincipalGroup: - return nil - default: - return errors.New("invalid principal type") - } -} - -var ( - // PrincipalNames principal name map - PrincipalNames = map[PrincipalType]string{ - PrincipalUser: "user", - PrincipalGroup: "group", - } -) - -const ( - - // DefaultStrategySuffix 默认策略的名称前缀 - DefaultStrategySuffix string = "的默认策略" -) - -// BuildDefaultStrategyName 构建默认鉴权策略的名称信息 -func BuildDefaultStrategyName(role PrincipalType, name string) string { - if role == PrincipalUser { - return fmt.Sprintf("%s%s%s", "(用户) ", name, DefaultStrategySuffix) - } - return fmt.Sprintf("%s%s%s", "(用户组) ", name, DefaultStrategySuffix) -} - -// ResourceOperation 资源操作 -type ResourceOperation int16 - -const ( - // Read 只读动作 - Read ResourceOperation = 10 - // Create 创建动作 - Create ResourceOperation = 20 - // Modify 修改动作 - Modify ResourceOperation = 30 - // Delete 删除动作 - Delete ResourceOperation = 40 -) - -// BzModule 模块标识 -type BzModule int16 - -const ( - // UnknowModule 未知模块 - UnknowModule BzModule = iota - // CoreModule 核心模块 - CoreModule - // DiscoverModule 服务模块 - DiscoverModule - // ConfigModule 配置模块 - ConfigModule - // AuthModule 鉴权模块 - AuthModule - // MaintainModule 运维操作模块 - MaintainModule - // BootstrapModule 初始化模块 - BootstrapModule -) - -// UserRoleType 用户角色类型 -type UserRoleType int - -const ( - UnknownUserRole UserRoleType = -1 - AdminUserRole UserRoleType = 0 - OwnerUserRole UserRoleType = 20 - SubAccountUserRole UserRoleType = 50 -) - -var ( - UserRoleNames = map[UserRoleType]string{ - AdminUserRole: "admin", - OwnerUserRole: "main", - SubAccountUserRole: "sub", - } -) - -// ResourceEntry 资源最简单信息 -type ResourceEntry struct { - ID string - Owner string -} - -// User 用户 -type User struct { - ID string - Name string - Password string - Owner string - Source string - Mobile string - Email string - Type UserRoleType - Token string - TokenEnable bool - Valid bool - Comment string - CreateTime time.Time - ModifyTime time.Time -} - -func (u *User) ToSpec() *apisecurity.User { - if u == nil { - return nil - } - return &apisecurity.User{ - Id: wrapperspb.String(u.ID), - Name: wrapperspb.String(u.Name), - Password: wrapperspb.String(u.Password), - Owner: wrapperspb.String(u.Owner), - Source: wrapperspb.String(u.Source), - AuthToken: wrapperspb.String(u.Token), - TokenEnable: wrapperspb.Bool(u.TokenEnable), - Comment: wrapperspb.String(u.Comment), - UserType: wrapperspb.String(fmt.Sprintf("%d", u.Type)), - } -} - -// UserGroupDetail 用户组详细(带用户列表) -type UserGroupDetail struct { - *UserGroup - - // UserIds改为 map 的形式,加速查询 - UserIds map[string]struct{} -} - -// ToUserIdSlice 将用户ID Map 专为 slice -func (ugd *UserGroupDetail) ToUserIdSlice() []string { - uids := make([]string, 0, len(ugd.UserIds)) - for uid := range ugd.UserIds { - uids = append(uids, uid) - } - return uids -} - -func (ugd *UserGroupDetail) ListSpecUser() []*apisecurity.User { - users := make([]*apisecurity.User, 0, len(ugd.UserIds)) - for i := range ugd.UserIds { - users = append(users, &apisecurity.User{ - Id: wrapperspb.String(i), - }) - } - return users -} - -// ToSpec 将用户ID Map 专为 slice -func (ugd *UserGroupDetail) ToSpec() *apisecurity.UserGroup { - if ugd == nil { - return nil - } - return &apisecurity.UserGroup{ - Id: wrapperspb.String(ugd.ID), - Name: wrapperspb.String(ugd.Name), - Owner: wrapperspb.String(ugd.Owner), - AuthToken: wrapperspb.String(ugd.Token), - TokenEnable: wrapperspb.Bool(ugd.TokenEnable), - Comment: wrapperspb.String(ugd.Comment), - Ctime: wrapperspb.String(commontime.Time2String(ugd.CreateTime)), - Mtime: wrapperspb.String(commontime.Time2String(ugd.ModifyTime)), - Relation: &apisecurity.UserGroupRelation{ - GroupId: wrapperspb.String(ugd.ID), - Users: ugd.ListSpecUser(), - }, - UserCount: wrapperspb.UInt32(uint32(len(ugd.UserIds))), - } -} - -// UserGroup 用户组 -type UserGroup struct { - ID string - Name string - Owner string - Token string - TokenEnable bool - Valid bool - Comment string - CreateTime time.Time - ModifyTime time.Time -} - -// ModifyUserGroup 用户组修改 -type ModifyUserGroup struct { - ID string - Owner string - Token string - TokenEnable bool - Comment string - AddUserIds []string - RemoveUserIds []string -} - -// UserGroupRelation 用户-用户组关联关系具体信息 -type UserGroupRelation struct { - GroupID string - UserIds []string - CreateTime time.Time - ModifyTime time.Time -} - -// StrategyDetail 鉴权策略详细 -type StrategyDetail struct { - ID string - Name string - Action string - Comment string - Principals []Principal - Default bool - Owner string - Resources []StrategyResource - Valid bool - Revision string - CreateTime time.Time - ModifyTime time.Time -} - -// StrategyDetailCache 鉴权策略详细 -type StrategyDetailCache struct { - *StrategyDetail - UserPrincipal map[string]Principal - GroupPrincipal map[string]Principal -} - -// ModifyStrategyDetail 修改鉴权策略详细 -type ModifyStrategyDetail struct { - ID string - Name string - Action string - Comment string - AddPrincipals []Principal - RemovePrincipals []Principal - AddResources []StrategyResource - RemoveResources []StrategyResource - ModifyTime time.Time -} - -// Strategy 策略main信息 -type Strategy struct { - ID string - Name string - Principal string - Action string - Comment string - Owner string - Default bool - Valid bool - CreateTime time.Time - ModifyTime time.Time -} - -// StrategyResource 策略资源 -type StrategyResource struct { - StrategyID string - ResType int32 - ResID string -} - -// Principal 策略相关人 -type Principal struct { - StrategyID string - PrincipalID string - PrincipalRole PrincipalType -} - -type OperateResource struct { - ResOwner string - ResType apisecurity.ResourceType - ResID string -} diff --git a/common/model/acquire_context.go b/common/model/auth/acquire_context.go similarity index 97% rename from common/model/acquire_context.go rename to common/model/auth/acquire_context.go index 91bfae548..4a8c7a92d 100644 --- a/common/model/acquire_context.go +++ b/common/model/auth/acquire_context.go @@ -15,7 +15,7 @@ * specific language governing permissions and limitations under the License. */ -package model +package auth import ( "context" @@ -38,7 +38,7 @@ type AcquireContext struct { // Module 来自那个业务层(服务注册与服务治理、配置模块) module BzModule // Method 操作函数 - method string + method ServerFunctionName // Operation 本次操作涉及的动作 operation ResourceOperation // Resources 本次 @@ -96,7 +96,7 @@ func WithModule(module BzModule) acquireContextOption { } // WithMethod 本次操作函数名称 -func WithMethod(method string) acquireContextOption { +func WithMethod(method ServerFunctionName) acquireContextOption { return func(authCtx *AcquireContext) { authCtx.method = method } @@ -213,7 +213,7 @@ func (authCtx *AcquireContext) SetAttachment(key string, val interface{}) { } // GetMethod 获取本次请求涉及的操作函数 -func (authCtx *AcquireContext) GetMethod() string { +func (authCtx *AcquireContext) GetMethod() ServerFunctionName { return authCtx.method } @@ -262,5 +262,4 @@ type ResourceOpInfo struct { Namespace string ResourceName string ResourceID string - Operation ResourceOperation } diff --git a/common/model/auth/auth.go b/common/model/auth/auth.go index 7cd06219a..0f879dc17 100644 --- a/common/model/auth/auth.go +++ b/common/model/auth/auth.go @@ -19,17 +19,482 @@ package auth import ( "context" + "errors" + "fmt" + "strconv" + "time" - "github.com/polarismesh/polaris/common/model" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "google.golang.org/protobuf/types/known/wrapperspb" + + commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" ) +var ( + // ErrorNoUser 没有找到对应的用户 + ErrorNoUser error = errors.New("no such user") + + // ErrorNoUserGroup 没有找到对应的用户组 + ErrorNoUserGroup error = errors.New("no such user group") + // ErrorWrongUsernameOrPassword 用户或者密码错误 + ErrorWrongUsernameOrPassword error = errors.New("name or password is wrong") + // ErrorTokenNotExist token 不存在 + ErrorTokenNotExist error = errors.New("token not exist") + // ErrorTokenInvalid 非法的 token + ErrorTokenInvalid error = errors.New("invalid token") + // ErrorTokenDisabled token 已经被禁用 + ErrorTokenDisabled error = errors.New("token already disabled") +) + +func ConvertToErrCode(err error) apimodel.Code { + if errors.Is(err, ErrorTokenNotExist) { + return apimodel.Code_TokenNotExisted + } + + if errors.Is(err, ErrorTokenDisabled) { + return apimodel.Code_TokenDisabled + } + + return apimodel.Code_NotAllowedAccess +} + +const ( + OperatorRoleKey string = "operator_role" + OperatorIDKey string = "operator_id" + OperatorOwnerKey string = "operator_owner" + OperatorLinkStrategy string = "operator_link_strategy" + PrincipalKey string = "principal" + LinkUsersKey string = "link_users" + LinkGroupsKey string = "link_groups" + RemoveLinkUsersKey string = "remove_link_users" + RemoveLinkGroupsKey string = "remove_link_groups" + + TokenDetailInfoKey string = "TokenInfo" + TokenForUser string = "uid" + TokenForUserGroup string = "groupid" + + ResourceAttachmentKey string = "resource_attachment" +) + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[PrincipalUser-1] + _ = x[PrincipalGroup-2] +} + +const _PrincipalType_name = "PrincipalUserPrincipalGroup" + +var _PrincipalType_index = [...]uint8{0, 13, 27} + +func (i PrincipalType) String() string { + i -= 1 + if i < 0 || i >= PrincipalType(len(_PrincipalType_index)-1) { + return "PrincipalType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _PrincipalType_name[_PrincipalType_index[i]:_PrincipalType_index[i+1]] +} + +//go:generate stringer -type=PrincipalType +type PrincipalType int + +const ( + PrincipalUser PrincipalType = 1 + PrincipalGroup PrincipalType = 2 + PrincipalRole PrincipalType = 3 +) + +// CheckPrincipalType 检查鉴权策略成员角色信息 +func CheckPrincipalType(role int) error { + switch PrincipalType(role) { + case PrincipalUser: + return nil + case PrincipalGroup: + return nil + case PrincipalRole: + return nil + default: + return errors.New("invalid principal type") + } +} + +var ( + // PrincipalNames principal name map + PrincipalNames = map[PrincipalType]string{ + PrincipalUser: "user", + PrincipalGroup: "group", + PrincipalRole: "role", + } +) + +const ( + + // DefaultStrategySuffix 默认策略的名称前缀 + DefaultStrategySuffix string = "的默认策略" +) + +// BuildDefaultStrategyName 构建默认鉴权策略的名称信息 +func BuildDefaultStrategyName(role PrincipalType, name string) string { + if role == PrincipalUser { + return fmt.Sprintf("%s%s%s", "(用户) ", name, DefaultStrategySuffix) + } + return fmt.Sprintf("%s%s%s", "(用户组) ", name, DefaultStrategySuffix) +} + +// ResourceOperation 资源操作 +type ResourceOperation int16 + +const ( + // Read 只读动作 + Read ResourceOperation = 10 + // Create 创建动作 + Create ResourceOperation = 20 + // Modify 修改动作 + Modify ResourceOperation = 30 + // Delete 删除动作 + Delete ResourceOperation = 40 +) + +// BzModule 模块标识 +type BzModule int16 + +const ( + // UnknowModule 未知模块 + UnknowModule BzModule = iota + // CoreModule 核心模块 + CoreModule + // DiscoverModule 服务模块 + DiscoverModule + // ConfigModule 配置模块 + ConfigModule + // AuthModule 鉴权模块 + AuthModule + // MaintainModule 运维操作模块 + MaintainModule + // BootstrapModule 初始化模块 + BootstrapModule +) + +// UserRoleType 用户角色类型 +type UserRoleType int + +const ( + UnknownUserRole UserRoleType = -1 + AdminUserRole UserRoleType = 0 + OwnerUserRole UserRoleType = 20 + SubAccountUserRole UserRoleType = 50 +) + +var ( + UserRoleNames = map[UserRoleType]string{ + AdminUserRole: "admin", + OwnerUserRole: "main", + SubAccountUserRole: "sub", + } +) + +// ResourceEntry 资源最简单信息 +type ResourceEntry struct { + Type apisecurity.ResourceType + ID string + Owner string + Metadata map[string]string +} + +// User 用户 +type User struct { + ID string + Name string + Password string + Owner string + Source string + Mobile string + Email string + Type UserRoleType + Metadata map[string]string + Token string + TokenEnable bool + Valid bool + Comment string + CreateTime time.Time + ModifyTime time.Time +} + +func (u *User) GetToken() string { + return u.Token +} + +func (u *User) Disable() bool { + return !u.TokenEnable +} + +func (u *User) OwnerID() string { + return u.Owner +} + +func (u *User) SelfID() string { + return u.ID +} + +func (u *User) ToSpec() *apisecurity.User { + if u == nil { + return nil + } + return &apisecurity.User{ + Id: wrapperspb.String(u.ID), + Name: wrapperspb.String(u.Name), + Password: wrapperspb.String(u.Password), + Owner: wrapperspb.String(u.Owner), + Source: wrapperspb.String(u.Source), + AuthToken: wrapperspb.String(u.Token), + TokenEnable: wrapperspb.Bool(u.TokenEnable), + Comment: wrapperspb.String(u.Comment), + UserType: wrapperspb.String(fmt.Sprintf("%d", u.Type)), + } +} + +// UserGroupDetail 用户组详细(带用户列表) +type UserGroupDetail struct { + *UserGroup + + // UserIds改为 map 的形式,加速查询 + UserIds map[string]struct{} +} + +// ToUserIdSlice 将用户ID Map 专为 slice +func (ugd *UserGroupDetail) ToUserIdSlice() []string { + uids := make([]string, 0, len(ugd.UserIds)) + for uid := range ugd.UserIds { + uids = append(uids, uid) + } + return uids +} + +func (ugd *UserGroupDetail) ListSpecUser() []*apisecurity.User { + users := make([]*apisecurity.User, 0, len(ugd.UserIds)) + for i := range ugd.UserIds { + users = append(users, &apisecurity.User{ + Id: wrapperspb.String(i), + }) + } + return users +} + +// ToSpec 将用户ID Map 专为 slice +func (ugd *UserGroupDetail) ToSpec() *apisecurity.UserGroup { + if ugd == nil { + return nil + } + return &apisecurity.UserGroup{ + Id: wrapperspb.String(ugd.ID), + Name: wrapperspb.String(ugd.Name), + Owner: wrapperspb.String(ugd.Owner), + AuthToken: wrapperspb.String(ugd.Token), + TokenEnable: wrapperspb.Bool(ugd.TokenEnable), + Comment: wrapperspb.String(ugd.Comment), + Ctime: wrapperspb.String(commontime.Time2String(ugd.CreateTime)), + Mtime: wrapperspb.String(commontime.Time2String(ugd.ModifyTime)), + Relation: &apisecurity.UserGroupRelation{ + GroupId: wrapperspb.String(ugd.ID), + Users: ugd.ListSpecUser(), + }, + UserCount: wrapperspb.UInt32(uint32(len(ugd.UserIds))), + } +} + +// UserGroup 用户组 +type UserGroup struct { + ID string + Name string + Owner string + Token string + TokenEnable bool + Metadata map[string]string + Valid bool + Comment string + Source string + CreateTime time.Time + ModifyTime time.Time +} + +func (u *UserGroup) GetToken() string { + return u.Token +} + +func (u *UserGroup) Disable() bool { + return !u.TokenEnable +} + +func (u *UserGroup) OwnerID() string { + return u.Owner +} + +func (u *UserGroup) SelfID() string { + return u.ID +} + +// ModifyUserGroup 用户组修改 +type ModifyUserGroup struct { + ID string + Owner string + Token string + TokenEnable bool + Comment string + Metadata map[string]string + AddUserIds []string + RemoveUserIds []string +} + +// UserGroupRelation 用户-用户组关联关系具体信息 +type UserGroupRelation struct { + GroupID string + UserIds []string + CreateTime time.Time + ModifyTime time.Time +} + +type Condition struct { + Key string + Value string + CompareFunc string +} + +// StrategyDetail 鉴权策略详细 +type StrategyDetail struct { + ID string + Name string + // Action: 只有 allow 以及 deny + Action string + Comment string + Default bool + Owner string + // CalleeMethods 允许访问的服务端接口 + CalleeMethods []string + Resources []StrategyResource + Conditions []Condition + Principals []Principal + Valid bool + Revision string + Metadata map[string]string + CreateTime time.Time + ModifyTime time.Time +} + +func (s *StrategyDetail) IsDeny() bool { + return s.Action == apisecurity.AuthAction_DENY.String() +} + +func NewPolicyDetailCache(d *StrategyDetail) *PolicyDetailCache { + users := make(map[string]Principal, 0) + groups := make(map[string]Principal, 0) + + for index := range d.Principals { + principal := d.Principals[index] + if principal.PrincipalType == PrincipalUser { + users[principal.PrincipalID] = principal + } else { + groups[principal.PrincipalID] = principal + } + } + + resources := map[apisecurity.ResourceType]*utils.SyncSet[string]{} + for index := range d.Resources { + resource := d.Resources[index] + resType := apisecurity.ResourceType(resource.ResType) + if _, ok := resources[resType]; !ok { + resources[resType] = utils.NewSyncSet[string]() + } + resources[resType].Add(resource.ResID) + } + + return &PolicyDetailCache{ + StrategyDetail: d, + UserPrincipal: users, + GroupPrincipal: groups, + ResourceDict: resources, + } +} + +// PolicyDetailCache 鉴权策略详细 +type PolicyDetailCache struct { + *StrategyDetail + UserPrincipal map[string]Principal + GroupPrincipal map[string]Principal + ResourceDict map[apisecurity.ResourceType]*utils.SyncSet[string] +} + +// ModifyStrategyDetail 修改鉴权策略详细 +type ModifyStrategyDetail struct { + ID string + Name string + Action string + Comment string + Metadata map[string]string + AddPrincipals []Principal + RemovePrincipals []Principal + AddResources []StrategyResource + RemoveResources []StrategyResource + ModifyTime time.Time +} + +// Strategy 策略main信息 +type Strategy struct { + ID string + Name string + Principal string + Action string + Comment string + Owner string + Default bool + Valid bool + CreateTime time.Time + ModifyTime time.Time +} + +// StrategyResource 策略资源 +type StrategyResource struct { + StrategyID string + ResType int32 + ResID string +} + +// Principal 策略相关人 +type Principal struct { + StrategyID string + Name string + Owner string + PrincipalID string + PrincipalType PrincipalType +} + +func (p Principal) String() string { + return fmt.Sprintf("%s/%s", p.PrincipalType.String(), p.PrincipalID) +} + // ParseUserRole 从ctx中解析用户角色 -func ParseUserRole(ctx context.Context) model.UserRoleType { +func ParseUserRole(ctx context.Context) UserRoleType { if ctx == nil { - return model.SubAccountUserRole + return SubAccountUserRole } - role, _ := ctx.Value(utils.ContextUserRoleIDKey).(model.UserRoleType) + role, _ := ctx.Value(utils.ContextUserRoleIDKey).(UserRoleType) return role } + +type Role struct { + ID string + Name string + Owner string + Source string + Type string + Metadata map[string]string + Valid bool + Comment string + CreateTime time.Time + ModifyTime time.Time + Users []*User + UserGroups []*UserGroup +} diff --git a/common/model/auth/const.go b/common/model/auth/const.go new file mode 100644 index 000000000..88ecf139d --- /dev/null +++ b/common/model/auth/const.go @@ -0,0 +1,244 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" +) + +type ServerFunctionName string + +// SDK 接口 +const ( + // 注册发现接口 + RegisterInstance ServerFunctionName = "RegisterInstance" + DeregisterInstance ServerFunctionName = "DeregisterInstance" + ReportServiceContract ServerFunctionName = "ReportServiceContract" + DiscoverServices ServerFunctionName = "DiscoverServices" + DiscoverInstances ServerFunctionName = "DiscoverInstances" + UpdateInstance ServerFunctionName = "UpdateInstance" + + // 服务治理接口 + DiscoverRouterRule ServerFunctionName = "DiscoverRouterRule" + DiscoverRateLimitRule ServerFunctionName = "DiscoverRateLimitRule" + DiscoverCircuitBreakerRule ServerFunctionName = "DiscoverCircuitBreakerRule" + DiscoverFaultDetectRule ServerFunctionName = "DiscoverFaultDetectRule" + DiscoverServiceContract ServerFunctionName = "DiscoverServiceContract" + DiscoverLaneRule ServerFunctionName = "DiscoverLaneRule" + + // 配置接口 + DiscoverConfigFile ServerFunctionName = "DiscoverConfigFile" + WatchConfigFile ServerFunctionName = "WatchConfigFile" + DiscoverConfigFileNames ServerFunctionName = "DiscoverConfigFileNames" + DiscoverConfigGroups ServerFunctionName = "DiscoverConfigGroups" +) + +// 命名空间 +const ( + CreateNamespace ServerFunctionName = "CreateNamespace" + CreateNamespaces ServerFunctionName = "CreateNamespaces" + DeleteNamespace ServerFunctionName = "DeleteNamespace" + DeleteNamespaces ServerFunctionName = "DeleteNamespaces" + UpdateNamespaces ServerFunctionName = "UpdateNamespaces" + UpdateNamespaceToken ServerFunctionName = "UpdateNamespaceToken" + DescribeNamespaces ServerFunctionName = "DescribeNamespaces" + DescribeNamespaceToken ServerFunctionName = "DescribeNamespaceToken" +) + +// 服务/服务别名 +const ( + CreateServices ServerFunctionName = "CreateServices" + DeleteServices ServerFunctionName = "DeleteServices" + UpdateServices ServerFunctionName = "UpdateServices" + UpdateServiceToken ServerFunctionName = "UpdateServiceToken" + DescribeAllServices ServerFunctionName = "DescribeAllServices" + DescribeServices ServerFunctionName = "DescribeServices" + DescribeServicesCount ServerFunctionName = "DescribeServicesCount" + DescribeServiceToken ServerFunctionName = "DescribeServiceToken" + DescribeServiceOwner ServerFunctionName = "DescribeServiceOwner" + + CreateServiceAlias ServerFunctionName = "CreateServiceAlias" + DeleteServiceAliases ServerFunctionName = "DeleteServiceAliases" + UpdateServiceAlias ServerFunctionName = "UpdateServiceAlias" + DescribeServiceAliases ServerFunctionName = "DescribeServiceAliases" +) + +// 服务接口定义 +const ( + CreateServiceContracts ServerFunctionName = "CreateServiceContracts" + DescribeServiceContracts ServerFunctionName = "DescribeServiceContracts" + DescribeServiceContractVersions ServerFunctionName = "DescribeServiceContractVersions" + DeleteServiceContracts ServerFunctionName = "DeleteServiceContracts" + + CreateServiceContractInterfaces ServerFunctionName = "CreateServiceContractInterfaces" + AppendServiceContractInterfaces ServerFunctionName = "AppendServiceContractInterfaces" + DeleteServiceContractInterfaces ServerFunctionName = "DeleteServiceContractInterfaces" +) + +// 服务实例 +const ( + CreateInstances ServerFunctionName = "CreateInstances" + DeleteInstances ServerFunctionName = "DeleteInstances" + DeleteInstancesByHost ServerFunctionName = "DeleteInstancesByHost" + UpdateInstances ServerFunctionName = "UpdateInstances" + UpdateInstancesIsolate ServerFunctionName = "UpdateInstancesIsolate" + DescribeInstances ServerFunctionName = "DescribeInstances" + DescribeInstancesCount ServerFunctionName = "DescribeInstancesCount" + DescribeInstanceLabels ServerFunctionName = "DescribeInstanceLabels" + CleanInstance ServerFunctionName = "CleanInstance" + BatchCleanInstances ServerFunctionName = "BatchCleanInstances" + DescribeInstanceLastHeartbeat ServerFunctionName = "DescribeInstanceLastHeartbeat" +) + +// 配置 +const ( + // 配置分组 + CreateConfigFileGroup ServerFunctionName = "CreateConfigFileGroup" + DeleteConfigFileGroup ServerFunctionName = "DeleteConfigFileGroup" + UpdateConfigFileGroup ServerFunctionName = "UpdateConfigFileGroup" + DescribeConfigFileGroups ServerFunctionName = "DescribeConfigFileGroups" + + // 配置文件 + PublishConfigFile ServerFunctionName = "PublishConfigFile" + CreateConfigFile ServerFunctionName = "CreateConfigFile" + UpdateConfigFile ServerFunctionName = "UpdateConfigFile" + DeleteConfigFile ServerFunctionName = "DeleteConfigFile" + DescribeConfigFileRichInfo ServerFunctionName = "DescribeConfigFileRichInfo" + DescribeConfigFiles ServerFunctionName = "DescribeConfigFiles" + BatchDeleteConfigFiles ServerFunctionName = "BatchDeleteConfigFiles" + ExportConfigFiles ServerFunctionName = "ExportConfigFiles" + ImportConfigFiles ServerFunctionName = "ImportConfigFiles" + + // 配置发布历史 + DescribeConfigFileReleaseHistories ServerFunctionName = "DescribeConfigFileReleaseHistories" + + // 配置发布 + RollbackConfigFileReleases ServerFunctionName = "RollbackConfigFileReleases" + DeleteConfigFileReleases ServerFunctionName = "DeleteConfigFileReleases" + StopGrayConfigFileReleases ServerFunctionName = "StopGrayConfigFileReleases" + DescribeConfigFileRelease ServerFunctionName = "DescribeConfigFileRelease" + DescribeConfigFileReleases ServerFunctionName = "DescribeConfigFileReleases" + DescribeConfigFileReleaseVersions ServerFunctionName = "DescribeConfigFileReleaseVersions" + UpsertAndReleaseConfigFile ServerFunctionName = "UpsertAndReleaseConfigFile" + + // 配置模板 + DescribeAllConfigFileTemplates ServerFunctionName = "DescribeAllConfigFileTemplates" + DescribeConfigFileTemplate ServerFunctionName = "DescribeConfigFileTemplate" + CreateConfigFileTemplate ServerFunctionName = "CreateConfigFileTemplate" +) + +// 路由 +const ( + CreateRouteRules ServerFunctionName = "CreateRouteRules" + DeleteRouteRules ServerFunctionName = "DeleteRouteRules" + UpdateRouteRules ServerFunctionName = "UpdateRouteRules" + EnableRouteRules ServerFunctionName = "EnableRouteRules" + DescribeRouteRules ServerFunctionName = "DescribeRouteRules" +) + +// 限流 +const ( + CreateRateLimitRules ServerFunctionName = "CreateRateLimitRules" + DeleteRateLimitRules ServerFunctionName = "DeleteRateLimitRules" + UpdateRateLimitRules ServerFunctionName = "UpdateRateLimitRules" + EnableRateLimitRules ServerFunctionName = "EnableRateLimitRules" + DescribeRateLimitRules ServerFunctionName = "DescribeRateLimitRules" +) + +// 熔断 +const ( + CreateCircuitBreakerRules ServerFunctionName = "CreateCircuitBreakerRules" + DeleteCircuitBreakerRules ServerFunctionName = "DeleteCircuitBreakerRules" + EnableCircuitBreakerRules ServerFunctionName = "EnableCircuitBreakerRules" + UpdateCircuitBreakerRules ServerFunctionName = "UpdateCircuitBreakerRules" + DescribeCircuitBreakerRules ServerFunctionName = "DescribeCircuitBreakerRules" +) + +// 主动探测 +const ( + CreateFaultDetectRules ServerFunctionName = "CreateFaultDetectRules" + DeleteFaultDetectRules ServerFunctionName = "DeleteFaultDetectRules" + EnableFaultDetectRules ServerFunctionName = "EnableFaultDetectRules" + UpdateFaultDetectRules ServerFunctionName = "UpdateFaultDetectRules" + DescribeFaultDetectRules ServerFunctionName = "DescribeFaultDetectRules" +) + +// 全链路灰度 +const () + +// 用户/用户组 +const ( + // 用户 + CreateUsers ServerFunctionName = "CreateUsers" + DeleteUsers ServerFunctionName = "DeleteUsers" + DescribeUsers ServerFunctionName = "DescribeUsers" + DescribeUserToken ServerFunctionName = "DescribeUserToken" + EnableUserToken ServerFunctionName = "EnableUserToken" + ResetUserToken ServerFunctionName = "ResetUserToken" + UpdateUser ServerFunctionName = "UpdateUser" + UpdateUserPassword ServerFunctionName = "UpdateUserPassword" + + // 用户组 + CreateUserGroup ServerFunctionName = "CreateUserGroup" + UpdateUserGroups ServerFunctionName = "UpdateUserGroups" + DeleteUserGroups ServerFunctionName = "DeleteUserGroups" + DescribeUserGroups ServerFunctionName = "DescribeUserGroups" + DescribeUserGroupDetail ServerFunctionName = "DescribeUserGroupDetail" + DescribeUserGroupToken ServerFunctionName = "DescribeUserGroupToken" + EnableUserGroupToken ServerFunctionName = "EnableUserGroupToken" + ResetUserGroupToken ServerFunctionName = "ResetUserGroupToken" +) + +// 策略/角色 +const ( + // 策略 + CreateAuthPolicy ServerFunctionName = "CreateAuthPolicy" + UpdateAuthPolicies ServerFunctionName = "UpdateAuthPolicies" + DeleteAuthPolicies ServerFunctionName = "DeleteAuthPolicies" + DescribeAuthPolicies ServerFunctionName = "DescribeAuthPolicies" + DescribeAuthPolicyDetail ServerFunctionName = "DescribeAuthPolicyDetail" + DescribePrincipalResources ServerFunctionName = "DescribePrincipalResources" + + // 角色 + CreateAuthRoles ServerFunctionName = "CreateAuthRoles" + UpdateAuthRoles ServerFunctionName = "UpdateAuthRoles" + DeleteAuthRoles ServerFunctionName = "DeleteAuthRoles" + DescribeAuthRoles ServerFunctionName = "DescribeAuthRoles" + DescribeAuthRoleDetail ServerFunctionName = "DescribeAuthRoleDetail" +) + +// 运维接口 +const ( + DescribeServerConnections ServerFunctionName = "DescribeServerConnections" + DescribeServerConnStats ServerFunctionName = "DescribeServerConnStats" + CloseConnections ServerFunctionName = "CloseConnections" + FreeOSMemory ServerFunctionName = "FreeOSMemory" + DescribeLeaderElections ServerFunctionName = "DescribeLeaderElections" + ReleaseLeaderElection ServerFunctionName = "ReleaseLeaderElection" + DescribeGetLogOutputLevel ServerFunctionName = "DescribeGetLogOutputLevel" + UpdateLogOutputLevel ServerFunctionName = "UpdateLogOutputLevel" + DescribeCMDBInfo ServerFunctionName = "DescribeCMDBInfo" +) + +var ( + SearchTypeMapping = map[string]apisecurity.ResourceType{ + "0": apisecurity.ResourceType_Namespaces, + "1": apisecurity.ResourceType_Services, + "2": apisecurity.ResourceType_ConfigGroups, + } +) diff --git a/common/model/auth/container.go b/common/model/auth/container.go new file mode 100644 index 000000000..4833b3108 --- /dev/null +++ b/common/model/auth/container.go @@ -0,0 +1,98 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + "github.com/polarismesh/polaris/common/utils" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" +) + +// PrincipalResourceContainer principal 资源容器 +type PrincipalResourceContainer struct { + denyResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] + allowResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] +} + +// NewPrincipalResourceContainer 创建 PrincipalResourceContainer 对象 +func NewPrincipalResourceContainer() *PrincipalResourceContainer { + return &PrincipalResourceContainer{ + allowResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), + denyResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), + } +} + +// Hint 返回该资源命中的策略类型, 优先匹配 deny, 其次匹配 allow, 否则返回 deny +func (p *PrincipalResourceContainer) Hint(rt apisecurity.ResourceType, resId string) (apisecurity.AuthAction, bool) { + ids, ok := p.denyResources.Load(rt) + if ok { + if ids.Contains(resId) { + return apisecurity.AuthAction_DENY, true + } + } + ids, ok = p.allowResources.Load(rt) + if ok { + if ids.Contains(resId) { + return apisecurity.AuthAction_ALLOW, true + } + } + return 0, false +} + +// SaveAllowResource 保存允许的资源 +func (p *PrincipalResourceContainer) SaveAllowResource(r StrategyResource) { + p.saveResource(p.allowResources, r) +} + +// DelAllowResource 删除允许的资源 +func (p *PrincipalResourceContainer) DelAllowResource(r StrategyResource) { + p.delResource(p.allowResources, r) +} + +// SaveDenyResource 保存拒绝的资源 +func (p *PrincipalResourceContainer) SaveDenyResource(r StrategyResource) { + p.saveResource(p.denyResources, r) +} + +// DelDenyResource 删除拒绝的资源 +func (p *PrincipalResourceContainer) DelDenyResource(r StrategyResource) { + p.delResource(p.denyResources, r) +} + +func (p *PrincipalResourceContainer) saveResource( + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], res StrategyResource) { + + resType := apisecurity.ResourceType(res.ResType) + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { + return utils.NewRefSyncSet[string]() + }) + + ids, _ := container.Load(resType) + ids.Add(res.ResID) +} + +func (p *PrincipalResourceContainer) delResource( + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], r StrategyResource) { + + resType := apisecurity.ResourceType(r.ResType) + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { + return utils.NewRefSyncSet[string]() + }) + + ids, _ := container.Load(resType) + ids.Remove(r.ResID) +} diff --git a/common/model/naming.go b/common/model/naming.go index b949a6c4a..aa3af6ab9 100644 --- a/common/model/naming.go +++ b/common/model/naming.go @@ -19,6 +19,7 @@ package model import ( "context" + "errors" "fmt" "strconv" "strings" @@ -35,6 +36,13 @@ import ( "github.com/polarismesh/polaris/common/utils" ) +var ( + // ErrorNoNamespace 没有找到对应的命名空间 + ErrorNoNamespace error = errors.New("no such namespace") + // ErrorNoService 没有找到对应的服务 + ErrorNoService error = errors.New("no such service") +) + func ExportToMap(exportTo []*wrappers.StringValue) map[string]struct{} { ret := make(map[string]struct{}) for _, v := range exportTo { @@ -54,6 +62,8 @@ type Namespace struct { ModifyTime time.Time // ServiceExportTo 服务可见性设置 ServiceExportTo map[string]struct{} + // Metadata 命名空间元数据 + Metadata map[string]string } func (n *Namespace) ListServiceExportTo() []*wrappers.StringValue { diff --git a/common/model/routing.go b/common/model/routing.go index 1fd9e4d97..85a3b9ff8 100644 --- a/common/model/routing.go +++ b/common/model/routing.go @@ -20,7 +20,6 @@ package model import ( "encoding/json" "fmt" - "sort" "strings" "time" @@ -35,6 +34,13 @@ import ( "github.com/polarismesh/polaris/common/utils" ) +type TrafficDirection string + +const ( + TrafficDirection_INBOUND TrafficDirection = "TrafficDirection_INBOUND" + TrafficDirection_OUTBOUND TrafficDirection = "TrafficDirection_OUTBOUND" +) + const ( // V2RuleIDKey v2 版本的规则路由 ID V2RuleIDKey = "__routing_v2_id__" @@ -55,14 +61,18 @@ var ( RuleRoutingTypeUrl string // MetaRoutingTypeUrl 记录 anypb.Any 中关于 MetadataRoutingConfig 的 url 信息 MetaRoutingTypeUrl string + // NearbyRoutingTypeUrl 记录 anypb.Any 中关于 NearbyRoutingConfig 的 url 信息 + NearbyRoutingTypeUrl string ) func init() { ruleAny, _ := ptypes.MarshalAny(&apitraffic.RuleRoutingConfig{}) metaAny, _ := ptypes.MarshalAny(&apitraffic.MetadataRoutingConfig{}) + nearbyAny, _ := ptypes.MarshalAny(&apitraffic.NearbyRoutingConfig{}) RuleRoutingTypeUrl = ruleAny.GetTypeUrl() MetaRoutingTypeUrl = metaAny.GetTypeUrl() + NearbyRoutingTypeUrl = nearbyAny.GetTypeUrl() } /* @@ -91,9 +101,11 @@ type ExtendRouterConfig struct { // MetadataRouting 元数据路由配置 MetadataRouting *apitraffic.MetadataRoutingConfig // RuleRouting 规则路由配置 - RuleRouting *apitraffic.RuleRoutingConfig - // ExtendInfo 额外信息数据 - ExtendInfo map[string]string + RuleRouting *RuleRoutingConfigWrapper + // NearbyRouting 就近路由规则数据 + NearbyRouting *apitraffic.NearbyRoutingConfig + // Metadata . + Metadata map[string]string } // ToApi Turn to API object @@ -103,13 +115,19 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { err error ) - if r.GetRoutingPolicy() == apitraffic.RoutingPolicy_MetadataPolicy { + switch r.GetRoutingPolicy() { + case apitraffic.RoutingPolicy_RulePolicy: + anyValue, err = ptypes.MarshalAny(r.NearbyRouting) + if err != nil { + return nil, err + } + case apitraffic.RoutingPolicy_MetadataPolicy: anyValue, err = ptypes.MarshalAny(r.MetadataRouting) if err != nil { return nil, err } - } else { - anyValue, err = ptypes.MarshalAny(r.RuleRouting) + default: + anyValue, err = ptypes.MarshalAny(r.RuleRouting.RuleRouting) if err != nil { return nil, err } @@ -137,6 +155,13 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { return rule, nil } +type RuleRoutingConfigWrapper struct { + Caller ServiceKey + Callee ServiceKey + // RuleRouting 规则路由配置 + RuleRouting *apitraffic.RuleRoutingConfig +} + // RouterConfig Routing rules type RouterConfig struct { // ID The unique id of the rules @@ -198,8 +223,7 @@ func (r *RouterConfig) ToExpendRoutingConfig() (*ExtendRouterConfig, error) { if err = utils.UnmarshalFromJsonString(rule, configText); nil != err { return nil, err } - parseSubRouteRule(rule) - ret.RuleRouting = rule + ret.RuleRouting = parseSubRouteRule(rule) break case apitraffic.RoutingPolicy_MetadataPolicy: rule := &apitraffic.MetadataRoutingConfig{} @@ -208,6 +232,13 @@ func (r *RouterConfig) ToExpendRoutingConfig() (*ExtendRouterConfig, error) { } ret.MetadataRouting = rule break + case apitraffic.RoutingPolicy_NearbyPolicy: + rule := &apitraffic.NearbyRoutingConfig{} + if err = utils.UnmarshalFromJsonString(rule, configText); nil != err { + return nil, err + } + ret.NearbyRouting = rule + break } return ret, nil } @@ -231,8 +262,7 @@ func (r *RouterConfig) parseBinaryAnyMessage( if err := unmarshalToAny(anyMsg, rule); nil != err { return err } - parseSubRouteRule(rule) - ret.RuleRouting = rule + ret.RuleRouting = parseSubRouteRule(rule) case apitraffic.RoutingPolicy_MetadataPolicy: rule := &apitraffic.MetadataRoutingConfig{} anyMsg := &anypb.Any{ @@ -243,6 +273,16 @@ func (r *RouterConfig) parseBinaryAnyMessage( return err } ret.MetadataRouting = rule + case apitraffic.RoutingPolicy_NearbyPolicy: + rule := &apitraffic.NearbyRoutingConfig{} + anyMsg := &anypb.Any{ + TypeUrl: NearbyRoutingTypeUrl, + Value: []byte(r.Config), + } + if err := unmarshalToAny(anyMsg, rule); nil != err { + return err + } + ret.NearbyRouting = rule } return nil } @@ -297,13 +337,17 @@ func ParseRouteRuleAnyToMessage(policy apitraffic.RoutingPolicy, anyMessage *any return nil, err } break - default: + case apitraffic.RoutingPolicy_NearbyPolicy: + rule = &apitraffic.NearbyRoutingConfig{} + if err := unmarshalToAny(anyMessage, rule); err != nil { + return nil, err + } break } return rule, nil } -func parseSubRouteRule(ruleRouting *apitraffic.RuleRoutingConfig) { +func parseSubRouteRule(ruleRouting *apitraffic.RuleRoutingConfig) *RuleRoutingConfigWrapper { if len(ruleRouting.Rules) == 0 { subRule := &apitraffic.SubRuleRouting{ Name: "", @@ -327,6 +371,28 @@ func parseSubRouteRule(ruleRouting *apitraffic.RuleRoutingConfig) { ruleRouting.Destinations = nil ruleRouting.Sources = nil } + + wrapper := &RuleRoutingConfigWrapper{ + RuleRouting: ruleRouting, + } + + for i := range ruleRouting.Rules { + item := ruleRouting.Rules[i] + source := item.Sources[0] + destination := item.Destinations[0] + + wrapper.Caller = ServiceKey{ + Namespace: source.Namespace, + Name: source.Service, + } + wrapper.Callee = ServiceKey{ + Namespace: destination.Namespace, + Name: destination.Service, + } + break + } + + return wrapper } const ( @@ -372,259 +438,6 @@ func RoutingConfigV1ToAPI(req *RoutingConfig, service string, namespace string) return out, nil } -// CompositeRoutingV1AndV2 The routing rules of the V1 version and the rules of the V2 version -func CompositeRoutingV1AndV2(v1rule *apitraffic.Routing, level1, level2, - level3 []*ExtendRouterConfig) (*apitraffic.Routing, []string) { - sort.Slice(level1, func(i, j int) bool { - return CompareRoutingV2(level1[i], level1[j]) - }) - - sort.Slice(level2, func(i, j int) bool { - return CompareRoutingV2(level2[i], level2[j]) - }) - - sort.Slice(level3, func(i, j int) bool { - return CompareRoutingV2(level3[i], level3[j]) - }) - - level1inRoutes, level1outRoutes, level1Revisions := - BuildV1RoutesFromV2(v1rule.Service.Value, v1rule.Namespace.Value, level1) - level2inRoutes, level2outRoutes, level2Revisions := - BuildV1RoutesFromV2(v1rule.Service.Value, v1rule.Namespace.Value, level2) - level3inRoutes, level3outRoutes, level3Revisions := - BuildV1RoutesFromV2(v1rule.Service.Value, v1rule.Namespace.Value, level3) - - inBounds := v1rule.GetInbounds() - outBounds := v1rule.GetOutbounds() - - // Processing inbounds rules,level1 cache -> v1rules -> level2 cache -> level3 cache - if len(level1inRoutes) > 0 { - v1rule.Inbounds = append(level1inRoutes, inBounds...) - } - if len(level2inRoutes) > 0 { - v1rule.Inbounds = append(v1rule.Inbounds, level2inRoutes...) - } - if len(level3inRoutes) > 0 { - v1rule.Inbounds = append(v1rule.Inbounds, level3inRoutes...) - } - - // Processing OutBounds rules,level1 cache -> v1rules -> level2 cache -> level3 cache - if len(level1outRoutes) > 0 { - v1rule.Outbounds = append(level1outRoutes, outBounds...) - } - if len(level2outRoutes) > 0 { - v1rule.Outbounds = append(v1rule.Outbounds, level2outRoutes...) - } - if len(level3outRoutes) > 0 { - v1rule.Outbounds = append(v1rule.Outbounds, level3outRoutes...) - } - - revisions := make([]string, 0, 1+len(level1Revisions)+len(level2Revisions)+len(level3Revisions)) - revisions = append(revisions, v1rule.GetRevision().GetValue()) - if len(level1Revisions) > 0 { - revisions = append(revisions, level1Revisions...) - } - if len(level2Revisions) > 0 { - revisions = append(revisions, level2Revisions...) - } - if len(level3Revisions) > 0 { - revisions = append(revisions, level3Revisions...) - } - - return v1rule, revisions -} - -// BuildV1RoutesFromV2 According to the routing rules of the V2 version, it is adapted to the V1 version -// of the routing rules. -// return inBound outBound revisions -func BuildV1RoutesFromV2(service, namespace string, - entries []*ExtendRouterConfig) ([]*apitraffic.Route, []*apitraffic.Route, []string) { - if len(entries) == 0 { - return []*apitraffic.Route{}, []*apitraffic.Route{}, []string{} - } - - revisions := make([]string, 0, len(entries)) - outRoutes := make([]*apitraffic.Route, 0, 8) - inRoutes := make([]*apitraffic.Route, 0, 8) - for i := range entries { - if !entries[i].Enable { - continue - } - outRoutes = append(outRoutes, BuildOutBoundsFromV2(service, namespace, entries[i])...) - inRoutes = append(inRoutes, BuildInBoundsFromV2(service, namespace, entries[i])...) - revisions = append(revisions, entries[i].Revision) - } - - return inRoutes, outRoutes, revisions -} - -// BuildOutBoundsFromV2 According to the routing rules of the V2 version, it is adapted to the -// outbounds in the routing rule of V1 version -func BuildOutBoundsFromV2(service, namespace string, item *ExtendRouterConfig) []*apitraffic.Route { - if item.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { - return []*apitraffic.Route{} - } - - var find bool - - matchService := func(source *apitraffic.SourceService) bool { - if source.Service == service && source.Namespace == namespace { - return true - } - if source.Namespace == namespace && source.Service == MatchAll { - return true - } - if source.Namespace == MatchAll && source.Service == MatchAll { - return true - } - return false - } - - routes := make([]*apitraffic.Route, 0, 8) - for i := range item.RuleRouting.Rules { - subRule := item.RuleRouting.Rules[i] - sources := item.RuleRouting.Rules[i].Sources - v1sources := make([]*apitraffic.Source, 0, len(sources)) - for i := range sources { - if matchService(sources[i]) { - find = true - entry := &apitraffic.Source{ - Service: utils.NewStringValue(service), - Namespace: utils.NewStringValue(namespace), - } - entry.Metadata = RoutingArguments2Labels(sources[i].GetArguments()) - v1sources = append(v1sources, entry) - } - } - - if !find { - break - } - - destinations := item.RuleRouting.Rules[i].Destinations - v1destinations := make([]*apitraffic.Destination, 0, len(destinations)) - for i := range destinations { - name := fmt.Sprintf("%s.%s.%s", item.Name, subRule.Name, destinations[i].Name) - entry := &apitraffic.Destination{ - Name: utils.NewStringValue(name), - Service: utils.NewStringValue(destinations[i].Service), - Namespace: utils.NewStringValue(destinations[i].Namespace), - Priority: utils.NewUInt32Value(destinations[i].GetPriority()), - Weight: utils.NewUInt32Value(destinations[i].GetWeight()), - Transfer: utils.NewStringValue(destinations[i].GetTransfer()), - Isolate: utils.NewBoolValue(destinations[i].GetIsolate()), - } - - v1labels := make(map[string]*apimodel.MatchString) - v2labels := destinations[i].GetLabels() - for index := range v2labels { - v1labels[index] = &apimodel.MatchString{ - Type: v2labels[index].GetType(), - Value: v2labels[index].GetValue(), - ValueType: v2labels[index].GetValueType(), - } - } - - entry.Metadata = v1labels - v1destinations = append(v1destinations, entry) - } - - routes = append(routes, &apitraffic.Route{ - Sources: v1sources, - Destinations: v1destinations, - ExtendInfo: map[string]string{ - V2RuleIDKey: item.ID, - }, - }) - } - - return routes -} - -// BuildInBoundsFromV2 Convert the routing rules of V2 to the inbounds in the routing rule of V1 -func BuildInBoundsFromV2(service, namespace string, item *ExtendRouterConfig) []*apitraffic.Route { - if item.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { - return []*apitraffic.Route{} - } - - var find bool - - matchService := func(destination *apitraffic.DestinationGroup) bool { - if destination.Service == service && destination.Namespace == namespace { - return true - } - if destination.Namespace == namespace && destination.Service == MatchAll { - return true - } - if destination.Namespace == MatchAll && destination.Service == MatchAll { - return true - } - return false - } - - routes := make([]*apitraffic.Route, 0, 8) - - for i := range item.RuleRouting.Rules { - subRule := item.RuleRouting.Rules[i] - destinations := item.RuleRouting.Rules[i].Destinations - v1destinations := make([]*apitraffic.Destination, 0, len(destinations)) - for i := range destinations { - if matchService(destinations[i]) { - find = true - name := fmt.Sprintf("%s.%s.%s", item.Name, subRule.Name, destinations[i].Name) - entry := &apitraffic.Destination{ - Name: utils.NewStringValue(name), - Service: utils.NewStringValue(service), - Namespace: utils.NewStringValue(namespace), - Priority: utils.NewUInt32Value(destinations[i].GetPriority()), - Weight: utils.NewUInt32Value(destinations[i].GetWeight()), - Transfer: utils.NewStringValue(destinations[i].GetTransfer()), - Isolate: utils.NewBoolValue(destinations[i].GetIsolate()), - } - - v1labels := make(map[string]*apimodel.MatchString) - v2labels := destinations[i].GetLabels() - for index := range v2labels { - v1labels[index] = &apimodel.MatchString{ - Type: v2labels[index].GetType(), - Value: v2labels[index].GetValue(), - ValueType: v2labels[index].GetValueType(), - } - } - - entry.Metadata = v1labels - v1destinations = append(v1destinations, entry) - } - } - - if !find { - break - } - - sources := item.RuleRouting.Rules[i].Sources - v1sources := make([]*apitraffic.Source, 0, len(sources)) - for i := range sources { - entry := &apitraffic.Source{ - Service: utils.NewStringValue(sources[i].Service), - Namespace: utils.NewStringValue(sources[i].Namespace), - } - - entry.Metadata = RoutingArguments2Labels(sources[i].GetArguments()) - v1sources = append(v1sources, entry) - } - - routes = append(routes, &apitraffic.Route{ - Sources: v1sources, - Destinations: v1destinations, - ExtendInfo: map[string]string{ - V2RuleIDKey: item.ID, - }, - }) - } - - return routes -} - // RoutingLabels2Arguments Adapting the old label model into a list of parameters func RoutingLabels2Arguments(labels map[string]*apimodel.MatchString) []*apitraffic.SourceMatch { if len(labels) == 0 { @@ -729,7 +542,7 @@ func BuildV2ExtendRouting(req *apitraffic.Routing, route *apitraffic.Route) (*Ex Revision: req.GetRevision().GetValue(), Priority: 0, }, - RuleRouting: convertV1RouteToV2Route(route), + RuleRouting: parseSubRouteRule(convertV1RouteToV2Route(route)), } return routing, nil @@ -821,7 +634,7 @@ func ConvertRoutingV1ToExtendV2(svcName, svcNamespace string, routing.CreateTime = rule.CreateTime routing.ModifyTime = rule.ModifyTime routing.EnableTime = rule.CreateTime - routing.ExtendInfo = map[string]string{ + routing.Metadata = map[string]string{ V1RuleIDKey: rule.ID, V1RuleRouteIndexKey: fmt.Sprintf("%d", i), V1RuleRouteTypeKey: V1RuleInRoute, @@ -857,7 +670,7 @@ func ConvertRoutingV1ToExtendV2(svcName, svcNamespace string, routing.CreateTime = rule.CreateTime routing.ModifyTime = rule.ModifyTime routing.EnableTime = rule.CreateTime - routing.ExtendInfo = map[string]string{ + routing.Metadata = map[string]string{ V1RuleIDKey: rule.ID, V1RuleRouteIndexKey: fmt.Sprintf("%d", i), V1RuleRouteTypeKey: V1RuleOutRoute, @@ -876,3 +689,140 @@ func ConvertRoutingV1ToExtendV2(svcName, svcNamespace string, return inRet, outRet, nil } + +func BuildRoutes(item *ExtendRouterConfig, direction TrafficDirection) []*apitraffic.Route { + switch direction { + case TrafficDirection_INBOUND: + return BuildInBoundsRoute(item) + default: + return BuildOutBoundsRoutes(item) + } +} + +// BuildInBoundsRoute Convert the routing rules of V2 to the inbounds in the routing rule of V1 +func BuildInBoundsRoute(item *ExtendRouterConfig) []*apitraffic.Route { + if item.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { + return []*apitraffic.Route{} + } + + routes := make([]*apitraffic.Route, 0, 8) + + specRules := item.RuleRouting.RuleRouting.Rules + + for i := range specRules { + subRule := specRules[i] + destinations := specRules[i].Destinations + v1destinations := make([]*apitraffic.Destination, 0, len(destinations)) + for i := range destinations { + name := fmt.Sprintf("%s.%s.%s", item.Name, subRule.Name, destinations[i].Name) + entry := &apitraffic.Destination{ + Name: utils.NewStringValue(name), + Service: utils.NewStringValue(item.RuleRouting.Callee.Name), + Namespace: utils.NewStringValue(item.RuleRouting.Callee.Namespace), + Priority: utils.NewUInt32Value(destinations[i].GetPriority()), + Weight: utils.NewUInt32Value(destinations[i].GetWeight()), + Transfer: utils.NewStringValue(destinations[i].GetTransfer()), + Isolate: utils.NewBoolValue(destinations[i].GetIsolate()), + } + + v1labels := make(map[string]*apimodel.MatchString) + v2labels := destinations[i].GetLabels() + for index := range v2labels { + v1labels[index] = &apimodel.MatchString{ + Type: v2labels[index].GetType(), + Value: v2labels[index].GetValue(), + ValueType: v2labels[index].GetValueType(), + } + } + + entry.Metadata = v1labels + v1destinations = append(v1destinations, entry) + } + + sources := specRules[i].Sources + v1sources := make([]*apitraffic.Source, 0, len(sources)) + for i := range sources { + entry := &apitraffic.Source{ + Service: utils.NewStringValue(sources[i].Service), + Namespace: utils.NewStringValue(sources[i].Namespace), + } + + entry.Metadata = RoutingArguments2Labels(sources[i].GetArguments()) + v1sources = append(v1sources, entry) + } + + routes = append(routes, &apitraffic.Route{ + Sources: v1sources, + Destinations: v1destinations, + ExtendInfo: map[string]string{ + V2RuleIDKey: item.ID, + }, + }) + } + + return routes +} + +// BuildOutBoundsRoutes According to the routing rules of the V2 version, it is adapted to the +// outbounds in the routing rule of V1 version +func BuildOutBoundsRoutes(item *ExtendRouterConfig) []*apitraffic.Route { + if item.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { + return []*apitraffic.Route{} + } + + routes := make([]*apitraffic.Route, 0, 8) + + specRules := item.RuleRouting.RuleRouting.Rules + + for i := range specRules { + subRule := specRules[i] + sources := specRules[i].Sources + v1sources := make([]*apitraffic.Source, 0, len(sources)) + for i := range sources { + entry := &apitraffic.Source{ + Service: utils.NewStringValue(item.RuleRouting.Caller.Name), + Namespace: utils.NewStringValue(item.RuleRouting.Caller.Namespace), + } + entry.Metadata = RoutingArguments2Labels(sources[i].GetArguments()) + v1sources = append(v1sources, entry) + } + + destinations := specRules[i].Destinations + v1destinations := make([]*apitraffic.Destination, 0, len(destinations)) + for i := range destinations { + name := fmt.Sprintf("%s.%s.%s", item.Name, subRule.Name, destinations[i].Name) + entry := &apitraffic.Destination{ + Name: utils.NewStringValue(name), + Service: utils.NewStringValue(destinations[i].Service), + Namespace: utils.NewStringValue(destinations[i].Namespace), + Priority: utils.NewUInt32Value(destinations[i].GetPriority()), + Weight: utils.NewUInt32Value(destinations[i].GetWeight()), + Transfer: utils.NewStringValue(destinations[i].GetTransfer()), + Isolate: utils.NewBoolValue(destinations[i].GetIsolate()), + } + + v1labels := make(map[string]*apimodel.MatchString) + v2labels := destinations[i].GetLabels() + for index := range v2labels { + v1labels[index] = &apimodel.MatchString{ + Type: v2labels[index].GetType(), + Value: v2labels[index].GetValue(), + ValueType: v2labels[index].GetValueType(), + } + } + + entry.Metadata = v1labels + v1destinations = append(v1destinations, entry) + } + + routes = append(routes, &apitraffic.Route{ + Sources: v1sources, + Destinations: v1destinations, + ExtendInfo: map[string]string{ + V2RuleIDKey: item.ID, + }, + }) + } + + return routes +} diff --git a/common/model/routing_test.go b/common/model/routing_test.go index 1c2296582..4545b2f00 100644 --- a/common/model/routing_test.go +++ b/common/model/routing_test.go @@ -65,7 +65,7 @@ func TestToExpendRoutingConfig(t *testing.T) { rConfig.Policy = "RulePolicy" erConfig, err := rConfig.ToExpendRoutingConfig() assert.Nil(t, err) - assert.Equal(t, ruleRouting.Sources[0].Service, erConfig.RuleRouting.Sources[0].Service) + assert.Equal(t, ruleRouting.Sources[0].Service, erConfig.RuleRouting.RuleRouting.Sources[0].Service) // 2. check v1 binary anyValue, err := anypb.New(proto.MessageV2(ruleRouting)) @@ -74,7 +74,7 @@ func TestToExpendRoutingConfig(t *testing.T) { rConfig.Config = v1AnyStr erConfig, err = rConfig.ToExpendRoutingConfig() assert.Nil(t, err) - assert.Equal(t, ruleRouting.Sources[0].Service, erConfig.RuleRouting.Sources[0].Service) + assert.Equal(t, ruleRouting.Sources[0].Service, erConfig.RuleRouting.RuleRouting.Sources[0].Service) // 3. check v2 binary //ruleRoutingV2 := &v2.RuleRoutingConfig{ diff --git a/common/utils/collection.go b/common/utils/collection.go index 597cccf07..534edec67 100644 --- a/common/utils/collection.go +++ b/common/utils/collection.go @@ -57,6 +57,90 @@ func (set *Set[K]) Range(fn func(val K)) { } } +// NewRefSyncSet returns a new Set +func NewRefSyncSet[K comparable]() *RefSyncSet[K] { + return &RefSyncSet[K]{ + container: make(map[K]int), + } +} + +type RefSyncSet[K comparable] struct { + container map[K]int + lock sync.RWMutex +} + +// Add adds a string to the set +func (set *RefSyncSet[K]) Add(val K) { + set.lock.Lock() + defer set.lock.Unlock() + + ref, ok := set.container[val] + if ok { + ref++ + } + set.container[val] = ref +} + +// Remove removes a string from the set +func (set *RefSyncSet[K]) Remove(val K) { + set.lock.Lock() + defer set.lock.Unlock() + ref, ok := set.container[val] + if ok { + ref-- + } + if ref == 0 { + delete(set.container, val) + } else { + set.container[val] = ref + } +} + +func (set *RefSyncSet[K]) ToSlice() []K { + set.lock.RLock() + defer set.lock.RUnlock() + + ret := make([]K, 0, len(set.container)) + for k := range set.container { + ret = append(ret, k) + } + return ret +} + +func (set *RefSyncSet[K]) Range(fn func(val K)) { + set.lock.RLock() + snapshot := map[K]struct{}{} + for k := range set.container { + snapshot[k] = struct{}{} + } + set.lock.RUnlock() + + for k := range snapshot { + fn(k) + } +} + +func (set *RefSyncSet[K]) Len() int { + set.lock.RLock() + defer set.lock.RUnlock() + + return len(set.container) +} + +// Contains contains target value +func (set *RefSyncSet[K]) Contains(val K) bool { + set.lock.Lock() + defer set.lock.Unlock() + + _, exist := set.container[val] + return exist +} + +func (set *RefSyncSet[K]) String() string { + ret := set.ToSlice() + return MustJson(ret) +} + // NewSyncSet returns a new Set func NewSyncSet[K comparable]() *SyncSet[K] { return &SyncSet[K]{ @@ -77,6 +161,15 @@ func (set *SyncSet[K]) Add(val K) { set.container[val] = struct{}{} } +// Add adds a string to the set +func (set *SyncSet[K]) AddAll(vals *SyncSet[K]) { + vals.Range(func(val K) { + set.lock.Lock() + defer set.lock.Unlock() + set.container[val] = struct{}{} + }) +} + // Remove removes a string from the set func (set *SyncSet[K]) Remove(val K) { set.lock.Lock() diff --git a/common/utils/common.go b/common/utils/common.go index fcf79b6e6..49595eb1b 100644 --- a/common/utils/common.go +++ b/common/utils/common.go @@ -87,6 +87,8 @@ const ( MaxDbCircuitbreakerOwner = 1024 MaxDbCircuitbreakerVersion = 32 + MaxRuleName = 64 + MaxPlatformIDLength = 32 MaxPlatformNameLength = 128 MaxPlatformDomainLength = 1024 diff --git a/common/utils/const.go b/common/utils/const.go index 91f7eebb7..0320376d1 100644 --- a/common/utils/const.go +++ b/common/utils/const.go @@ -73,4 +73,6 @@ const ( ContextIsFromSystem = StringContext("from-system") // ContextOperator operator info ContextOperator = StringContext("operator") + // ContextRequestHeaders request headers + ContextRequestHeaders = StringContext("request-headers") ) diff --git a/common/utils/funcs.go b/common/utils/funcs.go index f4d3a90ae..3fe5928ff 100644 --- a/common/utils/funcs.go +++ b/common/utils/funcs.go @@ -249,6 +249,7 @@ func ConvertGRPCContext(ctx context.Context) context.Context { ctx = context.Background() ctx = context.WithValue(ctx, ContextGrpcHeader, meta) + ctx = context.WithValue(ctx, ContextRequestHeaders, meta) ctx = context.WithValue(ctx, StringContext("request-id"), requestID) ctx = context.WithValue(ctx, StringContext("client-ip"), clientIP) ctx = context.WithValue(ctx, ContextClientAddress, address) diff --git a/config/api.go b/config/api.go index 9769462a6..432871f77 100644 --- a/config/api.go +++ b/config/api.go @@ -79,12 +79,8 @@ type ConfigFileReleaseOperate interface { GetConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse // DeleteConfigFileReleases 批量删除配置文件发布内容 DeleteConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse - // DeleteConfigFileRelease 删除配置文件发布 - DeleteConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse // RollbackConfigFileReleases 批量回滚配置到指定版本 RollbackConfigFileReleases(ctx context.Context, releases []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse - // RollbackConfigFileRelease 回滚配置到指定版本 - RollbackConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse // GetConfigFileReleases 查询所有的配置发布版本信息 GetConfigFileReleases(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse // GetConfigFileReleaseVersions 查询所有的配置发布版本信息 diff --git a/config/config_file_group_test.go b/config/config_file_group_test.go index 5be79ee96..d5210f6b0 100644 --- a/config/config_file_group_test.go +++ b/config/config_file_group_test.go @@ -22,13 +22,14 @@ import ( "reflect" "testing" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" ) var ( diff --git a/config/interceptor/auth/client_authibility.go b/config/interceptor/auth/client.go similarity index 83% rename from config/interceptor/auth/client_authibility.go rename to config/interceptor/auth/client.go index fdef9c8d2..892b19e1c 100644 --- a/config/interceptor/auth/client_authibility.go +++ b/config/interceptor/auth/client.go @@ -23,7 +23,7 @@ import ( apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" ) @@ -32,9 +32,9 @@ import ( func (s *ServerAuthability) UpsertAndReleaseConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, - model.Modify, "UpsertAndReleaseConfigFileFromClient") + auth.Modify, auth.PublishConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigFileResponse(model.ConvertToErrCode(err), nil) + return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -51,9 +51,9 @@ func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, Namespace: fileInfo.Namespace, Name: fileInfo.Name, Group: fileInfo.Group}, - }, model.Create, "CreateConfigFileFromClient") + }, auth.Create, auth.CreateConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -66,9 +66,9 @@ func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, func (s *ServerAuthability) UpdateConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFile) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, - []*apiconfig.ConfigFile{fileInfo}, model.Modify, "UpdateConfigFileFromClient") + []*apiconfig.ConfigFile{fileInfo}, auth.Modify, auth.UpdateConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -82,9 +82,9 @@ func (s *ServerAuthability) DeleteConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, - []*apiconfig.ConfigFile{req}, model.Delete, "DeleteConfigFileFromClient") + []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -101,9 +101,9 @@ func (s *ServerAuthability) PublishConfigFileFromClient(ctx context.Context, Namespace: fileInfo.Namespace, Name: fileInfo.FileName, Group: fileInfo.Group}, - }, model.Create, "PublishConfigFileFromClient") + }, auth.Create, auth.PublishConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -120,9 +120,9 @@ func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, Namespace: fileInfo.Namespace, Name: fileInfo.FileName, Group: fileInfo.Group}, - }, model.Read, "GetConfigFileForClient") + }, auth.Read, auth.DiscoverConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -133,10 +133,10 @@ func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, // WatchConfigFiles 监听配置文件变化 func (s *ServerAuthability) LongPullWatchFile(ctx context.Context, request *apiconfig.ClientWatchConfigFileRequest) (config.WatchCallback, error) { - authCtx := s.collectClientWatchConfigFiles(ctx, request, model.Read, "LongPullWatchFile") + authCtx := s.collectClientWatchConfigFiles(ctx, request, auth.Read, auth.WatchConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return func() *apiconfig.ConfigClientResponse { - return api.NewConfigClientResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) }, nil } @@ -155,9 +155,9 @@ func (s *ServerAuthability) GetConfigFileNamesWithCache(ctx context.Context, Namespace: req.GetConfigFileGroup().GetNamespace(), Group: req.GetConfigFileGroup().GetName(), }, - }, model.Read, "GetConfigFileNamesWithCache") + }, auth.Read, auth.DiscoverConfigFileNames) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - out := api.NewConfigClientListResponse(model.ConvertToErrCode(err)) + out := api.NewConfigClientListResponse(auth.ConvertToErrCode(err)) return out } @@ -166,6 +166,7 @@ func (s *ServerAuthability) GetConfigFileNamesWithCache(ctx context.Context, return s.nextServer.GetConfigFileNamesWithCache(ctx, req) } +// GetConfigGroupsWithCache 获取某个命名空间下的配置分组列表 func (s *ServerAuthability) GetConfigGroupsWithCache(ctx context.Context, req *apiconfig.ClientConfigFileInfo) *apiconfig.ConfigDiscoverResponse { @@ -173,9 +174,9 @@ func (s *ServerAuthability) GetConfigGroupsWithCache(ctx context.Context, { Namespace: req.GetNamespace(), }, - }, model.Read, "GetConfigGroupsWithCache") + }, auth.Read, auth.DiscoverConfigGroups) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - out := api.NewConfigDiscoverResponse(model.ConvertToErrCode(err)) + out := api.NewConfigDiscoverResponse(auth.ConvertToErrCode(err)) return out } @@ -189,9 +190,9 @@ func (s *ServerAuthability) CasUpsertAndReleaseConfigFileFromClient(ctx context. req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, - model.Modify, "CasUpsertAndReleaseConfigFileFromClient") + auth.Modify, auth.UpsertAndReleaseConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigFileResponse(model.ConvertToErrCode(err), nil) + return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_authibility.go b/config/interceptor/auth/config_file.go similarity index 78% rename from config/interceptor/auth/config_file_authibility.go rename to config/interceptor/auth/config_file.go index ed6a8f5df..4a3fcaf8c 100644 --- a/config/interceptor/auth/config_file_authibility.go +++ b/config/interceptor/auth/config_file.go @@ -23,7 +23,7 @@ import ( apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -31,9 +31,9 @@ import ( func (s *ServerAuthability) CreateConfigFile(ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( - ctx, []*apiconfig.ConfigFile{configFile}, model.Create, "CreateConfigFile") + ctx, []*apiconfig.ConfigFile{configFile}, auth.Create, auth.CreateConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -47,9 +47,9 @@ func (s *ServerAuthability) GetConfigFileRichInfo(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( - ctx, []*apiconfig.ConfigFile{req}, model.Read, "GetConfigFileRichInfo") + ctx, []*apiconfig.ConfigFile{req}, auth.Read, auth.DescribeConfigFileRichInfo) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -60,9 +60,9 @@ func (s *ServerAuthability) GetConfigFileRichInfo(ctx context.Context, func (s *ServerAuthability) SearchConfigFile(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileAuthContext(ctx, nil, model.Read, "SearchConfigFile") + authCtx := s.collectConfigFileAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFiles) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileBatchQueryResponseWithMessage(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -74,9 +74,9 @@ func (s *ServerAuthability) SearchConfigFile(ctx context.Context, func (s *ServerAuthability) UpdateConfigFile( ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( - ctx, []*apiconfig.ConfigFile{configFile}, model.Modify, "UpdateConfigFile") + ctx, []*apiconfig.ConfigFile{configFile}, auth.Modify, auth.UpdateConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -90,9 +90,9 @@ func (s *ServerAuthability) DeleteConfigFile(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, - []*apiconfig.ConfigFile{req}, model.Delete, "DeleteConfigFile") + []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -105,9 +105,9 @@ func (s *ServerAuthability) DeleteConfigFile(ctx context.Context, func (s *ServerAuthability) BatchDeleteConfigFile(ctx context.Context, req []*apiconfig.ConfigFile) *apiconfig.ConfigResponse { - authCtx := s.collectConfigFileAuthContext(ctx, req, model.Delete, "BatchDeleteConfigFile") + authCtx := s.collectConfigFileAuthContext(ctx, req, auth.Delete, auth.BatchDeleteConfigFiles) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -126,9 +126,9 @@ func (s *ServerAuthability) ExportConfigFile(ctx context.Context, } configFiles = append(configFiles, configFile) } - authCtx := s.collectConfigFileAuthContext(ctx, configFiles, model.Read, "ExportConfigFile") + authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Read, auth.ExportConfigFiles) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileExportResponseWithMessage(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigFileExportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -138,9 +138,9 @@ func (s *ServerAuthability) ExportConfigFile(ctx context.Context, func (s *ServerAuthability) ImportConfigFile(ctx context.Context, configFiles []*apiconfig.ConfigFile, conflictHandling string) *apiconfig.ConfigImportResponse { - authCtx := s.collectConfigFileAuthContext(ctx, configFiles, model.Create, "ImportConfigFile") + authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Create, auth.ImportConfigFiles) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileImportResponseWithMessage(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigFileImportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_group_authibility.go b/config/interceptor/auth/config_file_group.go similarity index 77% rename from config/interceptor/auth/config_file_group_authibility.go rename to config/interceptor/auth/config_file_group.go index 17c7c3c7f..caa466c57 100644 --- a/config/interceptor/auth/config_file_group_authibility.go +++ b/config/interceptor/auth/config_file_group.go @@ -19,13 +19,12 @@ package config_auth import ( "context" - "fmt" apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -33,11 +32,11 @@ import ( func (s *ServerAuthability) CreateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, - model.Create, "CreateConfigFileGroup") + auth.Create, auth.CreateConfigFileGroup) // 验证 token 信息 if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -50,10 +49,10 @@ func (s *ServerAuthability) CreateConfigFileGroup(ctx context.Context, func (s *ServerAuthability) QueryConfigFileGroups(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigGroupAuthContext(ctx, nil, model.Read, "QueryConfigFileGroups") + authCtx := s.collectConfigGroupAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileGroups) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(model.ConvertToErrCode(err)) + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -63,13 +62,7 @@ func (s *ServerAuthability) QueryConfigFileGroups(ctx context.Context, if len(resp.ConfigFileGroups) != 0 { for index := range resp.ConfigFileGroups { group := resp.ConfigFileGroups[index] - editable := s.policyMgr.GetAuthChecker().AllowResourceOperate(authCtx, &model.ResourceOpInfo{ - ResourceType: apisecurity.ResourceType_ConfigGroups, - Namespace: group.GetNamespace().GetValue(), - ResourceName: group.GetName().GetValue(), - ResourceID: fmt.Sprintf("%d", group.GetId().GetValue()), - Operation: authCtx.GetOperation(), - }) + editable := true // 如果包含特殊标签,也不允许修改 if _, ok := group.GetMetadata()[model.MetaKey3RdPlatform]; ok { editable = false @@ -84,10 +77,10 @@ func (s *ServerAuthability) QueryConfigFileGroups(ctx context.Context, func (s *ServerAuthability) DeleteConfigFileGroup( ctx context.Context, namespace, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{{Name: utils.NewStringValue(name), - Namespace: utils.NewStringValue(namespace)}}, model.Delete, "DeleteConfigFileGroup") + Namespace: utils.NewStringValue(namespace)}}, auth.Delete, auth.DeleteConfigFileGroup) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -100,10 +93,10 @@ func (s *ServerAuthability) DeleteConfigFileGroup( func (s *ServerAuthability) UpdateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, - model.Modify, "UpdateConfigFileGroup") + auth.Modify, auth.UpdateConfigFileGroup) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_release_authibility.go b/config/interceptor/auth/config_file_release.go similarity index 64% rename from config/interceptor/auth/config_file_release_authibility.go rename to config/interceptor/auth/config_file_release.go index 09711aff5..3a2df3e04 100644 --- a/config/interceptor/auth/config_file_release_authibility.go +++ b/config/interceptor/auth/config_file_release.go @@ -23,7 +23,7 @@ import ( apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -32,10 +32,10 @@ func (s *ServerAuthability) PublishConfigFile(ctx context.Context, configFileRelease *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, - []*apiconfig.ConfigFileRelease{configFileRelease}, model.Modify, "PublishConfigFile") + []*apiconfig.ConfigFileRelease{configFileRelease}, auth.Modify, "PublishConfigFile") if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -49,10 +49,10 @@ func (s *ServerAuthability) GetConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, - []*apiconfig.ConfigFileRelease{req}, model.Read, "GetConfigFileRelease") + []*apiconfig.ConfigFileRelease{req}, auth.Read, auth.DescribeConfigFileRelease) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -63,38 +63,24 @@ func (s *ServerAuthability) GetConfigFileRelease(ctx context.Context, func (s *ServerAuthability) DeleteConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, model.Delete, "DeleteConfigFileReleases") + authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Delete, auth.DeleteConfigFileReleases) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return s.nextServer.DeleteConfigFileReleases(ctx, reqs) } -// DeleteConfigFileRelease implements ConfigCenterServer. -func (s *ServerAuthability) DeleteConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ - req, - }, model.Delete, "DeleteConfigFileRelease") - - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return s.nextServer.DeleteConfigFileRelease(ctx, req) -} - // GetConfigFileReleaseVersions implements ConfigCenterServer. func (s *ServerAuthability) GetConfigFileReleaseVersions(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, model.Read, "GetConfigFileReleaseVersions") + authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseVersions) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -105,10 +91,10 @@ func (s *ServerAuthability) GetConfigFileReleaseVersions(ctx context.Context, func (s *ServerAuthability) GetConfigFileReleases(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, model.Read, "GetConfigFileReleases") + authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleases) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -119,38 +105,23 @@ func (s *ServerAuthability) GetConfigFileReleases(ctx context.Context, func (s *ServerAuthability) RollbackConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, model.Modify, "RollbackConfigFileReleases") + authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Modify, auth.RollbackConfigFileReleases) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return s.nextServer.RollbackConfigFileReleases(ctx, reqs) } -func (s *ServerAuthability) RollbackConfigFileRelease(ctx context.Context, - req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { - - authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ - req, - }, model.Modify, "RollbackConfigFileRelease") - - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return s.nextServer.RollbackConfigFileRelease(ctx, req) -} - // UpsertAndReleaseConfigFile . func (s *ServerAuthability) UpsertAndReleaseConfigFile(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, - model.Modify, "UpsertAndReleaseConfigFile") + auth.Modify, auth.UpsertAndReleaseConfigFile) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileResponse(model.ConvertToErrCode(err), nil) + return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -163,9 +134,9 @@ func (s *ServerAuthability) StopGrayConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, - model.Modify, "StopGrayConfigFileReleases") + auth.Modify, auth.StopGrayConfigFileReleases) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponse(model.ConvertToErrCode(err)) + return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_release_history_authibility.go b/config/interceptor/auth/config_file_release_history.go similarity index 88% rename from config/interceptor/auth/config_file_release_history_authibility.go rename to config/interceptor/auth/config_file_release_history.go index 54860942c..3b4464c7e 100644 --- a/config/interceptor/auth/config_file_release_history_authibility.go +++ b/config/interceptor/auth/config_file_release_history.go @@ -23,7 +23,7 @@ import ( apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -31,10 +31,10 @@ import ( func (s *ServerAuthability) GetConfigFileReleaseHistories(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileReleaseHistoryAuthContext(ctx, nil, model.Read, "GetConfigFileReleaseHistories") + authCtx := s.collectConfigFileReleaseHistoryAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseHistories) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) diff --git a/config/interceptor/auth/config_file_template_authibility.go b/config/interceptor/auth/config_file_template.go similarity index 81% rename from config/interceptor/auth/config_file_template_authibility.go rename to config/interceptor/auth/config_file_template.go index c41cdf258..dae5f9ab9 100644 --- a/config/interceptor/auth/config_file_template_authibility.go +++ b/config/interceptor/auth/config_file_template.go @@ -23,16 +23,16 @@ import ( apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // GetAllConfigFileTemplates get all config file templates func (s *ServerAuthability) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, - []*apiconfig.ConfigFileTemplate{}, model.Read, "GetAllConfigFileTemplates") + []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeAllConfigFileTemplates) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileBatchQueryResponseWithMessage(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -43,9 +43,9 @@ func (s *ServerAuthability) GetAllConfigFileTemplates(ctx context.Context) *apic // GetConfigFileTemplate get config file template func (s *ServerAuthability) GetConfigFileTemplate(ctx context.Context, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, - []*apiconfig.ConfigFileTemplate{}, model.Read, "GetAllConfigFileTemplates") + []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeConfigFileTemplate) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -58,9 +58,9 @@ func (s *ServerAuthability) CreateConfigFileTemplate(ctx context.Context, template *apiconfig.ConfigFileTemplate) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, - []*apiconfig.ConfigFileTemplate{template}, model.Create, "CreateConfigFileTemplate") + []*apiconfig.ConfigFileTemplate{template}, auth.Create, auth.CreateConfigFileTemplate) if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(model.ConvertToErrCode(err), err.Error()) + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/resource_listener.go b/config/interceptor/auth/resource_listener.go index 6e3e99858..f1fda9b46 100644 --- a/config/interceptor/auth/resource_listener.go +++ b/config/interceptor/auth/resource_listener.go @@ -24,6 +24,7 @@ import ( apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" ) @@ -45,9 +46,9 @@ func (s *ServerAuthability) After(ctx context.Context, resourceType model.Resour // onConfigGroupResource func (s *ServerAuthability) onConfigGroupResource(ctx context.Context, res *config.ResourceEvent) error { - authCtx := ctx.Value(utils.ContextAuthContextKey).(*model.AcquireContext) + authCtx := ctx.Value(utils.ContextAuthContextKey).(*auth.AcquireContext) - authCtx.SetAttachment(model.ResourceAttachmentKey, map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx.SetAttachment(auth.ResourceAttachmentKey, map[apisecurity.ResourceType][]auth.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: { { ID: strconv.FormatUint(res.ConfigGroup.Id.GetValue(), 10), @@ -62,11 +63,11 @@ func (s *ServerAuthability) onConfigGroupResource(ctx context.Context, res *conf groups := utils.ConvertStringValuesToSlice(res.ConfigGroup.GroupIds) removeGroups := utils.ConvertStringValuesToSlice(res.ConfigGroup.RemoveGroupIds) - authCtx.SetAttachment(model.LinkUsersKey, utils.StringSliceDeDuplication(users)) - authCtx.SetAttachment(model.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) + authCtx.SetAttachment(auth.LinkUsersKey, utils.StringSliceDeDuplication(users)) + authCtx.SetAttachment(auth.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) - authCtx.SetAttachment(model.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) - authCtx.SetAttachment(model.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) + authCtx.SetAttachment(auth.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) + authCtx.SetAttachment(auth.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) return s.policyMgr.AfterResourceOperation(authCtx) } diff --git a/config/interceptor/auth/server_authability.go b/config/interceptor/auth/server.go similarity index 64% rename from config/interceptor/auth/server_authability.go rename to config/interceptor/auth/server.go index 53f2e1324..4fece2d2b 100644 --- a/config/interceptor/auth/server_authability.go +++ b/config/interceptor/auth/server.go @@ -28,6 +28,7 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" ) @@ -57,108 +58,108 @@ func New(nextServer config.ConfigCenterServer, cacheMgr cachetypes.CacheManager, } func (s *ServerAuthability) collectConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithAccessResources(s.queryConfigFileResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(s.queryConfigFileResource(ctx, req)), ) } func (s *ServerAuthability) collectClientConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithFromClient(), - model.WithAccessResources(s.queryConfigFileResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithFromClient(), + authcommon.WithAccessResources(s.queryConfigFileResource(ctx, req)), ) } func (s *ServerAuthability) collectClientWatchConfigFiles(ctx context.Context, - req *apiconfig.ClientWatchConfigFileRequest, op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithFromClient(), - model.WithAccessResources(s.queryWatchConfigFilesResource(ctx, req)), + req *apiconfig.ClientWatchConfigFileRequest, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithFromClient(), + authcommon.WithAccessResources(s.queryWatchConfigFilesResource(ctx, req)), ) } func (s *ServerAuthability) collectConfigFileReleaseAuthContext(ctx context.Context, req []*apiconfig.ConfigFileRelease, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithAccessResources(s.queryConfigFileReleaseResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(s.queryConfigFileReleaseResource(ctx, req)), ) } func (s *ServerAuthability) collectConfigFilePublishAuthContext(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithAccessResources(s.queryConfigFilePublishResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(s.queryConfigFilePublishResource(ctx, req)), ) } func (s *ServerAuthability) collectClientConfigFileReleaseAuthContext(ctx context.Context, - req []*apiconfig.ConfigFileRelease, op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithFromClient(), - model.WithAccessResources(s.queryConfigFileReleaseResource(ctx, req)), + req []*apiconfig.ConfigFileRelease, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithFromClient(), + authcommon.WithAccessResources(s.queryConfigFileReleaseResource(ctx, req)), ) } func (s *ServerAuthability) collectConfigFileReleaseHistoryAuthContext( ctx context.Context, req []*apiconfig.ConfigFileReleaseHistory, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithAccessResources(s.queryConfigFileReleaseHistoryResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(s.queryConfigFileReleaseHistoryResource(ctx, req)), ) } func (s *ServerAuthability) collectConfigGroupAuthContext(ctx context.Context, req []*apiconfig.ConfigFileGroup, - op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), - model.WithOperation(op), - model.WithMethod(methodName), - model.WithAccessResources(s.queryConfigGroupResource(ctx, req)), + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), + authcommon.WithOperation(op), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(s.queryConfigGroupResource(ctx, req)), ) } func (s *ServerAuthability) collectConfigFileTemplateAuthContext(ctx context.Context, - req []*apiconfig.ConfigFileTemplate, op model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithModule(model.ConfigModule), + req []*apiconfig.ConfigFileTemplate, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithModule(authcommon.ConfigModule), ) } func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, - req []*apiconfig.ConfigFileGroup) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apiconfig.ConfigFileGroup) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { return nil @@ -178,7 +179,7 @@ func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, utils.RequestID(ctx), zap.Error(err)) return nil } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file_group access res", @@ -188,7 +189,7 @@ func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, // queryConfigFileResource config file资源的鉴权转换为config group的鉴权 func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, - req []*apiconfig.ConfigFile) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apiconfig.ConfigFile) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { return nil @@ -205,7 +206,7 @@ func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, utils.RequestID(ctx), zap.Error(err)) return nil } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file access res", @@ -214,7 +215,7 @@ func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, } func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, - req []*apiconfig.ConfigFileRelease) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apiconfig.ConfigFileRelease) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { return nil @@ -231,7 +232,7 @@ func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, utils.RequestID(ctx), zap.Error(err)) return nil } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file access res", @@ -240,7 +241,7 @@ func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, } func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, - req []*apiconfig.ConfigFilePublishInfo) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apiconfig.ConfigFilePublishInfo) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { return nil @@ -256,7 +257,7 @@ func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, authLog.Debug("[Config][Server] collect config_file res", utils.RequestID(ctx), zap.Error(err)) return nil } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file access res", utils.RequestID(ctx), zap.Any("res", ret)) @@ -264,7 +265,7 @@ func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, } func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Context, - req []*apiconfig.ConfigFileReleaseHistory) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apiconfig.ConfigFileReleaseHistory) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { return nil @@ -281,7 +282,7 @@ func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Co utils.RequestID(ctx), zap.Error(err)) return nil } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file access res", @@ -290,7 +291,7 @@ func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Co } func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, namespace string, - names []string) ([]model.ResourceEntry, error) { + names []string) ([]authcommon.ResourceEntry, error) { configFileGroups := make([]*model.ConfigFileGroup, 0, len(names)) for i := range names { @@ -302,11 +303,11 @@ func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, configFileGroups = append(configFileGroups, data) } - entries := make([]model.ResourceEntry, 0, len(configFileGroups)) + entries := make([]authcommon.ResourceEntry, 0, len(configFileGroups)) for index := range configFileGroups { group := configFileGroups[index] - entries = append(entries, model.ResourceEntry{ + entries = append(entries, authcommon.ResourceEntry{ ID: strconv.FormatUint(group.Id, 10), Owner: group.Owner, }) @@ -315,13 +316,13 @@ func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, } func (s *ServerAuthability) queryWatchConfigFilesResource(ctx context.Context, - req *apiconfig.ClientWatchConfigFileRequest) map[apisecurity.ResourceType][]model.ResourceEntry { + req *apiconfig.ClientWatchConfigFileRequest) map[apisecurity.ResourceType][]authcommon.ResourceEntry { files := req.GetWatchFiles() if len(files) == 0 { return nil } temp := map[string]struct{}{} - entries := make([]model.ResourceEntry, 0, len(files)) + entries := make([]authcommon.ResourceEntry, 0, len(files)) for _, apiConfigFile := range files { namespace := apiConfigFile.GetNamespace().GetValue() groupName := apiConfigFile.GetGroup().GetValue() @@ -334,13 +335,13 @@ func (s *ServerAuthability) queryWatchConfigFilesResource(ctx context.Context, if data == nil { continue } - entries = append(entries, model.ResourceEntry{ + entries = append(entries, authcommon.ResourceEntry{ ID: strconv.FormatUint(data.Id, 10), Owner: data.Owner, }) } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_ConfigGroups: entries, } authLog.Debug("[Config][Server] collect config_file watch access res", diff --git a/config/interceptor/paramcheck/config_file_release_check.go b/config/interceptor/paramcheck/config_file_release_check.go index a43337392..26db9a819 100644 --- a/config/interceptor/paramcheck/config_file_release_check.go +++ b/config/interceptor/paramcheck/config_file_release_check.go @@ -62,27 +62,7 @@ func (s *Server) GetConfigFileRelease(ctx context.Context, // DeleteConfigFileReleases implements ConfigCenterServer. func (s *Server) DeleteConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { - responses := api.NewConfigBatchWriteResponse(apimodel.Code_ExecuteSuccess) - chs := make([]chan *apiconfig.ConfigResponse, 0, len(reqs)) - for i, instance := range reqs { - chs = append(chs, make(chan *apiconfig.ConfigResponse)) - go func(index int, ins *apiconfig.ConfigFileRelease) { - chs[index] <- s.DeleteConfigFileRelease(ctx, ins) - }(i, instance) - } - - for _, ch := range chs { - resp := <-ch - api.ConfigCollect(responses, resp) - } - return responses -} - -func (s *Server) DeleteConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { - if errCode, errMsg := checkBaseReleaseParam(req, true); errCode != apimodel.Code_ExecuteSuccess { - return api.NewConfigResponseWithInfo(errCode, errMsg) - } - return s.nextServer.DeleteConfigFileRelease(ctx, req) + return s.nextServer.DeleteConfigFileReleases(ctx, reqs) } // GetConfigFileReleaseVersions implements ConfigCenterServer. @@ -138,28 +118,7 @@ func (s *Server) GetConfigFileReleases(ctx context.Context, func (s *Server) RollbackConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { - responses := api.NewConfigBatchWriteResponse(apimodel.Code_ExecuteSuccess) - chs := make([]chan *apiconfig.ConfigResponse, 0, len(reqs)) - for i, instance := range reqs { - chs = append(chs, make(chan *apiconfig.ConfigResponse)) - go func(index int, ins *apiconfig.ConfigFileRelease) { - chs[index] <- s.RollbackConfigFileRelease(ctx, ins) - }(i, instance) - } - - for _, ch := range chs { - resp := <-ch - api.ConfigCollect(responses, resp) - } - return responses -} - -func (s *Server) RollbackConfigFileRelease(ctx context.Context, - req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { - if errCode, errMsg := checkBaseReleaseParam(req, true); errCode != apimodel.Code_ExecuteSuccess { - return api.NewConfigResponseWithInfo(errCode, errMsg) - } - return s.nextServer.RollbackConfigFileRelease(ctx, req) + return s.nextServer.RollbackConfigFileReleases(ctx, reqs) } // UpsertAndReleaseConfigFile . diff --git a/config/watcher.go b/config/watcher.go index cc6b6b4fe..3d8323d5f 100644 --- a/config/watcher.go +++ b/config/watcher.go @@ -53,25 +53,25 @@ type ( WatchContextFactory func(clientId string, matcher BetaReleaseMatcher) WatchContext WatchContext interface { - // ClientID . + // ClientID 客户端发起的 ClientID() string - // ClientLabels . + // ClientLabels 客户端的标识,用于灰度发布要做标签的匹配判断 ClientLabels() map[string]string - // AppendInterest . + // AppendInterest 客户端增加订阅列表 AppendInterest(item *apiconfig.ClientConfigFileInfo) - // RemoveInterest . + // RemoveInterest 客户端删除订阅列表 RemoveInterest(item *apiconfig.ClientConfigFileInfo) - // ShouldNotify . + // ShouldNotify 判断是不是需要通知客户端某个配置变动了 ShouldNotify(event *model.SimpleConfigFileRelease) bool - // Reply . + // Reply 真正的通知逻辑 Reply(rsp *apiconfig.ConfigClientResponse) // Close . Close() error - // ShouldExpire . + // ShouldExpire 是不是存在有效时间 ShouldExpire(now time.Time) bool - // ListWatchFiles + // ListWatchFiles 列举出当前订阅的所有配置文件 ListWatchFiles() []*apiconfig.ClientConfigFileInfo - // IsOnce + // IsOnce 是不是只能被通知一次 IsOnce() bool } ) diff --git a/go.mod b/go.mod index c3ac69915..f347292b3 100644 --- a/go.mod +++ b/go.mod @@ -80,7 +80,7 @@ require ( require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/polarismesh/specification v1.5.0 + github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 ) require github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index a77ff7b70..9310d5282 100644 --- a/go.sum +++ b/go.sum @@ -296,8 +296,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219 h1:XnFyNUWnciM6zgXaz6tm+Egs35rhoD0KGMmKh4gCdi0= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219/go.mod h1:4WhwBysTom9Eoy0hQ4W69I0FmO+T0EpjEW9/5sgHoUk= -github.com/polarismesh/specification v1.5.0 h1:GzPtvqXCdiZ3tTKSenROrwSi0Bam2U2dM2opsBvP+mM= -github.com/polarismesh/specification v1.5.0/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= +github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 h1:eLZzl6yaeuQbTRYeZDu9cqrkJt7qlrt+fE7hGB+R3bc= +github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= diff --git a/namespace/namespace.go b/namespace/namespace.go index b0b5dc7e8..07e9f10cc 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -122,16 +122,14 @@ func (s *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) * return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } - msg := fmt.Sprintf("create namespace: name=%s", namespaceName) - log.Info(msg, utils.ZapRequestID(requestID)) - + log.Info("create namespace", utils.RequestID(ctx), zap.String("name", namespaceName)) out := &apimodel.Namespace{ Name: req.GetName(), Token: utils.NewStringValue(data.Token), } + s.RecordHistory(namespaceRecordEntry(ctx, req, model.OCreate)) _ = s.afterNamespaceResource(ctx, req, data, false) - return api.NewNamespaceResponse(apimodel.Code_ExecuteSuccess, out) } @@ -229,8 +227,7 @@ func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) * s.caches.Service().CleanNamespace(namespace.Name) - msg := fmt.Sprintf("delete namespace: name=%s", namespace.Name) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info("delete namespace", utils.RequestID(ctx), zap.String("name", namespace.Name)) s.RecordHistory(namespaceRecordEntry(ctx, req, model.ODelete)) _ = s.afterNamespaceResource(ctx, req, &model.Namespace{Name: req.GetName().GetValue()}, true) diff --git a/namespace/namespace_authability.go b/namespace/namespace_authability.go index 7e4b5cc21..98209e163 100644 --- a/namespace/namespace_authability.go +++ b/namespace/namespace_authability.go @@ -24,8 +24,10 @@ import ( apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -40,9 +42,9 @@ func (svr *serverAuthAbility) CreateNamespaceIfAbsent(ctx context.Context, // CreateNamespace 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 func (svr *serverAuthAbility) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( - ctx, []*apimodel.Namespace{req}, model.Create, "CreateNamespace") + ctx, []*apimodel.Namespace{req}, authcommon.Create, authcommon.CreateNamespace) // 验证 token 信息 - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -60,10 +62,10 @@ func (svr *serverAuthAbility) CreateNamespace(ctx context.Context, req *apimodel // CreateNamespaces 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 func (svr *serverAuthAbility) CreateNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { - authCtx := svr.collectNamespaceAuthContext(ctx, reqs, model.Create, "CreateNamespaces") + authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateNamespaces) // 验证 token 信息 - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -85,8 +87,8 @@ func (svr *serverAuthAbility) CreateNamespaces( // DeleteNamespace 删除命名空间,需要先走权限检查 func (svr *serverAuthAbility) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( - ctx, []*apimodel.Namespace{req}, model.Delete, "DeleteNamespace") - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + ctx, []*apimodel.Namespace{req}, authcommon.Delete, authcommon.DeleteNamespace) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -99,8 +101,8 @@ func (svr *serverAuthAbility) DeleteNamespace(ctx context.Context, req *apimodel // DeleteNamespaces 删除命名空间,需要先走权限检查 func (svr *serverAuthAbility) DeleteNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { - authCtx := svr.collectNamespaceAuthContext(ctx, reqs, model.Delete, "DeleteNamespaces") - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteNamespaces) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -113,8 +115,8 @@ func (svr *serverAuthAbility) DeleteNamespaces( // UpdateNamespaces 更新命名空间,需要先走权限检查 func (svr *serverAuthAbility) UpdateNamespaces( ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse { - authCtx := svr.collectNamespaceAuthContext(ctx, req, model.Modify, "UpdateNamespaces") - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + authCtx := svr.collectNamespaceAuthContext(ctx, req, authcommon.Modify, authcommon.UpdateNamespaces) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -127,8 +129,8 @@ func (svr *serverAuthAbility) UpdateNamespaces( // UpdateNamespaceToken 更新命名空间的token信息,需要先走权限检查 func (svr *serverAuthAbility) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( - ctx, []*apimodel.Namespace{req}, model.Modify, "UpdateNamespaceToken") - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + ctx, []*apimodel.Namespace{req}, authcommon.Modify, authcommon.UpdateNamespaceToken) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -141,8 +143,8 @@ func (svr *serverAuthAbility) UpdateNamespaceToken(ctx context.Context, req *api // GetNamespaces 获取命名空间列表信息,暂时不走权限检查 func (svr *serverAuthAbility) GetNamespaces( ctx context.Context, query map[string][]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectNamespaceAuthContext(ctx, nil, model.Read, "GetNamespaces") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) + authCtx := svr.collectNamespaceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeNamespaces) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) } @@ -150,29 +152,21 @@ func (svr *serverAuthAbility) GetNamespaces( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - resp := svr.targetServer.GetNamespaces(ctx, query) - if len(resp.Namespaces) != 0 { - for index := range resp.Namespaces { - ns := resp.Namespaces[index] - editable := svr.strategyMgn.GetAuthChecker().AllowResourceOperate(authCtx, &model.ResourceOpInfo{ - ResourceType: apisecurity.ResourceType_Namespaces, - Namespace: ns.GetName().GetValue(), - ResourceName: ns.GetName().GetValue(), - ResourceID: ns.GetId().GetValue(), - Operation: authCtx.GetOperation(), - }) - ns.Editable = utils.NewBoolValue(editable) - } - } + cachetypes.AppendNamespacePredicate(ctx, func(ctx context.Context, n *model.Namespace) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Users, + ID: n.Name, + }) + }) - return resp + return svr.targetServer.GetNamespaces(ctx, query) } // GetNamespaceToken 获取命名空间的token信息,暂时不走权限检查 func (svr *serverAuthAbility) GetNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( - ctx, []*apimodel.Namespace{req}, model.Read, "GetNamespaceToken") - _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx) + ctx, []*apimodel.Namespace{req}, authcommon.Read, authcommon.DescribeNamespaceToken) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } diff --git a/namespace/resource_listener.go b/namespace/resource_listener.go index 01e1985d3..0fb7268d3 100644 --- a/namespace/resource_listener.go +++ b/namespace/resource_listener.go @@ -24,6 +24,7 @@ import ( apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -66,13 +67,13 @@ func (svr *serverAuthAbility) After(ctx context.Context, resourceType model.Reso // onNamespaceResource func (svr *serverAuthAbility) onNamespaceResource(ctx context.Context, res *ResourceEvent) error { - authCtx, _ := ctx.Value(utils.ContextAuthContextKey).(*model.AcquireContext) + authCtx, _ := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) if authCtx == nil { log.Warn("[Namespace][ResourceHook] get auth context is nil, ignore", utils.RequestID(ctx)) return nil } - authCtx.SetAttachment(model.ResourceAttachmentKey, map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx.SetAttachment(authcommon.ResourceAttachmentKey, map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Namespaces: { { ID: res.Namespace.Name, @@ -87,11 +88,11 @@ func (svr *serverAuthAbility) onNamespaceResource(ctx context.Context, res *Reso groups := utils.ConvertStringValuesToSlice(res.ReqNamespace.GroupIds) removeGroups := utils.ConvertStringValuesToSlice(res.ReqNamespace.RemoveGroupIds) - authCtx.SetAttachment(model.LinkUsersKey, utils.StringSliceDeDuplication(users)) - authCtx.SetAttachment(model.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) + authCtx.SetAttachment(authcommon.LinkUsersKey, utils.StringSliceDeDuplication(users)) + authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) - authCtx.SetAttachment(model.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) - authCtx.SetAttachment(model.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) + authCtx.SetAttachment(authcommon.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) + authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) - return svr.strategyMgn.AfterResourceOperation(authCtx) + return svr.policySvr.AfterResourceOperation(authCtx) } diff --git a/namespace/server_authability.go b/namespace/server_authability.go index 704c6b93c..aa7cb0a18 100644 --- a/namespace/server_authability.go +++ b/namespace/server_authability.go @@ -26,7 +26,7 @@ import ( "go.uber.org/zap" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) @@ -36,15 +36,15 @@ import ( type serverAuthAbility struct { targetServer *Server userMgn auth.UserServer - strategyMgn auth.StrategyServer + policySvr auth.StrategyServer } func newServerAuthAbility(targetServer *Server, - userMgn auth.UserServer, strategyMgn auth.StrategyServer) NamespaceOperateServer { + userMgn auth.UserServer, policySvr auth.StrategyServer) NamespaceOperateServer { proxy := &serverAuthAbility{ targetServer: targetServer, userMgn: userMgn, - strategyMgn: strategyMgn, + policySvr: policySvr, } targetServer.SetResourceHooks(proxy) @@ -53,19 +53,19 @@ func newServerAuthAbility(targetServer *Server, // collectNamespaceAuthContext 对于命名空间的处理,收集所有的与鉴权的相关信息 func (svr *serverAuthAbility) collectNamespaceAuthContext(ctx context.Context, req []*apimodel.Namespace, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.CoreModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryNamespaceResource(req)), + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.CoreModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryNamespaceResource(req)), ) } // queryNamespaceResource 根据所给的 namespace 信息,收集对应的 ResourceEntry 列表 func (svr *serverAuthAbility) queryNamespaceResource( - req []*apimodel.Namespace) map[apisecurity.ResourceType][]model.ResourceEntry { + req []*apimodel.Namespace) map[apisecurity.ResourceType][]authcommon.ResourceEntry { names := utils.NewSet[string]() for index := range req { names.Add(req[index].Name.GetValue()) @@ -73,17 +73,18 @@ func (svr *serverAuthAbility) queryNamespaceResource( param := names.ToSlice() nsArr := svr.targetServer.caches.Namespace().GetNamespacesByName(param) - temp := make([]model.ResourceEntry, 0, len(nsArr)) + temp := make([]authcommon.ResourceEntry, 0, len(nsArr)) for index := range nsArr { ns := nsArr[index] - temp = append(temp, model.ResourceEntry{ + temp = append(temp, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Namespaces, ID: ns.Name, Owner: ns.Owner, }) } - ret := map[apisecurity.ResourceType][]model.ResourceEntry{ + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Namespaces: temp, } authLog.Debug("[Auth][Server] collect namespace access res", zap.Any("res", ret)) @@ -91,10 +92,10 @@ func (svr *serverAuthAbility) queryNamespaceResource( } func convertToErrCode(err error) apimodel.Code { - if errors.Is(err, model.ErrorTokenNotExist) { + if errors.Is(err, authcommon.ErrorTokenNotExist) { return apimodel.Code_TokenNotExisted } - if errors.Is(err, model.ErrorTokenDisabled) { + if errors.Is(err, authcommon.ErrorTokenDisabled) { return apimodel.Code_TokenDisabled } return apimodel.Code_NotAllowedAccess diff --git a/plugin/healthchecker/leader/debug_test.go b/plugin/healthchecker/leader/debug_test.go index 597bd0fa6..d4f0ac16e 100644 --- a/plugin/healthchecker/leader/debug_test.go +++ b/plugin/healthchecker/leader/debug_test.go @@ -25,9 +25,10 @@ import ( "sync/atomic" "testing" + "github.com/stretchr/testify/assert" + commonhash "github.com/polarismesh/polaris/common/hash" "github.com/polarismesh/polaris/common/utils" - "github.com/stretchr/testify/assert" ) func Test_LeaderCheckerDebugerHandler(t *testing.T) { diff --git a/release/conf/polaris-server.yaml b/release/conf/polaris-server.yaml index 5aac81bfa..91535a854 100644 --- a/release/conf/polaris-server.yaml +++ b/release/conf/polaris-server.yaml @@ -319,7 +319,7 @@ apiservers: option: listenIP: "0.0.0.0" listenPort: 8848 - # 设置 nacos 默认命名空间对应 Polaris 命名空间信息 + # Set the nacos default namespace to correspond to the Polaris namespace information defaultNamespace: default connLimit: openConnLimit: false diff --git a/service/batch/client_future_test.go b/service/batch/client_future_test.go index 4b569c9f7..983606c36 100644 --- a/service/batch/client_future_test.go +++ b/service/batch/client_future_test.go @@ -20,8 +20,9 @@ package batch import ( "testing" - "github.com/polarismesh/polaris/common/model" "github.com/stretchr/testify/assert" + + "github.com/polarismesh/polaris/common/model" ) func TestClientFuture_SetClient(t *testing.T) { diff --git a/service/batch/future.go b/service/batch/future.go index a690ad830..65b9fc4da 100644 --- a/service/batch/future.go +++ b/service/batch/future.go @@ -63,7 +63,7 @@ func (future *InstanceFuture) Reply(cur time.Time, code apimodel.Code, result er if !future.needWait { if result != nil { - log.Error("[Instance][Regis] receive future result", zap.String("service-id", future.serviceId), + log.Error("[Instance][Regis] receive future result", zap.String("instance-id", future.instance.ID()), zap.Error(result)) } return diff --git a/service/client_v1.go b/service/client_v1.go index 0c08fc0b5..ec5488852 100644 --- a/service/client_v1.go +++ b/service/client_v1.go @@ -250,7 +250,7 @@ func (s *Server) ServiceInstancesCache(ctx context.Context, filter *apiservice.D } ret := s.caches.Instance().DiscoverServiceInstances(specSvc.GetId().GetValue(), filter.GetOnlyHealthyInstance()) for i := range ret { - copyIns := s.getInstance(req, ret[i].Proto) + copyIns := s.getInstance(specSvc, ret[i].Proto) // 注意:这里的value是cache的,不修改cache的数据,通过getInstance,浅拷贝一份数据 finalInstances[copyIns.GetId().GetValue()] = copyIns } diff --git a/service/faultdetect_config.go b/service/faultdetect_config.go index 95785e43c..32ab8d5de 100644 --- a/service/faultdetect_config.go +++ b/service/faultdetect_config.go @@ -28,7 +28,9 @@ import ( apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" commonstore "github.com/polarismesh/polaris/common/store" @@ -36,25 +38,9 @@ import ( "github.com/polarismesh/polaris/common/utils" ) -func checkBatchFaultDetectRules(req []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - // CreateFaultDetectRules Create a FaultDetect rule func (s *Server) CreateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { - return checkErr - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, cbRule := range request { response := s.createFaultDetectRule(ctx, cbRule) @@ -66,9 +52,6 @@ func (s *Server) CreateFaultDetectRules( // DeleteFaultDetectRules Delete current Fault Detect rules func (s *Server) DeleteFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { - return checkErr - } responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, cbRule := range request { @@ -81,9 +64,6 @@ func (s *Server) DeleteFaultDetectRules( // UpdateFaultDetectRules Modify the FaultDetect rule func (s *Server) UpdateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { - return checkErr - } responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, cbRule := range request { @@ -93,43 +73,6 @@ func (s *Server) UpdateFaultDetectRules( return api.FormatBatchWriteResponse(responses) } -func checkFaultDetectRuleParams( - req *apifault.FaultDetectRule, idRequired bool, nameRequired bool) *apiservice.Response { - if req == nil { - return api.NewResponse(apimodel.Code_EmptyRequest) - } - if resp := checkFaultDetectRuleParamsDbLen(req); nil != resp { - return resp - } - if nameRequired && len(req.GetName()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) - } - if idRequired && len(req.GetId()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) - } - return nil -} - -func checkFaultDetectRuleParamsDbLen(req *apifault.FaultDetectRule) *apiservice.Response { - if err := utils.CheckDbRawStrFieldLen(req.GetTargetService().GetService(), MaxDbServiceNameLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceName) - } - if err := utils.CheckDbRawStrFieldLen( - req.GetTargetService().GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetName(), MaxRuleName); err != nil { - return api.NewResponse(apimodel.Code_InvalidRateLimitName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), MaxCommentLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceComment) - } - return nil -} - func faultDetectRuleRecordEntry(ctx context.Context, req *apifault.FaultDetectRule, md *model.FaultDetectRule, opt model.OperationType) *model.RecordEntry { marshaler := jsonpb.Marshaler{} @@ -148,18 +91,14 @@ func faultDetectRuleRecordEntry(ctx context.Context, req *apifault.FaultDetectRu // createFaultDetectRule Create a FaultDetect rule func (s *Server) createFaultDetectRule(ctx context.Context, request *apifault.FaultDetectRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkFaultDetectRuleParams(request, false, true); resp != nil { - return resp - } data, err := api2FaultDetectRule(request) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponse(apimodel.Code_ParseException) } exists, err := s.storage.HasFaultDetectRuleByName(data.Name, data.Namespace) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { @@ -169,13 +108,13 @@ func (s *Server) createFaultDetectRule(ctx context.Context, request *apifault.Fa // 存储层操作 if err := s.storage.CreateFaultDetectRule(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } msg := fmt.Sprintf("create fault detect rule: id=%v, name=%v, namespace=%v", data.ID, request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, data, model.OCreate)) @@ -185,37 +124,29 @@ func (s *Server) createFaultDetectRule(ctx context.Context, request *apifault.Fa // updateFaultDetectRule Update a FaultDetect rule func (s *Server) updateFaultDetectRule(ctx context.Context, request *apifault.FaultDetectRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkFaultDetectRuleParams(request, true, true); resp != nil { - return resp - } - resp := s.checkFaultDetectRuleExists(request.GetId(), requestID) - if resp != nil { - return resp - } fdRuleId := &apifault.FaultDetectRule{Id: request.GetId()} fdRule, err := api2FaultDetectRule(request) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAnyDataResponse(apimodel.Code_ParseException, fdRuleId) } fdRule.ID = request.GetId() exists, err := s.storage.HasFaultDetectRuleByNameExcludeId(fdRule.Name, fdRule.Namespace, fdRule.ID) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { return api.NewAnyDataResponse(apimodel.Code_FaultDetectRuleExisted, fdRuleId) } if err := s.storage.UpdateFaultDetectRule(fdRule); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return storeError2AnyResponse(err, fdRuleId) } msg := fmt.Sprintf("update fault detect rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, fdRule, model.OUpdate)) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, fdRuleId) @@ -224,9 +155,6 @@ func (s *Server) updateFaultDetectRule(ctx context.Context, request *apifault.Fa // deleteFaultDetectRule Delete a FaultDetect rule func (s *Server) deleteFaultDetectRule(ctx context.Context, request *apifault.FaultDetectRule) *apiservice.Response { requestID := utils.ParseRequestID(ctx) - if resp := checkFaultDetectRuleParams(request, true, false); resp != nil { - return resp - } resp := s.checkFaultDetectRuleExists(request.GetId(), requestID) if resp != nil { if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundResource) { @@ -261,36 +189,20 @@ func (s *Server) checkFaultDetectRuleExists(id, requestID string) *apiservice.Re return nil } -var ( - // FaultDetectRuleFilters filter fault detect rule query parameters - FaultDetectRuleFilters = map[string]bool{ - "brief": true, - "offset": true, - "limit": true, - "id": true, - "name": true, - "namespace": true, - "service": true, - "serviceNamespace": true, - "dstService": true, - "dstNamespace": true, - "dstMethod": true, - "description": true, - } -) - func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - for key := range query { - if _, ok := FaultDetectRuleFilters[key]; !ok { - log.Errorf("params %s is not allowed in querying fault detect rule", key) - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - } - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - total, cbRules, err := s.storage.GetFaultDetectRules(query, offset, limit) + offset, limit, _ := utils.ParseOffsetAndLimit(query) + total, cbRules, err := s.caches.FaultDetector().Query(ctx, &cachetypes.FaultDetectArgs{ + ID: query["id"], + Name: query["name"], + Namespace: query["namespace"], + Service: query["service"], + ServiceNamespace: query["serviceNamespace"], + DstNamespace: query["dstNamespace"], + DstService: query["dstService"], + DstMethod: query["dstMethod"], + Offset: offset, + Limit: limit, + }) if err != nil { log.Errorf("get fault detect rules store err: %s", err.Error()) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) @@ -301,7 +213,7 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin for _, cbRule := range cbRules { cbRuleProto, err := faultDetectRule2api(cbRule) if nil != err { - log.Errorf("marshal circuitbreaker rule fail: %v", err) + log.Error("marshal circuitbreaker rule fail", utils.RequestID(ctx), zap.Error(err)) continue } if nil == cbRuleProto { @@ -309,7 +221,7 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin } err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto) if nil != err { - log.Errorf("add circuitbreaker rule as any data fail: %v", err) + log.Error("add circuitbreaker rule as any data fail", utils.RequestID(ctx), zap.Error(err)) continue } } diff --git a/service/healthcheck/check.go b/service/healthcheck/check.go index d2412b5b7..6a06cfe8d 100644 --- a/service/healthcheck/check.go +++ b/service/healthcheck/check.go @@ -55,8 +55,7 @@ type CheckScheduler struct { clientCheckIntervalSec int64 clientCheckTtlSec int64 - adoptInstancesChan chan AdoptEvent - ctx context.Context + ctx context.Context } // AdoptEvent is the event for adopt @@ -158,7 +157,6 @@ func newCheckScheduler(ctx context.Context, slotNum int, minCheckInterval time.D maxCheckIntervalSec: int64(maxCheckInterval.Seconds()), clientCheckIntervalSec: int64(clientCheckInterval.Seconds()), clientCheckTtlSec: int64(clientCheckTtl.Seconds()), - adoptInstancesChan: make(chan AdoptEvent, 1024), ctx: ctx, } return scheduler @@ -167,7 +165,6 @@ func newCheckScheduler(ctx context.Context, slotNum int, minCheckInterval time.D func (c *CheckScheduler) run(ctx context.Context) { go c.doCheckInstances(ctx) go c.doCheckClient(ctx) - go c.doAdopt(ctx) } func (c *CheckScheduler) doCheckInstances(ctx context.Context) { @@ -179,84 +176,6 @@ func (c *CheckScheduler) doCheckInstances(ctx context.Context) { log.Infof("[Health Check][Check]timeWheel has been stopped") } -const ( - batchAdoptInterval = 30 * time.Millisecond - batchAdoptCount = 30 -) - -func (c *CheckScheduler) doAdopt(ctx context.Context) { - instancesToAdd := make(map[string]bool) - instancesToRemove := make(map[string]bool) - var checker plugin.HealthChecker - ticker := time.NewTicker(batchAdoptInterval) - defer func() { - ticker.Stop() - }() - for { - select { - case event := <-c.adoptInstancesChan: - instanceId := event.InstanceId - if event.Add { - instancesToAdd[instanceId] = true - delete(instancesToRemove, instanceId) - } else { - instancesToRemove[instanceId] = true - delete(instancesToAdd, instanceId) - } - checker = event.Checker - if len(instancesToAdd) == batchAdoptCount { - instancesToAdd = c.processAdoptEvents(instancesToAdd, true, checker) - } - if len(instancesToRemove) == batchAdoptCount { - instancesToRemove = c.processAdoptEvents(instancesToRemove, false, checker) - } - case <-ticker.C: - if len(instancesToAdd) > 0 { - instancesToAdd = c.processAdoptEvents(instancesToAdd, true, checker) - } - if len(instancesToRemove) > 0 { - instancesToRemove = c.processAdoptEvents(instancesToRemove, false, checker) - } - case <-ctx.Done(): - log.Infof("[Health Check][Check]adopting routine has been stopped") - return - } - } -} - -func (c *CheckScheduler) processAdoptEvents( - instances map[string]bool, add bool, checker plugin.HealthChecker) map[string]bool { - instanceIds := make([]string, 0, len(instances)) - for id := range instances { - instanceIds = append(instanceIds, id) - } - log.Debug("[Health Check][Check] adopt event", zap.Any("instances", instanceIds), - zap.String("server", c.svr.localHost), zap.Bool("add", add)) - return instances -} - -func (c *CheckScheduler) addAdopting(instanceId string, checker plugin.HealthChecker) { - select { - case c.adoptInstancesChan <- AdoptEvent{ - InstanceId: instanceId, - Add: true, - Checker: checker}: - case <-c.ctx.Done(): - return - } -} - -func (c *CheckScheduler) removeAdopting(instanceId string, checker plugin.HealthChecker) { - select { - case c.adoptInstancesChan <- AdoptEvent{ - InstanceId: instanceId, - Add: false, - Checker: checker}: - case <-c.ctx.Done(): - return - } -} - func (c *CheckScheduler) upsertInstanceChecker(instanceWithChecker *InstanceWithChecker) (bool, *itemValue) { c.rwMutex.Lock() defer c.rwMutex.Unlock() @@ -343,7 +262,6 @@ func (c *CheckScheduler) UpsertInstance(instanceWithChecker *InstanceWithChecker if firstadd { return } - c.addAdopting(instValue.id, instValue.checker) instance := instanceWithChecker.instance log.Infof("[Health Check][Check]add check instance is %s, host is %s:%d", instance.ID(), instance.Host(), instance.Port()) @@ -352,11 +270,9 @@ func (c *CheckScheduler) UpsertInstance(instanceWithChecker *InstanceWithChecker // AddClient add client to check func (c *CheckScheduler) AddClient(clientWithChecker *ClientWithChecker) { - exists, instValue := c.putClientIfAbsent(clientWithChecker) - if exists { + if exists, _ := c.putClientIfAbsent(clientWithChecker); exists { return } - c.addAdopting(instValue.id, instValue.checker) client := clientWithChecker.client log.Infof("[Health Check][Check]add check client is %s, host is %s:%d", client.Proto().GetId().GetValue(), client.Proto().GetHost(), 0) @@ -529,9 +445,6 @@ func (c *CheckScheduler) DelClient(clientWithChecker *ClientWithChecker) { exists := c.delClientIfPresent(clientId) log.Infof("[Health Check][Check]remove check client is %s:%d, id is %s, exists is %v", client.Proto().GetHost().GetValue(), 0, clientId, exists) - if exists { - c.removeAdopting(clientId, clientWithChecker.checker) - } } // DelInstance del instance from check @@ -541,9 +454,6 @@ func (c *CheckScheduler) DelInstance(instanceWithChecker *InstanceWithChecker) { exists := c.delInstanceIfPresent(instanceId) log.Infof("[Health Check][Check]remove check instance is %s:%d, id is %s, exists is %v", instance.Host(), instance.Port(), instanceId, exists) - if exists { - c.removeAdopting(instanceId, instanceWithChecker.checker) - } } func (c *CheckScheduler) delInstanceIfPresent(instanceId string) bool { diff --git a/service/healthcheck/option.go b/service/healthcheck/option.go index 1c8c71da4..3ff47826f 100644 --- a/service/healthcheck/option.go +++ b/service/healthcheck/option.go @@ -81,15 +81,6 @@ func withChecker() serverOption { } } -// WithPlugins . -func WithPlugins() serverOption { - return func(svr *Server) error { - svr.history = plugin.GetHistory() - svr.discoverEvent = plugin.GetDiscoverEvent() - return nil - } -} - // withCacheProvider . func withCacheProvider() serverOption { return func(svr *Server) error { diff --git a/service/healthcheck/report.go b/service/healthcheck/report.go index 38471a8d3..2b297b50d 100644 --- a/service/healthcheck/report.go +++ b/service/healthcheck/report.go @@ -61,6 +61,7 @@ func (s *Server) checkInstanceExists(ctx context.Context, id string) (int64, *mo return -1, nil, apimodel.Code_ExecuteSuccess } if resp.Count > max404Count { + log.Errorf("[healthcheck] not found heartbeat record by id %s, count: %v", id, resp.Count) return resp.Count, nil, apimodel.Code_NotFoundResource } return resp.Count, nil, apimodel.Code_ExecuteSuccess diff --git a/service/healthcheck/server.go b/service/healthcheck/server.go index 466362a6e..9a75eaf5f 100644 --- a/service/healthcheck/server.go +++ b/service/healthcheck/server.go @@ -54,8 +54,6 @@ type Server struct { timeAdjuster *TimeAdjuster dispatcher *Dispatcher checkScheduler *CheckScheduler - history plugin.History - discoverEvent plugin.DiscoverChannel localHost string bc *batch.Controller serviceCache cachetypes.ServiceCache @@ -117,7 +115,6 @@ func initialize(ctx context.Context, hcOpt *Config, bc *batch.Controller) error } svr, err := NewHealthServer(ctx, hcOpt, - WithPlugins(), WithStore(storage), WithBatchController(bc), WithTimeAdjuster(newTimeAdjuster(ctx, storage)), @@ -212,16 +209,13 @@ func (s *Server) ListCheckerServer() []*model.Instance { // RecordHistory server对外提供history插件的简单封装 func (s *Server) RecordHistory(entry *model.RecordEntry) { // 如果插件没有初始化,那么不记录history - if s.history == nil { - return - } // 如果数据为空,则不需要打印了 if entry == nil { return } // 调用插件记录history - s.history.Record(entry) + plugin.GetHistory().Record(entry) } // publishInstanceEvent 发布服务事件 diff --git a/service/healthcheck/test_export.go b/service/healthcheck/test_export.go index 9920424e8..7f6210667 100644 --- a/service/healthcheck/test_export.go +++ b/service/healthcheck/test_export.go @@ -30,7 +30,6 @@ func TestInitialize(ctx context.Context, hcOpt *Config, bc *batch.Controller, testServer, err := NewHealthServer(ctx, hcOpt, WithStore(storage), WithBatchController(bc), - WithPlugins(), WithTimeAdjuster(newTimeAdjuster(ctx, storage)), ) if err != nil { diff --git a/service/instance_check_test.go b/service/instance_check_test.go index a6cf62b7b..ef179b5a9 100644 --- a/service/instance_check_test.go +++ b/service/instance_check_test.go @@ -23,11 +23,12 @@ import ( "testing" "time" - "github.com/polarismesh/polaris/cache" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/cache" ) func TestInstanceCheck(t *testing.T) { diff --git a/service/interceptor/auth/circuitbreaker_config_authability.go b/service/interceptor/auth/circuitbreaker_config.go similarity index 77% rename from service/interceptor/auth/circuitbreaker_config_authability.go rename to service/interceptor/auth/circuitbreaker_config.go index e8c215711..1067fd57a 100644 --- a/service/interceptor/auth/circuitbreaker_config_authability.go +++ b/service/interceptor/auth/circuitbreaker_config.go @@ -25,73 +25,73 @@ import ( ) // CreateCircuitBreakers creates circuit breakers -func (svr *ServerAuthAbility) CreateCircuitBreakers(ctx context.Context, +func (svr *Server) CreateCircuitBreakers(ctx context.Context, reqs []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse { return svr.nextSvr.CreateCircuitBreakers(ctx, reqs) } // CreateCircuitBreakerVersions creates circuit breaker versions -func (svr *ServerAuthAbility) CreateCircuitBreakerVersions(ctx context.Context, +func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, reqs []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse { return svr.nextSvr.CreateCircuitBreakerVersions(ctx, reqs) } // DeleteCircuitBreakers delete circuit breakers -func (svr *ServerAuthAbility) DeleteCircuitBreakers(ctx context.Context, +func (svr *Server) DeleteCircuitBreakers(ctx context.Context, reqs []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse { return svr.nextSvr.DeleteCircuitBreakers(ctx, reqs) } // UpdateCircuitBreakers update circuit breakers -func (svr *ServerAuthAbility) UpdateCircuitBreakers(ctx context.Context, +func (svr *Server) UpdateCircuitBreakers(ctx context.Context, reqs []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse { return svr.nextSvr.UpdateCircuitBreakers(ctx, reqs) } // ReleaseCircuitBreakers release circuit breakers -func (svr *ServerAuthAbility) ReleaseCircuitBreakers(ctx context.Context, +func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, reqs []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse { return svr.nextSvr.ReleaseCircuitBreakers(ctx, reqs) } // UnBindCircuitBreakers unbind circuit breakers -func (svr *ServerAuthAbility) UnBindCircuitBreakers(ctx context.Context, +func (svr *Server) UnBindCircuitBreakers(ctx context.Context, reqs []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse { return svr.nextSvr.UnBindCircuitBreakers(ctx, reqs) } // GetCircuitBreaker get circuit breaker -func (svr *ServerAuthAbility) GetCircuitBreaker(ctx context.Context, +func (svr *Server) GetCircuitBreaker(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetCircuitBreaker(ctx, query) } // GetCircuitBreakerVersions get circuit breaker versions -func (svr *ServerAuthAbility) GetCircuitBreakerVersions(ctx context.Context, +func (svr *Server) GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerVersions(ctx, query) } // GetMasterCircuitBreakers get master circuit breakers -func (svr *ServerAuthAbility) GetMasterCircuitBreakers(ctx context.Context, +func (svr *Server) GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetMasterCircuitBreakers(ctx, query) } // GetReleaseCircuitBreakers get release circuit breakers -func (svr *ServerAuthAbility) GetReleaseCircuitBreakers(ctx context.Context, +func (svr *Server) GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetReleaseCircuitBreakers(ctx, query) } // GetCircuitBreakerByService get circuit breaker by service -func (svr *ServerAuthAbility) GetCircuitBreakerByService(ctx context.Context, +func (svr *Server) GetCircuitBreakerByService(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerByService(ctx, query) } // GetCircuitBreakerToken get circuit breaker token -func (svr *ServerAuthAbility) GetCircuitBreakerToken( +func (svr *Server) GetCircuitBreakerToken( ctx context.Context, req *apifault.CircuitBreaker) *apiservice.Response { return svr.nextSvr.GetCircuitBreakerToken(ctx, req) } diff --git a/service/interceptor/auth/circuitbreaker_rule_authability.go b/service/interceptor/auth/circuitbreaker_rule.go similarity index 59% rename from service/interceptor/auth/circuitbreaker_rule_authability.go rename to service/interceptor/auth/circuitbreaker_rule.go index 37e5640d0..99201c302 100644 --- a/service/interceptor/auth/circuitbreaker_rule_authability.go +++ b/service/interceptor/auth/circuitbreaker_rule.go @@ -21,21 +21,22 @@ import ( "context" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) -func (svr *ServerAuthAbility) CreateCircuitBreakerRules( +func (svr *Server) CreateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - // TODO not support CircuitBreaker resource auth, so we set op is read - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, model.Read, "CreateCircuitBreakerRules") + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Create, authcommon.CreateCircuitBreakerRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -43,12 +44,11 @@ func (svr *ServerAuthAbility) CreateCircuitBreakerRules( return svr.nextSvr.CreateCircuitBreakerRules(ctx, request) } -func (svr *ServerAuthAbility) DeleteCircuitBreakerRules( +func (svr *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, model.Read, "DeleteCircuitBreakerRules") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Delete, authcommon.DeleteCircuitBreakerRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -56,12 +56,11 @@ func (svr *ServerAuthAbility) DeleteCircuitBreakerRules( return svr.nextSvr.DeleteCircuitBreakerRules(ctx, request) } -func (svr *ServerAuthAbility) EnableCircuitBreakerRules( +func (svr *Server) EnableCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, model.Read, "EnableCircuitBreakerRules") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.EnableCircuitBreakerRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -69,12 +68,11 @@ func (svr *ServerAuthAbility) EnableCircuitBreakerRules( return svr.nextSvr.EnableCircuitBreakerRules(ctx, request) } -func (svr *ServerAuthAbility) UpdateCircuitBreakerRules( +func (svr *Server) UpdateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, model.Read, "UpdateCircuitBreakerRules") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.UpdateCircuitBreakerRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -82,15 +80,23 @@ func (svr *ServerAuthAbility) UpdateCircuitBreakerRules( return svr.nextSvr.UpdateCircuitBreakerRules(ctx, request) } -func (svr *ServerAuthAbility) GetCircuitBreakerRules( +func (svr *Server) GetCircuitBreakerRules( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, nil, model.Read, "GetCircuitBreakerRules") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchQueryResponse(convertToErrCode(err)) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, nil, authcommon.Read, authcommon.DescribeCircuitBreakerRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + cachetypes.AppendCircuitBreakerRulePredicate(ctx, func(ctx context.Context, cbr *model.CircuitBreakerRule) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_CircuitBreakerRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, + }) + }) + return svr.nextSvr.GetCircuitBreakerRules(ctx, query) } diff --git a/service/interceptor/auth/client_v1_authability.go b/service/interceptor/auth/client_v1.go similarity index 62% rename from service/interceptor/auth/client_v1_authability.go rename to service/interceptor/auth/client_v1.go index fec85a0f6..fb1fc8dec 100644 --- a/service/interceptor/auth/client_v1_authability.go +++ b/service/interceptor/auth/client_v1.go @@ -25,17 +25,18 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // RegisterInstance create one instance -func (svr *ServerAuthAbility) RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { +func (svr *Server) RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { authCtx := svr.collectClientInstanceAuthContext( - ctx, []*apiservice.Instance{req}, model.Create, "RegisterInstance") + ctx, []*apiservice.Instance{req}, authcommon.Create, authcommon.RegisterInstance) - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } @@ -46,13 +47,13 @@ func (svr *ServerAuthAbility) RegisterInstance(ctx context.Context, req *apiserv } // DeregisterInstance delete onr instance -func (svr *ServerAuthAbility) DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { +func (svr *Server) DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { authCtx := svr.collectClientInstanceAuthContext( - ctx, []*apiservice.Instance{req}, model.Create, "DeregisterInstance") + ctx, []*apiservice.Instance{req}, authcommon.Create, authcommon.DeregisterInstance) - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } @@ -63,21 +64,21 @@ func (svr *ServerAuthAbility) DeregisterInstance(ctx context.Context, req *apise } // ReportClient is the interface for reporting client authability -func (svr *ServerAuthAbility) ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response { +func (svr *Server) ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response { return svr.nextSvr.ReportClient(ctx, req) } // ReportServiceContract . -func (svr *ServerAuthAbility) ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response { +func (svr *Server) ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response { authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{{ Name: wrapperspb.String(req.GetService()), Namespace: wrapperspb.String(req.GetNamespace()), - }}, model.Create, "ReportServiceContract") + }}, authcommon.Create, authcommon.ReportServiceContract) - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } @@ -87,21 +88,21 @@ func (svr *ServerAuthAbility) ReportServiceContract(ctx context.Context, req *ap } // GetPrometheusTargets Used for client acquisition service information -func (svr *ServerAuthAbility) GetPrometheusTargets(ctx context.Context, +func (svr *Server) GetPrometheusTargets(ctx context.Context, query map[string]string) *model.PrometheusDiscoveryResponse { return svr.nextSvr.GetPrometheusTargets(ctx, query) } // GetServiceWithCache is the interface for getting service with cache -func (svr *ServerAuthAbility) GetServiceWithCache( +func (svr *Server) GetServiceWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverServices") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverServices) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -112,14 +113,14 @@ func (svr *ServerAuthAbility) GetServiceWithCache( } // ServiceInstancesCache is the interface for getting service instances cache -func (svr *ServerAuthAbility) ServiceInstancesCache( +func (svr *Server) ServiceInstancesCache( ctx context.Context, filter *apiservice.DiscoverFilter, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverInstances") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverInstances) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -130,14 +131,14 @@ func (svr *ServerAuthAbility) ServiceInstancesCache( } // GetRoutingConfigWithCache is the interface for getting routing config with cache -func (svr *ServerAuthAbility) GetRoutingConfigWithCache( +func (svr *Server) GetRoutingConfigWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverRouterRule") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRouterRule) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -148,14 +149,14 @@ func (svr *ServerAuthAbility) GetRoutingConfigWithCache( } // GetRateLimitWithCache is the interface for getting rate limit with cache -func (svr *ServerAuthAbility) GetRateLimitWithCache( +func (svr *Server) GetRateLimitWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverRateLimit") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRateLimitRule) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -166,14 +167,14 @@ func (svr *ServerAuthAbility) GetRateLimitWithCache( } // GetCircuitBreakerWithCache is the interface for getting a circuit breaker with cache -func (svr *ServerAuthAbility) GetCircuitBreakerWithCache( +func (svr *Server) GetCircuitBreakerWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverCircuitBreaker") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverCircuitBreakerRule) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -183,14 +184,14 @@ func (svr *ServerAuthAbility) GetCircuitBreakerWithCache( return svr.nextSvr.GetCircuitBreakerWithCache(ctx, req) } -func (svr *ServerAuthAbility) GetFaultDetectWithCache( +func (svr *Server) GetFaultDetectWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverFaultDetect") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverFaultDetectRule) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -201,13 +202,13 @@ func (svr *ServerAuthAbility) GetFaultDetectWithCache( } // UpdateInstance update single instance -func (svr *ServerAuthAbility) UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { +func (svr *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { authCtx := svr.collectClientInstanceAuthContext( - ctx, []*apiservice.Instance{req}, model.Modify, "UpdateInstance") + ctx, []*apiservice.Instance{req}, authcommon.Modify, authcommon.UpdateInstance) - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } @@ -218,16 +219,16 @@ func (svr *ServerAuthAbility) UpdateInstance(ctx context.Context, req *apiservic } // GetServiceContractWithCache User Client Get ServiceContract Rule Information -func (svr *ServerAuthAbility) GetServiceContractWithCache(ctx context.Context, +func (svr *Server) GetServiceContractWithCache(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response { authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{{ Namespace: wrapperspb.String(req.Namespace), Name: wrapperspb.String(req.Service), - }}, model.Read, "GetServiceContractWithCache") + }}, authcommon.Read, authcommon.DiscoverServiceContract) - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewResponse(convertToErrCode(err)) + resp := api.NewResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } @@ -239,12 +240,12 @@ func (svr *ServerAuthAbility) GetServiceContractWithCache(ctx context.Context, } // GetLaneRuleWithCache fetch lane rules by client -func (svr *ServerAuthAbility) GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { +func (svr *Server) GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Read, "DiscoverLaneRule") - _, err := svr.policyMgr.GetAuthChecker().CheckClientPermission(authCtx) + ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverLaneRule) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) if err != nil { - resp := api.NewDiscoverResponse(convertToErrCode(err)) + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) resp.Info = utils.NewStringValue(err.Error()) return resp } diff --git a/service/interceptor/auth/faultdetect_config_authability.go b/service/interceptor/auth/faultdetect_config.go similarity index 57% rename from service/interceptor/auth/faultdetect_config_authability.go rename to service/interceptor/auth/faultdetect_config.go index fce58ac5f..ef4b09d76 100644 --- a/service/interceptor/auth/faultdetect_config_authability.go +++ b/service/interceptor/auth/faultdetect_config.go @@ -21,56 +21,68 @@ import ( "context" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) -func (svr *ServerAuthAbility) CreateFaultDetectRules( +func (svr *Server) CreateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, model.Read, "CreateFaultDetectRules") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.CreateFaultDetectRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return svr.nextSvr.CreateFaultDetectRules(ctx, request) } -func (svr *ServerAuthAbility) DeleteFaultDetectRules( +func (svr *Server) DeleteFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, model.Read, "DeleteFaultDetectRules") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.DeleteFaultDetectRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return svr.nextSvr.DeleteFaultDetectRules(ctx, request) } -func (svr *ServerAuthAbility) UpdateFaultDetectRules( +func (svr *Server) UpdateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, model.Read, "UpdateFaultDetectRules") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.UpdateFaultDetectRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return svr.nextSvr.UpdateFaultDetectRules(ctx, request) } -func (svr *ServerAuthAbility) GetFaultDetectRules( +func (svr *Server) GetFaultDetectRules( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, nil, model.Read, "GetFaultDetectRules") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(convertToErrCode(err)) + authCtx := svr.collectFaultDetectAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeFaultDetectRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + cachetypes.AppendFaultDetectRulePredicate(ctx, func(ctx context.Context, cbr *model.FaultDetectRule) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_FaultDetectRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, + }) + }) + return svr.nextSvr.GetFaultDetectRules(ctx, query) } diff --git a/service/interceptor/auth/instance_authability.go b/service/interceptor/auth/instance.go similarity index 58% rename from service/interceptor/auth/instance_authability.go rename to service/interceptor/auth/instance.go index 6accfc712..071d9f856 100644 --- a/service/interceptor/auth/instance_authability.go +++ b/service/interceptor/auth/instance.go @@ -24,19 +24,18 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateInstances create instances -func (svr *ServerAuthAbility) CreateInstances(ctx context.Context, +func (svr *Server) CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - authCtx := svr.collectInstanceAuthContext(ctx, reqs, model.Create, "CreateInstances") + authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateInstances) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) api.Collect(batchResp, resp) return batchResp @@ -49,13 +48,13 @@ func (svr *ServerAuthAbility) CreateInstances(ctx context.Context, } // DeleteInstances delete instances -func (svr *ServerAuthAbility) DeleteInstances(ctx context.Context, +func (svr *Server) DeleteInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - authCtx := svr.collectInstanceAuthContext(ctx, reqs, model.Delete, "DeleteInstances") + authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteInstances) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - resp := api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) api.Collect(batchResp, resp) return batchResp @@ -68,16 +67,16 @@ func (svr *ServerAuthAbility) DeleteInstances(ctx context.Context, } // DeleteInstancesByHost 目前只允许 super account 进行数据删除 -func (svr *ServerAuthAbility) DeleteInstancesByHost(ctx context.Context, +func (svr *Server) DeleteInstancesByHost(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - authCtx := svr.collectInstanceAuthContext(ctx, reqs, model.Delete, "DeleteInstancesByHost") + authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteInstancesByHost) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - if authcommon.ParseUserRole(ctx) == model.SubAccountUserRole { + if authcommon.ParseUserRole(ctx) == authcommon.SubAccountUserRole { ret := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) api.Collect(ret, api.NewResponse(apimodel.Code_NotAllowedAccess)) return ret @@ -87,13 +86,13 @@ func (svr *ServerAuthAbility) DeleteInstancesByHost(ctx context.Context, } // UpdateInstances update instances -func (svr *ServerAuthAbility) UpdateInstances(ctx context.Context, +func (svr *Server) UpdateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - authCtx := svr.collectInstanceAuthContext(ctx, reqs, model.Modify, "UpdateInstances") + authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateInstances) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -103,13 +102,13 @@ func (svr *ServerAuthAbility) UpdateInstances(ctx context.Context, } // UpdateInstancesIsolate update instances -func (svr *ServerAuthAbility) UpdateInstancesIsolate(ctx context.Context, +func (svr *Server) UpdateInstancesIsolate(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - authCtx := svr.collectInstanceAuthContext(ctx, reqs, model.Modify, "UpdateInstancesIsolate") + authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateInstancesIsolate) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -119,12 +118,12 @@ func (svr *ServerAuthAbility) UpdateInstancesIsolate(ctx context.Context, } // GetInstances get instances -func (svr *ServerAuthAbility) GetInstances(ctx context.Context, +func (svr *Server) GetInstances(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectInstanceAuthContext(ctx, nil, model.Read, "GetInstances") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + authCtx := svr.collectInstanceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeInstances) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -134,11 +133,11 @@ func (svr *ServerAuthAbility) GetInstances(ctx context.Context, } // GetInstancesCount get instances to count -func (svr *ServerAuthAbility) GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse { - authCtx := svr.collectInstanceAuthContext(ctx, nil, model.Read, "GetInstancesCount") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) +func (svr *Server) GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse { + authCtx := svr.collectInstanceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeInstancesCount) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -146,13 +145,13 @@ func (svr *ServerAuthAbility) GetInstancesCount(ctx context.Context) *apiservice return svr.nextSvr.GetInstancesCount(ctx) } -func (svr *ServerAuthAbility) GetInstanceLabels(ctx context.Context, +func (svr *Server) GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response { - authCtx := svr.collectInstanceAuthContext(ctx, nil, model.Read, "GetInstanceLabels") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + authCtx := svr.collectInstanceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeInstanceLabels) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) diff --git a/service/interceptor/auth/l5_service_authability.go b/service/interceptor/auth/l5_service.go similarity index 85% rename from service/interceptor/auth/l5_service_authability.go rename to service/interceptor/auth/l5_service.go index 4b3eac1ee..bbef6ae1c 100644 --- a/service/interceptor/auth/l5_service_authability.go +++ b/service/interceptor/auth/l5_service.go @@ -28,12 +28,12 @@ import ( // Stat::instance()->inc_sync_req_cnt(); // 保存client的IP,该函数只是存储到本地的缓存中 // Stat::instance()->add_agent(sbac.agent_ip()); -func (svr *ServerAuthAbility) SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) ( +func (svr *Server) SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) ( *l5.Cl5SyncByAgentAckCmd, error) { return svr.nextSvr.SyncByAgentCmd(ctx, sbac) } // RegisterByNameCmd 根据名字获取sid信息 -func (svr *ServerAuthAbility) RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) { +func (svr *Server) RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) { return svr.nextSvr.RegisterByNameCmd(rbnc) } diff --git a/service/interceptor/auth/ratelimit_config_authability.go b/service/interceptor/auth/ratelimit_config.go similarity index 67% rename from service/interceptor/auth/ratelimit_config_authability.go rename to service/interceptor/auth/ratelimit_config.go index 7f2a30487..f41a09cac 100644 --- a/service/interceptor/auth/ratelimit_config_authability.go +++ b/service/interceptor/auth/ratelimit_config.go @@ -21,20 +21,23 @@ import ( "context" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateRateLimits creates rate limits for a namespace. -func (svr *ServerAuthAbility) CreateRateLimits( +func (svr *Server) CreateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, reqs, model.Create, "CreateRateLimits") + authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateRateLimitRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -46,11 +49,11 @@ func (svr *ServerAuthAbility) CreateRateLimits( } // DeleteRateLimits deletes rate limits for a namespace. -func (svr *ServerAuthAbility) DeleteRateLimits( +func (svr *Server) DeleteRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, reqs, model.Delete, "DeleteRateLimits") + authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteRateLimitRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -62,11 +65,11 @@ func (svr *ServerAuthAbility) DeleteRateLimits( } // UpdateRateLimits updates rate limits for a namespace. -func (svr *ServerAuthAbility) UpdateRateLimits( +func (svr *Server) UpdateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, reqs, model.Modify, "UpdateRateLimits") + authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateRateLimitRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -78,11 +81,11 @@ func (svr *ServerAuthAbility) UpdateRateLimits( } // EnableRateLimits 启用限流规则 -func (svr *ServerAuthAbility) EnableRateLimits( +func (svr *Server) EnableRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, nil, model.Read, "EnableRateLimits") + authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.EnableRateLimitRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -94,11 +97,11 @@ func (svr *ServerAuthAbility) EnableRateLimits( } // GetRateLimits gets rate limits for a namespace. -func (svr *ServerAuthAbility) GetRateLimits( +func (svr *Server) GetRateLimits( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, nil, model.Read, "GetRateLimits") + authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeRateLimitRules) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -106,5 +109,13 @@ func (svr *ServerAuthAbility) GetRateLimits( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + cachetypes.AppendRatelimitRulePredicate(ctx, func(ctx context.Context, cbr *model.RateLimit) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_RateLimitRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, + }) + }) + return svr.nextSvr.GetRateLimits(ctx, query) } diff --git a/service/interceptor/auth/resource_listen.go b/service/interceptor/auth/resource_listen.go index 1552dc430..080b364ff 100644 --- a/service/interceptor/auth/resource_listen.go +++ b/service/interceptor/auth/resource_listen.go @@ -23,17 +23,18 @@ import ( apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/service" ) // Before this function is called before the resource operation -func (svr *ServerAuthAbility) Before(ctx context.Context, resourceType model.Resource) { +func (svr *Server) Before(ctx context.Context, resourceType model.Resource) { // do nothing } // After this function is called after the resource operation -func (svr *ServerAuthAbility) After(ctx context.Context, resourceType model.Resource, res *service.ResourceEvent) error { +func (svr *Server) After(ctx context.Context, resourceType model.Resource, res *service.ResourceEvent) error { switch resourceType { case model.RService: return svr.onServiceResource(ctx, res) @@ -43,15 +44,16 @@ func (svr *ServerAuthAbility) After(ctx context.Context, resourceType model.Reso } // onServiceResource 服务资源的处理,只处理服务,namespace 只由 namespace 相关的进行处理, -func (svr *ServerAuthAbility) onServiceResource(ctx context.Context, res *service.ResourceEvent) error { - authCtx := ctx.Value(utils.ContextAuthContextKey).(*model.AcquireContext) +func (svr *Server) onServiceResource(ctx context.Context, res *service.ResourceEvent) error { + authCtx := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) ownerId := utils.ParseOwnerID(ctx) - authCtx.SetAttachment(model.ResourceAttachmentKey, map[apisecurity.ResourceType][]model.ResourceEntry{ + authCtx.SetAttachment(authcommon.ResourceAttachmentKey, map[apisecurity.ResourceType][]authcommon.ResourceEntry{ apisecurity.ResourceType_Services: { { - ID: res.Service.ID, - Owner: ownerId, + ID: res.Service.ID, + Owner: ownerId, + Metadata: res.Service.Meta, }, }, }) @@ -62,11 +64,11 @@ func (svr *ServerAuthAbility) onServiceResource(ctx context.Context, res *servic groups := utils.ConvertStringValuesToSlice(res.ReqService.GroupIds) removeGroups := utils.ConvertStringValuesToSlice(res.ReqService.RemoveGroupIds) - authCtx.SetAttachment(model.LinkUsersKey, utils.StringSliceDeDuplication(users)) - authCtx.SetAttachment(model.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) + authCtx.SetAttachment(authcommon.LinkUsersKey, utils.StringSliceDeDuplication(users)) + authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) - authCtx.SetAttachment(model.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) - authCtx.SetAttachment(model.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) + authCtx.SetAttachment(authcommon.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) + authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) - return svr.policyMgr.AfterResourceOperation(authCtx) + return svr.policySvr.AfterResourceOperation(authCtx) } diff --git a/service/interceptor/auth/routing_config_v1_authability.go b/service/interceptor/auth/routing_config_v1.go similarity index 75% rename from service/interceptor/auth/routing_config_v1_authability.go rename to service/interceptor/auth/routing_config_v1.go index ae79d9051..32066dbbe 100644 --- a/service/interceptor/auth/routing_config_v1_authability.go +++ b/service/interceptor/auth/routing_config_v1.go @@ -25,16 +25,16 @@ import ( apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateRoutingConfigs creates routing configs -func (svr *ServerAuthAbility) CreateRoutingConfigs( +func (svr *Server) CreateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, model.Create, "CreateRoutingConfigs") + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Create, "CreateRoutingConfigs") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -46,11 +46,11 @@ func (svr *ServerAuthAbility) CreateRoutingConfigs( } // DeleteRoutingConfigs deletes routing configs -func (svr *ServerAuthAbility) DeleteRoutingConfigs( +func (svr *Server) DeleteRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, model.Delete, "DeleteRoutingConfigs") + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Delete, "DeleteRoutingConfigs") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -62,11 +62,11 @@ func (svr *ServerAuthAbility) DeleteRoutingConfigs( } // UpdateRoutingConfigs updates routing configs -func (svr *ServerAuthAbility) UpdateRoutingConfigs( +func (svr *Server) UpdateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, model.Modify, "UpdateRoutingConfigs") + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Modify, "UpdateRoutingConfigs") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } @@ -78,11 +78,11 @@ func (svr *ServerAuthAbility) UpdateRoutingConfigs( } // GetRoutingConfigs gets routing configs -func (svr *ServerAuthAbility) GetRoutingConfigs( +func (svr *Server) GetRoutingConfigs( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, nil, model.Read, "GetRoutingConfigs") + authCtx := svr.collectRouteRuleAuthContext(ctx, nil, authcommon.Read, "GetRoutingConfigs") - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } diff --git a/service/interceptor/auth/routing_config_v2_authability.go b/service/interceptor/auth/routing_config_v2.go similarity index 54% rename from service/interceptor/auth/routing_config_v2_authability.go rename to service/interceptor/auth/routing_config_v2.go index 66470627f..3ed128dad 100644 --- a/service/interceptor/auth/routing_config_v2_authability.go +++ b/service/interceptor/auth/routing_config_v2.go @@ -20,22 +20,25 @@ package service_auth import ( "context" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateRoutingConfigsV2 批量创建路由配置 -func (svr *ServerAuthAbility) CreateRoutingConfigsV2(ctx context.Context, +func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { // TODO not support RouteRuleV2 resource auth, so we set op is read - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, model.Read, "CreateRoutingConfigsV2") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.CreateRouteRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -43,12 +46,12 @@ func (svr *ServerAuthAbility) CreateRoutingConfigsV2(ctx context.Context, } // DeleteRoutingConfigsV2 批量删除路由配置 -func (svr *ServerAuthAbility) DeleteRoutingConfigsV2(ctx context.Context, +func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, model.Read, "DeleteRoutingConfigsV2") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.DeleteRouteRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -56,12 +59,12 @@ func (svr *ServerAuthAbility) DeleteRoutingConfigsV2(ctx context.Context, } // UpdateRoutingConfigsV2 批量更新路由配置 -func (svr *ServerAuthAbility) UpdateRoutingConfigsV2(ctx context.Context, +func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, model.Read, "UpdateRoutingConfigsV2") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.UpdateRouteRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -69,12 +72,12 @@ func (svr *ServerAuthAbility) UpdateRoutingConfigsV2(ctx context.Context, } // EnableRoutings batch enable routing rules -func (svr *ServerAuthAbility) EnableRoutings(ctx context.Context, +func (svr *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, model.Read, "EnableRoutings") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.EnableRouteRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -82,8 +85,22 @@ func (svr *ServerAuthAbility) EnableRoutings(ctx context.Context, } // QueryRoutingConfigsV2 提供给OSS的查询路由配置的接口 -func (svr *ServerAuthAbility) QueryRoutingConfigsV2(ctx context.Context, +func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + authCtx := svr.collectRouteRuleV2AuthContext(ctx, nil, authcommon.Read, authcommon.DescribeRouteRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + } + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + cachetypes.AppendRouterRulePredicate(ctx, func(ctx context.Context, cbr *model.ExtendRouterConfig) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_RouteRules, + ID: cbr.ID, + Metadata: cbr.Metadata, + }) + }) return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) } diff --git a/service/interceptor/auth/server.go b/service/interceptor/auth/server.go new file mode 100644 index 000000000..bb6859493 --- /dev/null +++ b/service/interceptor/auth/server.go @@ -0,0 +1,477 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service_auth + +import ( + "context" + + apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "go.uber.org/zap" + + "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/service" +) + +// Server 带有鉴权能力的 discoverServer +// +// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 +type Server struct { + nextSvr service.DiscoverServer + userMgn auth.UserServer + policySvr auth.StrategyServer +} + +func NewServerAuthAbility(nextSvr service.DiscoverServer, + userMgn auth.UserServer, policySvr auth.StrategyServer) service.DiscoverServer { + proxy := &Server{ + nextSvr: nextSvr, + userMgn: userMgn, + policySvr: policySvr, + } + + actualSvr, ok := nextSvr.(*service.Server) + if ok { + actualSvr.SetResourceHooks(proxy) + } + return proxy +} + +// Cache Get cache management +func (svr *Server) Cache() cachetypes.CacheManager { + return svr.nextSvr.Cache() +} + +// GetServiceInstanceRevision 获取服务实例的版本号 +func (svr *Server) GetServiceInstanceRevision(serviceID string, + instances []*model.Instance) (string, error) { + return svr.nextSvr.GetServiceInstanceRevision(serviceID, instances) +} + +// collectServiceAuthContext 对于服务的处理,收集所有的与鉴权的相关信息 +// +// @receiver svr Server +// @param ctx 请求上下文 ctx +// @param req 实际请求对象 +// @param resourceOp 该接口的数据操作类型 +// @return *authcommon.AcquireContext 返回鉴权上下文 +func (svr *Server) collectServiceAuthContext(ctx context.Context, req []*apiservice.Service, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryServiceResource(req)), + ) +} + +// collectServiceAliasAuthContext 对于服务别名的处理,收集所有的与鉴权的相关信息 +// +// @receiver svr Server +// @param ctx 请求上下文 ctx +// @param req 实际请求对象 +// @param resourceOp 该接口的数据操作类型 +// @return *authcommon.AcquireContext 返回鉴权上下文 +func (svr *Server) collectServiceAliasAuthContext(ctx context.Context, req []*apiservice.ServiceAlias, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryServiceAliasResource(req)), + ) +} + +// collectInstanceAuthContext 对于服务实例的处理,收集所有的与鉴权的相关信息 +// +// @receiver svr Server +// @param ctx 请求上下文 ctx +// @param req 实际请求对象 +// @param resourceOp 该接口的数据操作类型 +// @return *authcommon.AcquireContext 返回鉴权上下文 +func (svr *Server) collectInstanceAuthContext(ctx context.Context, req []*apiservice.Instance, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryInstanceResource(req)), + ) +} + +// collectClientInstanceAuthContext 对于服务实例的处理,收集所有的与鉴权的相关信息 +func (svr *Server) collectClientInstanceAuthContext(ctx context.Context, req []*apiservice.Instance, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithFromClient(), + authcommon.WithAccessResources(svr.queryInstanceResource(req)), + ) +} + +// collectRouteRuleAuthContext 对于服务路由规则的处理,收集所有的与鉴权的相关信息 +// +// @receiver svr Server +// @param ctx 请求上下文 ctx +// @param req 实际请求对象 +// @param resourceOp 该接口的数据操作类型 +// @return *authcommon.AcquireContext 返回鉴权上下文 +func (svr *Server) collectRouteRuleAuthContext(ctx context.Context, req []*apitraffic.Routing, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryRouteRuleResource(req)), + ) +} + +// collectRateLimitAuthContext 对于服务限流规则的处理,收集所有的与鉴权的相关信息 +// +// @receiver svr Server +// @param ctx 请求上下文 ctx +// @param req 实际请求对象 +// @param resourceOp 该接口的数据操作类型 +// @return *authcommon.AcquireContext 返回鉴权上下文 +func (svr *Server) collectRateLimitAuthContext(ctx context.Context, req []*apitraffic.Rule, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + + resources := make([]authcommon.ResourceEntry, 0, len(req)) + for i := range req { + saveRule := svr.Cache().RateLimit().GetRule(req[i].GetId().GetValue()) + if saveRule != nil { + resources = append(resources, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_RouteRules, + ID: saveRule.ID, + Metadata: saveRule.Proto.Metadata, + }) + } + } + + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_RateLimitRules: resources, + }), + ) +} + +// collectRouteRuleV2AuthContext 收集路由v2规则 +func (svr *Server) collectRouteRuleV2AuthContext(ctx context.Context, req []*apitraffic.RouteRule, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + + resources := make([]authcommon.ResourceEntry, 0, len(req)) + for i := range req { + saveRule := svr.Cache().RoutingConfig().GetRule(req[i].GetId()) + if saveRule != nil { + resources = append(resources, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_RouteRules, + ID: saveRule.ID, + Metadata: saveRule.Metadata, + }) + } + } + + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_RouteRules: resources, + }), + ) +} + +// collectCircuitBreakerRuleV2AuthContext 收集熔断v2规则 +func (svr *Server) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, + req []*apifault.CircuitBreakerRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + + resources := make([]authcommon.ResourceEntry, 0, len(req)) + for i := range req { + saveRule := svr.Cache().CircuitBreaker().GetRule(req[i].GetId()) + if saveRule != nil { + resources = append(resources, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_CircuitBreakerRules, + ID: saveRule.ID, + Metadata: saveRule.Proto.Metadata, + }) + } + } + + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_CircuitBreakerRules: resources, + }), + ) +} + +// collectFaultDetectAuthContext 收集主动探测规则 +func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, + req []*apifault.FaultDetectRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + + resources := make([]authcommon.ResourceEntry, 0, len(req)) + for i := range req { + saveRule := svr.Cache().FaultDetector().GetRule(req[i].GetId()) + if saveRule != nil { + resources = append(resources, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_FaultDetectRules, + ID: saveRule.ID, + Metadata: saveRule.Proto.Metadata, + }) + } + } + + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_FaultDetectRules: resources, + }), + ) +} + +// queryServiceResource 根据所给的 service 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryServiceResource( + req []*apiservice.Service) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + svcName := req[index].GetName().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + names.Add(svcNamespace) + svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) + if svc != nil { + svcSet.Store(svc.ID, svc) + } + } + + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect service access res", zap.Any("res", ret)) + } + return ret +} + +// queryServiceAliasResource 根据所给的 servicealias 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryServiceAliasResource( + req []*apiservice.ServiceAlias) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + refSvcName := req[index].GetService().GetValue() + refSvcNamespace := req[index].GetNamespace().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + names.Add(svcNamespace) + refSvc := svr.Cache().Service().GetServiceByName(refSvcName, refSvcNamespace) + if refSvc != nil { + svcSet.Store(refSvc.ID, refSvc) + } + } + + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect service alias access res", zap.Any("res", ret)) + } + return ret +} + +// queryInstanceResource 根据所给的 instances 信息,收集对应的 ResourceEntry 列表 +// 由于实例是注册到服务下的,因此只需要判断,是否有对应服务的权限即可 +func (svr *Server) queryInstanceResource( + req []*apiservice.Instance) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + svcName := req[index].GetService().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + item := req[index] + if svcNamespace != "" && svcName != "" { + svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) + if svc != nil { + svcSet.Store(svc.ID, svc) + } else { + names.Add(svcNamespace) + } + } else { + ins := svr.Cache().Instance().GetInstance(item.GetId().GetValue()) + if ins != nil { + svc := svr.Cache().Service().GetServiceByID(ins.ServiceID) + if svc != nil { + svcSet.Store(svc.ID, svc) + } else { + names.Add(svcNamespace) + } + } + } + } + + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect instance access res", zap.Any("res", ret)) + } + return ret +} + +// queryCircuitBreakerResource 根据所给的 CircuitBreaker 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryCircuitBreakerResource( + req []*apifault.CircuitBreaker) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + svcName := req[index].GetService().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) + if svc != nil { + svcSet.Store(svc.ID, svc) + } + } + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect circuit-breaker access res", zap.Any("res", ret)) + } + return ret +} + +// queryRouteRuleResource 根据所给的 RouteRule 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryRouteRuleResource( + req []*apitraffic.Routing) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + svcName := req[index].GetService().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) + if svc != nil { + svcSet.Store(svc.ID, svc) + } + } + + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect route-rule access res", zap.Any("res", ret)) + } + return ret +} + +// queryRateLimitConfigResource 根据所给的 RateLimit 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryRateLimitConfigResource( + req []*apitraffic.Rule) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return make(map[apisecurity.ResourceType][]authcommon.ResourceEntry) + } + + names := utils.NewSet[string]() + svcSet := utils.NewMap[string, *model.Service]() + + for index := range req { + svcName := req[index].GetService().GetValue() + svcNamespace := req[index].GetNamespace().GetValue() + svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) + if svc != nil { + svcSet.Store(svc.ID, svc) + } + } + + ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) + if authLog.DebugEnabled() { + authLog.Debug("[Auth][Server] collect rate-limit access res", zap.Any("res", ret)) + } + return ret +} + +// convertToDiscoverResourceEntryMaps 通用方法,进行转换为期望的、服务相关的 ResourceEntry +func (svr *Server) convertToDiscoverResourceEntryMaps(nsSet *utils.Set[string], + svcSet *utils.Map[string, *model.Service]) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + var ( + param = nsSet.ToSlice() + nsArr = svr.Cache().Namespace().GetNamespacesByName(param) + nsRet = make([]authcommon.ResourceEntry, 0, len(nsArr)) + ) + for index := range nsArr { + ns := nsArr[index] + nsRet = append(nsRet, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Namespaces, + ID: ns.Name, + Owner: ns.Owner, + Metadata: ns.Metadata, + }) + } + + svcRet := make([]authcommon.ResourceEntry, 0, svcSet.Len()) + svcSet.Range(func(key string, svc *model.Service) { + svcRet = append(svcRet, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Services, + ID: svc.ID, + Owner: svc.Owner, + Metadata: svc.Meta, + }) + }) + + return map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Namespaces: nsRet, + apisecurity.ResourceType_Services: svcRet, + } +} diff --git a/service/interceptor/auth/server_authability.go b/service/interceptor/auth/server_authability.go deleted file mode 100644 index b70909509..000000000 --- a/service/interceptor/auth/server_authability.go +++ /dev/null @@ -1,488 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package service_auth - -import ( - "context" - "errors" - - apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "go.uber.org/zap" - - "github.com/polarismesh/polaris/auth" - cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/service" -) - -// ServerAuthAbility 带有鉴权能力的 discoverServer -// -// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 -type ServerAuthAbility struct { - nextSvr service.DiscoverServer - userMgn auth.UserServer - policyMgr auth.StrategyServer -} - -func NewServerAuthAbility(nextSvr service.DiscoverServer, - userMgn auth.UserServer, policyMgr auth.StrategyServer) service.DiscoverServer { - proxy := &ServerAuthAbility{ - nextSvr: nextSvr, - userMgn: userMgn, - policyMgr: policyMgr, - } - - actualSvr, ok := nextSvr.(*service.Server) - if ok { - actualSvr.SetResourceHooks(proxy) - } - return proxy -} - -// Cache Get cache management -func (svr *ServerAuthAbility) Cache() cachetypes.CacheManager { - return svr.nextSvr.Cache() -} - -// GetServiceInstanceRevision 获取服务实例的版本号 -func (svr *ServerAuthAbility) GetServiceInstanceRevision(serviceID string, - instances []*model.Instance) (string, error) { - return svr.nextSvr.GetServiceInstanceRevision(serviceID, instances) -} - -// collectServiceAuthContext 对于服务的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectServiceAuthContext(ctx context.Context, req []*apiservice.Service, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryServiceResource(req)), - ) -} - -// collectServiceAliasAuthContext 对于服务别名的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectServiceAliasAuthContext(ctx context.Context, req []*apiservice.ServiceAlias, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryServiceAliasResource(req)), - ) -} - -// collectInstanceAuthContext 对于服务实例的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectInstanceAuthContext(ctx context.Context, req []*apiservice.Instance, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryInstanceResource(req)), - ) -} - -// collectClientInstanceAuthContext 对于服务实例的处理,收集所有的与鉴权的相关信息 -func (svr *ServerAuthAbility) collectClientInstanceAuthContext(ctx context.Context, req []*apiservice.Instance, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithFromClient(), - model.WithAccessResources(svr.queryInstanceResource(req)), - ) -} - -// collectCircuitBreakerAuthContext 对于服务熔断的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectCircuitBreakerAuthContext(ctx context.Context, req []*apifault.CircuitBreaker, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryCircuitBreakerResource(req)), - ) -} - -// collectCircuitBreakerReleaseAuthContext -// -// @receiver svr -// @param ctx -// @param req -// @param resourceOp -// @return *model.AcquireContext -func (svr *ServerAuthAbility) collectCircuitBreakerReleaseAuthContext(ctx context.Context, - req []*apiservice.ConfigRelease, resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryCircuitBreakerReleaseResource(req)), - ) -} - -// collectRouteRuleAuthContext 对于服务路由规则的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectRouteRuleAuthContext(ctx context.Context, req []*apitraffic.Routing, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryRouteRuleResource(req)), - ) -} - -// collectRateLimitAuthContext 对于服务限流规则的处理,收集所有的与鉴权的相关信息 -// -// @receiver svr ServerAuthAbility -// @param ctx 请求上下文 ctx -// @param req 实际请求对象 -// @param resourceOp 该接口的数据操作类型 -// @return *model.AcquireContext 返回鉴权上下文 -func (svr *ServerAuthAbility) collectRateLimitAuthContext(ctx context.Context, req []*apitraffic.Rule, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(svr.queryRateLimitConfigResource(req)), - ) -} - -// collectRouteRuleV2AuthContext 收集路由v2规则 -func (svr *ServerAuthAbility) collectRouteRuleV2AuthContext(ctx context.Context, req []*apitraffic.RouteRule, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{}), - ) -} - -// collectRouteRuleV2AuthContext 收集熔断v2规则 -func (svr *ServerAuthAbility) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, - req []*apifault.CircuitBreakerRule, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{}), - ) -} - -// collectRouteRuleV2AuthContext 收集主动探测规则 -func (svr *ServerAuthAbility) collectFaultDetectAuthContext(ctx context.Context, - req []*apifault.FaultDetectRule, - resourceOp model.ResourceOperation, methodName string) *model.AcquireContext { - return model.NewAcquireContext( - model.WithRequestContext(ctx), - model.WithOperation(resourceOp), - model.WithModule(model.DiscoverModule), - model.WithMethod(methodName), - model.WithAccessResources(map[apisecurity.ResourceType][]model.ResourceEntry{}), - ) -} - -// queryServiceResource 根据所给的 service 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryServiceResource( - req []*apiservice.Service) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetName().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - names.Add(svcNamespace) - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect service access res", zap.Any("res", ret)) - } - return ret -} - -// queryServiceAliasResource 根据所给的 servicealias 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryServiceAliasResource( - req []*apiservice.ServiceAlias) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - refSvcName := req[index].GetService().GetValue() - refSvcNamespace := req[index].GetNamespace().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - names.Add(svcNamespace) - refSvc := svr.Cache().Service().GetServiceByName(refSvcName, refSvcNamespace) - if refSvc != nil { - svcSet.Store(refSvc.ID, refSvc) - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect service alias access res", zap.Any("res", ret)) - } - return ret -} - -// queryInstanceResource 根据所给的 instances 信息,收集对应的 ResourceEntry 列表 -// 由于实例是注册到服务下的,因此只需要判断,是否有对应服务的权限即可 -func (svr *ServerAuthAbility) queryInstanceResource( - req []*apiservice.Instance) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetService().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - item := req[index] - if svcNamespace != "" && svcName != "" { - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } else { - names.Add(svcNamespace) - } - } else { - ins := svr.Cache().Instance().GetInstance(item.GetId().GetValue()) - if ins != nil { - svc := svr.Cache().Service().GetServiceByID(ins.ServiceID) - if svc != nil { - svcSet.Store(svc.ID, svc) - } else { - names.Add(svcNamespace) - } - } - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect instance access res", zap.Any("res", ret)) - } - return ret -} - -// queryCircuitBreakerResource 根据所给的 CircuitBreaker 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryCircuitBreakerResource( - req []*apifault.CircuitBreaker) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetService().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } - } - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect circuit-breaker access res", zap.Any("res", ret)) - } - return ret -} - -// queryCircuitBreakerReleaseResource 根据所给的 CircuitBreakerRelease 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryCircuitBreakerReleaseResource( - req []*apiservice.ConfigRelease) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetService().GetName().GetValue() - svcNamespace := req[index].GetService().GetNamespace().GetValue() - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect circuit-breaker-release access res", zap.Any("res", ret)) - } - return ret -} - -// queryRouteRuleResource 根据所给的 RouteRule 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryRouteRuleResource( - req []*apitraffic.Routing) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetService().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect route-rule access res", zap.Any("res", ret)) - } - return ret -} - -// queryRateLimitConfigResource 根据所给的 RateLimit 信息,收集对应的 ResourceEntry 列表 -func (svr *ServerAuthAbility) queryRateLimitConfigResource( - req []*apitraffic.Rule) map[apisecurity.ResourceType][]model.ResourceEntry { - if len(req) == 0 { - return make(map[apisecurity.ResourceType][]model.ResourceEntry) - } - - names := utils.NewSet[string]() - svcSet := utils.NewMap[string, *model.Service]() - - for index := range req { - svcName := req[index].GetService().GetValue() - svcNamespace := req[index].GetNamespace().GetValue() - svc := svr.Cache().Service().GetServiceByName(svcName, svcNamespace) - if svc != nil { - svcSet.Store(svc.ID, svc) - } - } - - ret := svr.convertToDiscoverResourceEntryMaps(names, svcSet) - if authLog.DebugEnabled() { - authLog.Debug("[Auth][Server] collect rate-limit access res", zap.Any("res", ret)) - } - return ret -} - -// convertToDiscoverResourceEntryMaps 通用方法,进行转换为期望的、服务相关的 ResourceEntry -func (svr *ServerAuthAbility) convertToDiscoverResourceEntryMaps(nsSet *utils.Set[string], - svcSet *utils.Map[string, *model.Service]) map[apisecurity.ResourceType][]model.ResourceEntry { - var ( - param = nsSet.ToSlice() - nsArr = svr.Cache().Namespace().GetNamespacesByName(param) - nsRet = make([]model.ResourceEntry, 0, len(nsArr)) - ) - for index := range nsArr { - ns := nsArr[index] - nsRet = append(nsRet, model.ResourceEntry{ - ID: ns.Name, - Owner: ns.Owner, - }) - } - - svcRet := make([]model.ResourceEntry, 0, svcSet.Len()) - svcSet.Range(func(key string, svc *model.Service) { - svcRet = append(svcRet, model.ResourceEntry{ - ID: svc.ID, - Owner: svc.Owner, - }) - }) - - return map[apisecurity.ResourceType][]model.ResourceEntry{ - apisecurity.ResourceType_Namespaces: nsRet, - apisecurity.ResourceType_Services: svcRet, - } -} - -func convertToErrCode(err error) apimodel.Code { - if errors.Is(err, model.ErrorTokenNotExist) { - return apimodel.Code_TokenNotExisted - } - if errors.Is(err, model.ErrorTokenDisabled) { - return apimodel.Code_TokenDisabled - } - return apimodel.Code_NotAllowedAccess -} diff --git a/service/interceptor/auth/service_authability.go b/service/interceptor/auth/service.go similarity index 57% rename from service/interceptor/auth/service_authability.go rename to service/interceptor/auth/service.go index 84b611a3f..864a33368 100644 --- a/service/interceptor/auth/service_authability.go +++ b/service/interceptor/auth/service.go @@ -20,22 +20,25 @@ package service_auth import ( "context" + "github.com/polarismesh/specification/source/go/api/v1/security" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateServices 批量创建服务 -func (svr *ServerAuthAbility) CreateServices( +func (svr *Server) CreateServices( ctx context.Context, reqs []*apiservice.Service) *apiservice.BatchWriteResponse { - authCtx := svr.collectServiceAuthContext(ctx, reqs, model.Create, "CreateServices") + authCtx := svr.collectServiceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateServices) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -55,16 +58,16 @@ func (svr *ServerAuthAbility) CreateServices( } // DeleteServices 批量删除服务 -func (svr *ServerAuthAbility) DeleteServices( +func (svr *Server) DeleteServices( ctx context.Context, reqs []*apiservice.Service) *apiservice.BatchWriteResponse { - authCtx := svr.collectServiceAuthContext(ctx, reqs, model.Delete, "DeleteServices") + authCtx := svr.collectServiceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteServices) accessRes := authCtx.GetAccessResources() delete(accessRes, apisecurity.ResourceType_Namespaces) authCtx.SetAccessResources(accessRes) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -74,17 +77,17 @@ func (svr *ServerAuthAbility) DeleteServices( } // UpdateServices 对于服务修改来说,只针对服务本身,而不需要检查命名空间 -func (svr *ServerAuthAbility) UpdateServices( +func (svr *Server) UpdateServices( ctx context.Context, reqs []*apiservice.Service) *apiservice.BatchWriteResponse { - authCtx := svr.collectServiceAuthContext(ctx, reqs, model.Modify, "UpdateServices") + authCtx := svr.collectServiceAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateServices) accessRes := authCtx.GetAccessResources() delete(accessRes, apisecurity.ResourceType_Namespaces) authCtx.SetAccessResources(accessRes) - _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -93,17 +96,17 @@ func (svr *ServerAuthAbility) UpdateServices( } // UpdateServiceToken 更新服务的 token -func (svr *ServerAuthAbility) UpdateServiceToken( +func (svr *Server) UpdateServiceToken( ctx context.Context, req *apiservice.Service) *apiservice.Response { authCtx := svr.collectServiceAuthContext( - ctx, []*apiservice.Service{req}, model.Modify, "UpdateServiceToken") + ctx, []*apiservice.Service{req}, authcommon.Modify, authcommon.UpdateServiceToken) accessRes := authCtx.GetAccessResources() delete(accessRes, apisecurity.ResourceType_Namespaces) authCtx.SetAccessResources(accessRes) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -111,12 +114,12 @@ func (svr *ServerAuthAbility) UpdateServiceToken( return svr.nextSvr.UpdateServiceToken(ctx, req) } -func (svr *ServerAuthAbility) GetAllServices(ctx context.Context, +func (svr *Server) GetAllServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetAllServices") + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeAllServices) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -126,40 +129,36 @@ func (svr *ServerAuthAbility) GetAllServices(ctx context.Context, } // GetServices 批量获取服务 -func (svr *ServerAuthAbility) GetServices( +func (svr *Server) GetServices( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServices") + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServices) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + // 注入查询条件拦截器 + resp := svr.nextSvr.GetServices(ctx, query) if len(resp.Services) != 0 { for index := range resp.Services { svc := resp.Services[index] - editable := svr.policyMgr.GetAuthChecker().AllowResourceOperate(authCtx, &model.ResourceOpInfo{ - ResourceType: apisecurity.ResourceType_Services, - Namespace: svc.GetNamespace().GetValue(), - ResourceName: svc.GetName().GetValue(), - ResourceID: svc.GetId().GetValue(), - Operation: authCtx.GetOperation(), - }) - svc.Editable = utils.NewBoolValue(editable) + // TODO 需要配合 metadata 做调整 + svc.Editable = utils.NewBoolValue(true) } } return resp } // GetServicesCount 批量获取服务数量 -func (svr *ServerAuthAbility) GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServicesCount") +func (svr *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse { + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServicesCount) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -168,11 +167,11 @@ func (svr *ServerAuthAbility) GetServicesCount(ctx context.Context) *apiservice. } // GetServiceToken 获取服务的 token -func (svr *ServerAuthAbility) GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServiceToken") +func (svr *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceToken) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -181,15 +180,24 @@ func (svr *ServerAuthAbility) GetServiceToken(ctx context.Context, req *apiservi } // GetServiceOwner 获取服务的 owner -func (svr *ServerAuthAbility) GetServiceOwner( +func (svr *Server) GetServiceOwner( ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServiceOwner") + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceOwner) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + }) + return svr.nextSvr.GetServiceOwner(ctx, req) } diff --git a/service/interceptor/auth/service_alias_authability.go b/service/interceptor/auth/service_alias.go similarity index 57% rename from service/interceptor/auth/service_alias_authability.go rename to service/interceptor/auth/service_alias.go index 5b8da8be4..d52f86538 100644 --- a/service/interceptor/auth/service_alias_authability.go +++ b/service/interceptor/auth/service_alias.go @@ -20,22 +20,24 @@ package service_auth import ( "context" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateServiceAlias creates a service alias -func (svr *ServerAuthAbility) CreateServiceAlias( +func (svr *Server) CreateServiceAlias( ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { authCtx := svr.collectServiceAliasAuthContext( - ctx, []*apiservice.ServiceAlias{req}, model.Create, "CreateServiceAlias") + ctx, []*apiservice.ServiceAlias{req}, authcommon.Create, authcommon.CreateServiceAlias) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewServiceAliasResponse(convertToErrCode(err), req) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewServiceAliasResponse(authcommon.ConvertToErrCode(err), req) } ctx = authCtx.GetRequestContext() @@ -51,12 +53,12 @@ func (svr *ServerAuthAbility) CreateServiceAlias( } // DeleteServiceAliases deletes service aliases -func (svr *ServerAuthAbility) DeleteServiceAliases(ctx context.Context, +func (svr *Server) DeleteServiceAliases(ctx context.Context, reqs []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { - authCtx := svr.collectServiceAliasAuthContext(ctx, reqs, model.Delete, "DeleteServiceAliases") + authCtx := svr.collectServiceAliasAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteServiceAliases) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -66,13 +68,13 @@ func (svr *ServerAuthAbility) DeleteServiceAliases(ctx context.Context, } // UpdateServiceAlias updates service alias -func (svr *ServerAuthAbility) UpdateServiceAlias( +func (svr *Server) UpdateServiceAlias( ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { authCtx := svr.collectServiceAliasAuthContext( - ctx, []*apiservice.ServiceAlias{req}, model.Modify, "UpdateServiceAlias") + ctx, []*apiservice.ServiceAlias{req}, authcommon.Modify, authcommon.UpdateServiceAlias) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewServiceAliasResponse(convertToErrCode(err), req) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewServiceAliasResponse(authcommon.ConvertToErrCode(err), req) } ctx = authCtx.GetRequestContext() @@ -82,34 +84,24 @@ func (svr *ServerAuthAbility) UpdateServiceAlias( } // GetServiceAliases gets service aliases -func (svr *ServerAuthAbility) GetServiceAliases(ctx context.Context, +func (svr *Server) GetServiceAliases(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAliasAuthContext(ctx, nil, model.Read, "GetServiceAliases") + authCtx := svr.collectServiceAliasAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceAliases) - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(convertToErrCode(err)) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - resp := svr.nextSvr.GetServiceAliases(ctx, query) - if len(resp.Aliases) != 0 { - for index := range resp.Aliases { - alias := resp.Aliases[index] - svc := svr.Cache().Service().GetServiceByName(alias.Service.Value, alias.Namespace.Value) - if svc == nil { - continue - } - editable := svr.policyMgr.GetAuthChecker().AllowResourceOperate(authCtx, &model.ResourceOpInfo{ - ResourceType: apisecurity.ResourceType_Services, - Namespace: svc.Namespace, - ResourceName: svc.Name, - ResourceID: svc.ID, - Operation: authCtx.GetOperation(), - }) - alias.Editable = utils.NewBoolValue(editable) - } - } - return resp + cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + }) + + return svr.nextSvr.GetServiceAliases(ctx, query) } diff --git a/service/interceptor/auth/service_contract_authability.go b/service/interceptor/auth/service_contract.go similarity index 66% rename from service/interceptor/auth/service_contract_authability.go rename to service/interceptor/auth/service_contract.go index b413c5614..15725418b 100644 --- a/service/interceptor/auth/service_contract_authability.go +++ b/service/interceptor/auth/service_contract.go @@ -23,12 +23,12 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateServiceContracts . -func (svr *ServerAuthAbility) CreateServiceContracts(ctx context.Context, +func (svr *Server) CreateServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse { services := make([]*apiservice.Service, 0, len(req)) for i := range req { @@ -38,9 +38,9 @@ func (svr *ServerAuthAbility) CreateServiceContracts(ctx context.Context, }) } - authCtx := svr.collectServiceAuthContext(ctx, services, model.Create, "CreateServiceContracts") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectServiceAuthContext(ctx, services, authcommon.Create, authcommon.CreateServiceContracts) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -49,11 +49,11 @@ func (svr *ServerAuthAbility) CreateServiceContracts(ctx context.Context, } // GetServiceContracts . -func (svr *ServerAuthAbility) GetServiceContracts(ctx context.Context, +func (svr *Server) GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServiceContracts") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(convertToErrCode(err)) + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceContracts) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -62,12 +62,12 @@ func (svr *ServerAuthAbility) GetServiceContracts(ctx context.Context, } // GetServiceContractVersions . -func (svr *ServerAuthAbility) GetServiceContractVersions(ctx context.Context, +func (svr *Server) GetServiceContractVersions(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, model.Read, "GetServiceContractVersions") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(convertToErrCode(err)) + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceContractVersions) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -76,7 +76,7 @@ func (svr *ServerAuthAbility) GetServiceContractVersions(ctx context.Context, } // DeleteServiceContracts . -func (svr *ServerAuthAbility) DeleteServiceContracts(ctx context.Context, +func (svr *Server) DeleteServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse { services := make([]*apiservice.Service, 0, len(req)) for i := range req { @@ -86,9 +86,9 @@ func (svr *ServerAuthAbility) DeleteServiceContracts(ctx context.Context, }) } - authCtx := svr.collectServiceAuthContext(ctx, services, model.Delete, "DeleteServiceContracts") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(convertToErrCode(err)) + authCtx := svr.collectServiceAuthContext(ctx, services, authcommon.Delete, authcommon.DeleteServiceContracts) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -97,16 +97,16 @@ func (svr *ServerAuthAbility) DeleteServiceContracts(ctx context.Context, } // CreateServiceContractInterfaces . -func (svr *ServerAuthAbility) CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, +func (svr *Server) CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, source apiservice.InterfaceDescriptor_Source) *apiservice.Response { authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{ { Namespace: utils.NewStringValue(contract.Namespace), Name: utils.NewStringValue(contract.Service), }, - }, model.Modify, "CreateServiceContractInterfaces") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(convertToErrCode(err)) + }, authcommon.Modify, authcommon.CreateServiceContractInterfaces) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -115,16 +115,16 @@ func (svr *ServerAuthAbility) CreateServiceContractInterfaces(ctx context.Contex } // AppendServiceContractInterfaces . -func (svr *ServerAuthAbility) AppendServiceContractInterfaces(ctx context.Context, +func (svr *Server) AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, source apiservice.InterfaceDescriptor_Source) *apiservice.Response { authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{ { Namespace: utils.NewStringValue(contract.Namespace), Name: utils.NewStringValue(contract.Service), }, - }, model.Modify, "AppendServiceContractInterfaces") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(convertToErrCode(err)) + }, authcommon.Modify, authcommon.AppendServiceContractInterfaces) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -133,16 +133,16 @@ func (svr *ServerAuthAbility) AppendServiceContractInterfaces(ctx context.Contex } // DeleteServiceContractInterfaces . -func (svr *ServerAuthAbility) DeleteServiceContractInterfaces(ctx context.Context, +func (svr *Server) DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response { authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{ { Namespace: utils.NewStringValue(contract.Namespace), Name: utils.NewStringValue(contract.Service), }, - }, model.Modify, "DeleteServiceContractInterfaces") - if _, err := svr.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(convertToErrCode(err)) + }, authcommon.Modify, authcommon.DeleteServiceContractInterfaces) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() diff --git a/service/interceptor/paramcheck/circuit_breaker.go b/service/interceptor/paramcheck/circuit_breaker.go index d0c42d79d..452121325 100644 --- a/service/interceptor/paramcheck/circuit_breaker.go +++ b/service/interceptor/paramcheck/circuit_breaker.go @@ -21,15 +21,13 @@ import ( "context" "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" - "github.com/polarismesh/specification/source/go/api/v1/service_manage" - - "github.com/polarismesh/polaris/common/utils" - apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/utils" ) // GetMasterCircuitBreakers implements service.DiscoverServer. diff --git a/service/interceptor/paramcheck/fault_detect.go b/service/interceptor/paramcheck/fault_detect.go index a21a073a5..28a8d4ca4 100644 --- a/service/interceptor/paramcheck/fault_detect.go +++ b/service/interceptor/paramcheck/fault_detect.go @@ -19,30 +19,187 @@ package paramcheck import ( "context" + "strconv" "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + commonstore "github.com/polarismesh/polaris/common/store" + "github.com/polarismesh/polaris/common/utils" +) + +var ( + // FaultDetectRuleFilters filter fault detect rule query parameters + FaultDetectRuleFilters = map[string]bool{ + "brief": true, + "offset": true, + "limit": true, + "id": true, + "name": true, + "namespace": true, + "service": true, + "serviceNamespace": true, + "dstService": true, + "dstNamespace": true, + "dstMethod": true, + "description": true, + } ) // DeleteFaultDetectRules implements service.DiscoverServer. func (svr *Server) DeleteFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + + if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { + return checkErr + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, cbRule := range request { + if resp := checkFaultDetectRuleParams(cbRule, false, true); resp != nil { + api.Collect(batchRsp, resp) + continue + } + } + + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteFaultDetectRules(ctx, request) } // GetFaultDetectRules implements service.DiscoverServer. func (svr *Server) GetFaultDetectRules(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { + + for key := range query { + if _, ok := FaultDetectRuleFilters[key]; !ok { + log.Errorf("params %s is not allowed in querying fault detect rule", key) + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + } + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + + query["offset"] = strconv.FormatUint(uint64(offset), 10) + query["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetFaultDetectRules(ctx, query) } // CreateFaultDetectRules implements service.DiscoverServer. func (svr *Server) CreateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + + if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { + return checkErr + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, cbRule := range request { + if resp := checkFaultDetectRuleParams(cbRule, false, true); resp != nil { + api.Collect(batchRsp, resp) + continue + } + } + + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.CreateFaultDetectRules(ctx, request) } // UpdateFaultDetectRules implements service.DiscoverServer. func (svr *Server) UpdateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { + return checkErr + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, cbRule := range request { + if resp := checkFaultDetectRuleParams(cbRule, false, true); resp != nil { + api.Collect(batchRsp, resp) + continue + } + if resp := svr.checkFaultDetectRuleExists(ctx, cbRule.GetId()); resp != nil { + api.Collect(batchRsp, resp) + continue + } + } + + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateFaultDetectRules(ctx, request) } + +func (svr *Server) checkFaultDetectRuleExists(ctx context.Context, id string) *apiservice.Response { + exists, err := svr.storage.HasFaultDetectRule(id) + if err != nil { + log.Error(err.Error(), utils.RequestID(ctx)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + if !exists { + return api.NewResponse(apimodel.Code_NotFoundResource) + } + return nil +} + +func checkBatchFaultDetectRules(req []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +func checkFaultDetectRuleParams( + req *apifault.FaultDetectRule, idRequired bool, nameRequired bool) *apiservice.Response { + if req == nil { + return api.NewResponse(apimodel.Code_EmptyRequest) + } + if resp := checkFaultDetectRuleParamsDbLen(req); nil != resp { + return resp + } + if nameRequired && len(req.GetName()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) + } + if idRequired && len(req.GetId()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) + } + return nil +} + +func checkFaultDetectRuleParamsDbLen(req *apifault.FaultDetectRule) *apiservice.Response { + if err := utils.CheckDbRawStrFieldLen(req.GetTargetService().GetService(), utils.MaxDbServiceNameLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceName) + } + if err := utils.CheckDbRawStrFieldLen( + req.GetTargetService().GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetName(), utils.MaxRuleName); err != nil { + return api.NewResponse(apimodel.Code_InvalidRateLimitName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), utils.MaxCommentLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceComment) + } + return nil +} diff --git a/service/interceptor/paramcheck/server.go b/service/interceptor/paramcheck/server.go index 542f1264b..f46a54065 100644 --- a/service/interceptor/paramcheck/server.go +++ b/service/interceptor/paramcheck/server.go @@ -23,12 +23,14 @@ import ( "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/service" + "github.com/polarismesh/polaris/store" ) // Server 带有鉴权能力的 discoverServer // // 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 type Server struct { + storage store.Store nextSvr service.DiscoverServer ratelimit plugin.Ratelimit } diff --git a/service/interceptor/paramcheck/service_alias.go b/service/interceptor/paramcheck/service_alias.go index 68c946044..c01f58de8 100644 --- a/service/interceptor/paramcheck/service_alias.go +++ b/service/interceptor/paramcheck/service_alias.go @@ -130,6 +130,7 @@ func checkDeleteServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlia return nil } + func preCheckAlias(req *apiservice.ServiceAlias) (*apiservice.Response, bool) { if req == nil { return api.NewServiceAliasResponse(apimodel.Code_EmptyRequest, req), true diff --git a/service/ratelimit_config.go b/service/ratelimit_config.go index 6ab5c0d0c..be4bf0734 100644 --- a/service/ratelimit_config.go +++ b/service/ratelimit_config.go @@ -264,7 +264,7 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap return errResp } - total, extendRateLimits, err := s.Cache().RateLimit().QueryRateLimitRules(*args) + total, extendRateLimits, err := s.Cache().RateLimit().QueryRateLimitRules(ctx, *args) if err != nil { log.Errorf("get rate limits store err: %s", err.Error()) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) diff --git a/service/routing_config_v2.go b/service/routing_config_v2.go index 311b8860e..7e8040a1d 100644 --- a/service/routing_config_v2.go +++ b/service/routing_config_v2.go @@ -207,7 +207,7 @@ func (s *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]str return apiv1.NewBatchQueryResponse(apimodel.Code(presp.GetCode().GetValue())) } - total, ret, err := s.Cache().RoutingConfig().QueryRoutingConfigsV2(args) + total, ret, err := s.Cache().RoutingConfig().QueryRoutingConfigsV2(ctx, args) if err != nil { log.Error("[Routing][V2] query routing list from cache", utils.RequestID(ctx), zap.Error(err)) return apiv1.NewBatchQueryResponse(apimodel.Code_ExecuteException) diff --git a/service/routing_config_v2_test.go b/service/routing_config_v2_test.go index af000b722..8513f3ae0 100644 --- a/service/routing_config_v2_test.go +++ b/service/routing_config_v2_test.go @@ -100,8 +100,8 @@ func TestCreateRoutingConfigV2(t *testing.T) { out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", "offset": "0", - "namespace": expendItem.RuleRouting.Rules[0].Sources[0].Namespace, - "service": expendItem.RuleRouting.Rules[0].Sources[0].Service, + "namespace": expendItem.RuleRouting.RuleRouting.Rules[0].Sources[0].Namespace, + "service": expendItem.RuleRouting.RuleRouting.Rules[0].Sources[0].Service, }) if !respSuccess(out) { t.Fatalf("error: %+v", out) @@ -601,3 +601,8 @@ func unmarshalRoutingV2toAnySlice(routings []*anypb.Any) ([]*apitraffic.RouteRul return ret, nil } + +func Test_PrintRouteRuleTypeUrl(t *testing.T) { + any, _ := ptypes.MarshalAny(&apitraffic.RuleRoutingConfig{}) + t.Log(any.TypeUrl) +} diff --git a/service/server.go b/service/server.go index 4ea474c49..8c1cd9b3e 100644 --- a/service/server.go +++ b/service/server.go @@ -48,8 +48,8 @@ type Server struct { healthServer *healthcheck.Server - cmdb plugin.CMDB - history plugin.History + cmdb plugin.CMDB + history plugin.History l5service *l5service diff --git a/service/service.go b/service/service.go index 0330b4895..e7b43d0a7 100644 --- a/service/service.go +++ b/service/service.go @@ -381,7 +381,7 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis } serviceArgs := parseServiceArgs(serviceFilters, serviceMetas, ctx) - total, services, err := s.caches.Service().GetServicesByFilter(serviceArgs, instanceArgs, offset, limit) + total, services, err := s.caches.Service().GetServicesByFilter(ctx, serviceArgs, instanceArgs, offset, limit) if err != nil { log.Errorf("[Server][Service][Query] req(%+v) store err: %s", query, err.Error()) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) diff --git a/store/admin_api.go b/store/admin_api.go index ca44e4fe1..4114019d0 100644 --- a/store/admin_api.go +++ b/store/admin_api.go @@ -20,7 +20,7 @@ package store import ( "time" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" ) const ( @@ -34,7 +34,7 @@ type AdminStore interface { // IsLeader whether it is leader node IsLeader(key string) bool // ListLeaderElections list all leaderelection - ListLeaderElections() ([]*model.LeaderElection, error) + ListLeaderElections() ([]*admin.LeaderElection, error) // ReleaseLeaderElection force release leader status ReleaseLeaderElection(key string) error // BatchCleanDeletedInstances batch clean soft deleted instances diff --git a/store/api.go b/store/api.go index a5034c3d0..575233cca 100644 --- a/store/api.go +++ b/store/api.go @@ -49,6 +49,8 @@ type Store interface { AdminStore // GrayStore mgr gray resource GrayStore + // AuthStore Auth storage interface + AuthStore } // NamespaceStore Namespace storage interface diff --git a/store/auth_api.go b/store/auth_api.go index 7fd1d340f..625df4a34 100644 --- a/store/auth_api.go +++ b/store/auth_api.go @@ -20,92 +20,102 @@ package store import ( "time" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) +type AuthStore interface { + // UserStore 用户接口 + UserStore + // GroupStore 用户组接口 + GroupStore + // StrategyStore 鉴权策略接口 + StrategyStore + // RoleStore 角色接口 + RoleStore +} + // UserStore User-related operation interface type UserStore interface { // AddUser Create a user - AddUser(user *model.User) error + AddUser(tx Tx, user *authcommon.User) error // UpdateUser Update user - UpdateUser(user *model.User) error + UpdateUser(user *authcommon.User) error // DeleteUser delete users - DeleteUser(user *model.User) error + DeleteUser(tx Tx, user *authcommon.User) error // GetSubCount Number of getting a child account - GetSubCount(user *model.User) (uint32, error) + GetSubCount(user *authcommon.User) (uint32, error) // GetUser Obtain user - GetUser(id string) (*model.User, error) + GetUser(id string) (*authcommon.User, error) // GetUserByName Get a unique user according to Name + Owner - GetUserByName(name, ownerId string) (*model.User, error) + GetUserByName(name, ownerId string) (*authcommon.User, error) // GetUserByIDS Get users according to USER IDS batch - GetUserByIds(ids []string) ([]*model.User, error) + GetUserByIds(ids []string) ([]*authcommon.User, error) // GetUsers Query user list - GetUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*model.User, error) + GetUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*authcommon.User, error) // GetUsersForCache Used to refresh user cache // 此方法用于 cache 增量更新,需要注意 mtime 应为数据库时间戳 - GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*model.User, error) + GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.User, error) } // GroupStore User group storage operation interface type GroupStore interface { - // AddGroup Add a user group - AddGroup(group *model.UserGroupDetail) error - + AddGroup(tx Tx, group *authcommon.UserGroupDetail) error // UpdateGroup Update user group - UpdateGroup(group *model.ModifyUserGroup) error - + UpdateGroup(group *authcommon.ModifyUserGroup) error // DeleteGroup Delete user group - DeleteGroup(group *model.UserGroupDetail) error - + DeleteGroup(tx Tx, group *authcommon.UserGroupDetail) error // GetGroup Get user group details - GetGroup(id string) (*model.UserGroupDetail, error) - + GetGroup(id string) (*authcommon.UserGroupDetail, error) // GetGroupByName Get user groups according to Name and Owner - GetGroupByName(name, owner string) (*model.UserGroup, error) - + GetGroupByName(name, owner string) (*authcommon.UserGroup, error) // GetGroups Get a list of user groups - GetGroups(filters map[string]string, offset uint32, limit uint32) (uint32, []*model.UserGroup, error) - + GetGroups(filters map[string]string, offset uint32, limit uint32) (uint32, []*authcommon.UserGroup, error) // GetUserGroupsForCache Refresh of getting user groups for cache // 此方法用于 cache 增量更新,需要注意 mtime 应为数据库时间戳 - GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) + GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.UserGroupDetail, error) } // StrategyStore Authentication policy related storage operation interface type StrategyStore interface { - // AddStrategy Create authentication strategy - AddStrategy(strategy *model.StrategyDetail) error - + AddStrategy(tx Tx, strategy *authcommon.StrategyDetail) error // UpdateStrategy Update authentication strategy - UpdateStrategy(strategy *model.ModifyStrategyDetail) error - + UpdateStrategy(strategy *authcommon.ModifyStrategyDetail) error // DeleteStrategy Delete authentication strategy DeleteStrategy(id string) error - + // CleanPrincipalPolicies Clean all the policies associated with the principal + CleanPrincipalPolicies(tx Tx, p authcommon.Principal) error // LooseAddStrategyResources Song requires the resources of the authentication strategy, // allowing the issue of ignoring the primary key conflict - LooseAddStrategyResources(resources []model.StrategyResource) error - + LooseAddStrategyResources(resources []authcommon.StrategyResource) error // RemoveStrategyResources Clean all the strategies associated with corresponding resources - RemoveStrategyResources(resources []model.StrategyResource) error - + RemoveStrategyResources(resources []authcommon.StrategyResource) error // GetStrategyResources Gets a Principal's corresponding resource ID data information - GetStrategyResources(principalId string, principalRole model.PrincipalType) ([]model.StrategyResource, error) - + GetStrategyResources(principalId string, principalRole authcommon.PrincipalType) ([]authcommon.StrategyResource, error) // GetDefaultStrategyDetailByPrincipal Get a default policy for a Principal GetDefaultStrategyDetailByPrincipal(principalId string, - principalType model.PrincipalType) (*model.StrategyDetail, error) - + principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) // GetStrategyDetail Get strategy details - GetStrategyDetail(id string) (*model.StrategyDetail, error) - + GetStrategyDetail(id string) (*authcommon.StrategyDetail, error) // GetStrategies Get a list of strategies GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.StrategyDetail, error) - - // GetStrategyDetailsForCache Used to refresh policy cache + []*authcommon.StrategyDetail, error) + // GetMoreStrategies Used to refresh policy cache // 此方法用于 cache 增量更新,需要注意 mtime 应为数据库时间戳 - GetStrategyDetailsForCache(mtime time.Time, firstUpdate bool) ([]*model.StrategyDetail, error) + GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) +} + +// RoleStore Role related storage operation interface +type RoleStore interface { + // AddRole Add a role + AddRole(role *authcommon.Role) error + // UpdateRole Update a role + UpdateRole(role *authcommon.Role) error + // DeleteRole Delete a role + DeleteRole(role *authcommon.Role) error + // CleanPrincipalRoles Clean all the roles associated with the principal + CleanPrincipalRoles(tx Tx, p *authcommon.Principal) error + // GetRole get more role for cache update + GetMoreRoles(firstUpdate bool, modifyTime time.Time) ([]*authcommon.Role, error) } diff --git a/store/boltdb/admin.go b/store/boltdb/admin.go index 09eafc1d1..4d536c97a 100644 --- a/store/boltdb/admin.go +++ b/store/boltdb/admin.go @@ -26,6 +26,7 @@ import ( "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -63,13 +64,13 @@ func (m *adminStore) IsLeader(key string) bool { } // ListLeaderElections -func (m *adminStore) ListLeaderElections() ([]*model.LeaderElection, error) { +func (m *adminStore) ListLeaderElections() ([]*admin.LeaderElection, error) { m.mutex.Lock() defer m.mutex.Unlock() - var out []*model.LeaderElection + var out []*admin.LeaderElection for k, v := range m.leMap { - item := &model.LeaderElection{ + item := &admin.LeaderElection{ ElectKey: k, Host: utils.LocalHost, Ctime: 0, diff --git a/store/boltdb/default.go b/store/boltdb/default.go index 85ca9eea3..17907b923 100644 --- a/store/boltdb/default.go +++ b/store/boltdb/default.go @@ -25,6 +25,7 @@ import ( "go.uber.org/zap" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -81,6 +82,8 @@ type boltStore struct { *configFileReleaseHistoryStore *configFileTemplateStore + *grayStore + // adminStore store *adminStore // 工具 @@ -89,7 +92,7 @@ type boltStore struct { *userStore *groupStore *strategyStore - *grayStore + *roleStore handler BoltHandler start bool @@ -141,7 +144,7 @@ var ( "polaris.checker": "fbca9bfa04ae4ead86e1ecf5811e32a9", } - mainUser = &model.User{ + mainUser = &authcommon.User{ ID: "65e4789a6d5b49669adf1e9e8387549c", Name: "polaris", Password: "$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum", @@ -158,21 +161,21 @@ var ( ModifyTime: time.Now(), } - superDefaultStrategy = &model.StrategyDetail{ + superDefaultStrategy = &authcommon.StrategyDetail{ ID: "super_user_default_strategy", Name: "(用户) polarissys@admin的默认策略", Action: "READ_WRITE", Comment: "default admin", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { StrategyID: "super_user_default_strategy", PrincipalID: "", - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: true, Owner: "", - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: "super_user_default_strategy", ResType: int32(apisecurity.ResourceType_Namespaces), @@ -195,21 +198,21 @@ var ( ModifyTime: time.Now(), } - mainDefaultStrategy = &model.StrategyDetail{ + mainDefaultStrategy = &authcommon.StrategyDetail{ ID: "fbca9bfa04ae4ead86e1ecf5811e32a9", Name: "(用户) polaris的默认策略", Action: "READ_WRITE", Comment: "default admin", - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", PrincipalID: "65e4789a6d5b49669adf1e9e8387549c", - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: true, Owner: "65e4789a6d5b49669adf1e9e8387549c", - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", ResType: int32(apisecurity.ResourceType_Namespaces), diff --git a/store/boltdb/group.go b/store/boltdb/group.go index 4b5e6832d..09cf886f5 100644 --- a/store/boltdb/group.go +++ b/store/boltdb/group.go @@ -27,7 +27,7 @@ import ( bolt "go.etcd.io/bbolt" "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -73,66 +73,37 @@ type groupStore struct { } // AddGroup add a group -func (gs *groupStore) AddGroup(group *model.UserGroupDetail) error { +func (gs *groupStore) AddGroup(tx store.Tx, group *authcommon.UserGroupDetail) error { if group.ID == "" || group.Name == "" || group.Token == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "add usergroup missing some params, groupId is %s, name is %s", group.ID, group.Name)) } - proxy, err := gs.handler.StartTx() - if err != nil { - return err - } - tx := proxy.GetDelegateTx().(*bolt.Tx) + dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer func() { - _ = tx.Rollback() - }() - - if err := gs.cleanInValidGroup(tx, group.Name, group.Owner); err != nil { + if err := gs.cleanInValidGroup(dbTx, group.Name, group.Owner); err != nil { log.Error("[Store][Group] clean invalid usergroup", zap.Error(err), zap.String("name", group.Name), zap.String("owner", group.Owner)) return err } - return gs.addGroup(tx, group) -} - -// addGroup to boltdb -func (gs *groupStore) addGroup(tx *bolt.Tx, group *model.UserGroupDetail) error { - group.Valid = true group.CreateTime = time.Now() group.ModifyTime = group.CreateTime data := convertForGroupStore(group) - if err := saveValue(tx, tblGroup, data.ID, data); err != nil { + if err := saveValue(dbTx, tblGroup, data.ID, data); err != nil { log.Error("[Store][Group] save usergroup", zap.Error(err), zap.String("name", group.Name), zap.String("owner", group.Owner)) return err } - - if err := createDefaultStrategy(tx, model.PrincipalGroup, data.ID, data.Name, - data.Owner); err != nil { - log.Error("[Store][Group] add usergroup default strategy", zap.Error(err), - zap.String("name", group.Name), zap.String("owner", group.Owner)) - - return err - } - - if err := tx.Commit(); err != nil { - log.Error("[Store][Group] add usergroup tx commit", zap.Error(err), - zap.String("name", group.Name), zap.String("owner", group.Owner)) - return err - } - return nil } // UpdateGroup update a group -func (gs *groupStore) UpdateGroup(group *model.ModifyUserGroup) error { +func (gs *groupStore) UpdateGroup(group *authcommon.ModifyUserGroup) error { if group.ID == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "update usergroup missing some params, groupId is %s", group.ID)) @@ -141,7 +112,7 @@ func (gs *groupStore) UpdateGroup(group *model.ModifyUserGroup) error { return gs.updateGroup(group) } -func (gs *groupStore) updateGroup(group *model.ModifyUserGroup) error { +func (gs *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { proxy, err := gs.handler.StartTx() if err != nil { return err @@ -166,7 +137,7 @@ func (gs *groupStore) updateGroup(group *model.ModifyUserGroup) error { return ErrorMultipleGroupFound } - var ret *model.UserGroupDetail + var ret *authcommon.UserGroupDetail for _, v := range values { ret = convertForGroupDetail(v.(*groupForStore)) break @@ -193,7 +164,7 @@ func (gs *groupStore) updateGroup(group *model.ModifyUserGroup) error { } // updateGroupRelation 更新用户组的关联关系数据 -func updateGroupRelation(group *model.UserGroupDetail, modify *model.ModifyUserGroup) { +func updateGroupRelation(group *authcommon.UserGroupDetail, modify *authcommon.ModifyUserGroup) { for i := range modify.AddUserIds { group.UserIds[modify.AddUserIds[i]] = struct{}{} } @@ -204,52 +175,26 @@ func updateGroupRelation(group *model.UserGroupDetail, modify *model.ModifyUserG } // DeleteGroup 删除用户组 -func (gs *groupStore) DeleteGroup(group *model.UserGroupDetail) error { +func (gs *groupStore) DeleteGroup(tx store.Tx, group *authcommon.UserGroupDetail) error { if group.ID == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "delete usergroup missing some params, groupId is %s", group.ID)) } - - return gs.deleteGroup(group) -} - -func (gs *groupStore) deleteGroup(group *model.UserGroupDetail) error { - proxy, err := gs.handler.StartTx() - if err != nil { - return err - } - tx := proxy.GetDelegateTx().(*bolt.Tx) - - defer func() { - _ = tx.Rollback() - }() + dbTx := tx.GetDelegateTx().(*bolt.Tx) properties := make(map[string]interface{}) properties[GroupFieldValid] = false properties[GroupFieldModifyTime] = time.Now() - if err := updateValue(tx, tblGroup, group.ID, properties); err != nil { + if err := updateValue(dbTx, tblGroup, group.ID, properties); err != nil { log.Error("[Store][Group] remove usergroup", zap.Error(err), zap.String("id", group.ID)) - return err - } - - if err := cleanLinkStrategy(tx, model.PrincipalGroup, group.ID, group.Owner); err != nil { - log.Error("[Store][Group] clean usergroup default strategy", - zap.Error(err), zap.String("id", group.ID)) - return err + return store.Error(err) } - - if err := tx.Commit(); err != nil { - log.Error("[Store][Group] delete usergroupr tx commit", - zap.Error(err), zap.String("id", group.ID)) - return err - } - return nil } // GetGroup get a group -func (gs *groupStore) GetGroup(groupID string) (*model.UserGroupDetail, error) { +func (gs *groupStore) GetGroup(groupID string) (*authcommon.UserGroupDetail, error) { if groupID == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "get usergroup missing some params, groupID is %s", groupID)) @@ -269,7 +214,7 @@ func (gs *groupStore) GetGroup(groupID string) (*model.UserGroupDetail, error) { return nil, ErrorMultipleGroupFound } - var ret *model.UserGroupDetail + var ret *authcommon.UserGroupDetail for _, v := range values { ret = convertForGroupDetail(v.(*groupForStore)) break @@ -283,7 +228,7 @@ func (gs *groupStore) GetGroup(groupID string) (*model.UserGroupDetail, error) { } // GetGroupByName get a group by name -func (gs *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error) { +func (gs *groupStore) GetGroupByName(name, owner string) (*authcommon.UserGroup, error) { if name == "" || owner == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "get usergroup missing some params, name=%s, owner=%s", name, owner)) @@ -314,7 +259,7 @@ func (gs *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, erro if len(values) > 1 { return nil, ErrorMultipleGroupFound } - var ret *model.UserGroupDetail + var ret *authcommon.UserGroupDetail for _, v := range values { ret = convertForGroupDetail(v.(*groupForStore)) break @@ -324,7 +269,7 @@ func (gs *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, erro // GetGroups get groups func (gs *groupStore) GetGroups(filters map[string]string, offset uint32, - limit uint32) (uint32, []*model.UserGroup, error) { + limit uint32) (uint32, []*authcommon.UserGroup, error) { // 如果本次请求参数携带了 user_id,那么就是查询这个用户所关联的所有用户组 if _, ok := filters["user_id"]; ok { @@ -336,7 +281,7 @@ func (gs *groupStore) GetGroups(filters map[string]string, offset uint32, // listSimpleGroups Normal user group query func (gs *groupStore) listSimpleGroups(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.UserGroup, error) { + []*authcommon.UserGroup, error) { fields := []string{GroupFieldID, GroupFieldOwner, GroupFieldName, GroupFieldValid} values, err := gs.handler.LoadValuesByFilter(tblGroup, fields, &groupForStore{}, func(m map[string]interface{}) bool { @@ -379,7 +324,7 @@ func (gs *groupStore) listSimpleGroups(filters map[string]string, offset uint32, // listGroupByUser 查询某个用户下所关联的用户组信息 func (gs *groupStore) listGroupByUser(filters map[string]string, offset uint32, - limit uint32) (uint32, []*model.UserGroup, error) { + limit uint32) (uint32, []*authcommon.UserGroup, error) { var ( userID = filters["user_id"] @@ -427,9 +372,9 @@ func (gs *groupStore) listGroupByUser(filters map[string]string, offset uint32, return total, doGroupPage(values, offset, limit), nil } -func doGroupPage(ret map[string]interface{}, offset uint32, limit uint32) []*model.UserGroup { +func doGroupPage(ret map[string]interface{}, offset uint32, limit uint32) []*authcommon.UserGroup { - groups := make([]*model.UserGroup, 0, len(ret)) + groups := make([]*authcommon.UserGroup, 0, len(ret)) beginIndex := offset endIndex := beginIndex + limit @@ -459,7 +404,7 @@ func doGroupPage(ret map[string]interface{}, offset uint32, limit uint32) []*mod } // GetGroupsForCache 查询用户分组数据,主要用于Cache更新 -func (gs *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) { +func (gs *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.UserGroupDetail, error) { ret, err := gs.handler.LoadValuesByFilter(tblGroup, []string{GroupFieldModifyTime}, &groupForStore{}, func(m map[string]interface{}) bool { mt := m[GroupFieldModifyTime].(time.Time) @@ -470,7 +415,7 @@ func (gs *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*m return nil, err } - groups := make([]*model.UserGroupDetail, 0, len(ret)) + groups := make([]*authcommon.UserGroupDetail, 0, len(ret)) for k := range ret { val := ret[k] @@ -518,7 +463,7 @@ func (gs *groupStore) cleanInValidGroup(tx *bolt.Tx, name, owner string) error { return deleteValues(tx, tblGroup, keys) } -func convertForGroupStore(group *model.UserGroupDetail) *groupForStore { +func convertForGroupStore(group *authcommon.UserGroupDetail) *groupForStore { userIds := make(map[string]string, len(group.UserIds)) @@ -540,14 +485,14 @@ func convertForGroupStore(group *model.UserGroupDetail) *groupForStore { } } -func convertForGroupDetail(group *groupForStore) *model.UserGroupDetail { +func convertForGroupDetail(group *groupForStore) *authcommon.UserGroupDetail { userIds := make(map[string]struct{}, len(group.UserIds)) for id := range group.UserIds { userIds[id] = struct{}{} } - return &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + return &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: group.ID, Name: group.Name, Owner: group.Owner, diff --git a/store/boltdb/group_test.go b/store/boltdb/group_test.go index d0f7b2ebb..264ba997f 100644 --- a/store/boltdb/group_test.go +++ b/store/boltdb/group_test.go @@ -24,10 +24,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) -func buildUserIds(users []*model.User) map[string]struct{} { +func buildUserIds(users []*authcommon.User) map[string]struct{} { ret := make(map[string]struct{}, len(users)) for i := range users { @@ -38,14 +38,14 @@ func buildUserIds(users []*model.User) map[string]struct{} { return ret } -func createTestUserGroup(num int) []*model.UserGroupDetail { - ret := make([]*model.UserGroupDetail, 0, num) +func createTestUserGroup(num int) []*authcommon.UserGroupDetail { + ret := make([]*authcommon.UserGroupDetail, 0, num) users := createTestUsers(num) for i := 0; i < num; i++ { - ret = append(ret, &model.UserGroupDetail{ - UserGroup: &model.UserGroup{ + ret = append(ret, &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ ID: fmt.Sprintf("test_group_%d", i), Name: fmt.Sprintf("test_group_%d", i), Owner: "polaris", @@ -69,9 +69,12 @@ func Test_groupStore_AddGroup(t *testing.T) { groups := createTestUserGroup(1) - if err := gs.AddGroup(groups[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := gs.AddGroup(tx, groups[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := gs.GetGroup(groups[0].ID) if err != nil { @@ -94,13 +97,16 @@ func Test_groupStore_UpdateGroup(t *testing.T) { groups := createTestUserGroup(1) - if err := gs.AddGroup(groups[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := gs.AddGroup(tx, groups[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) groups[0].Comment = time.Now().String() - if err := gs.UpdateGroup(&model.ModifyUserGroup{ + if err := gs.UpdateGroup(&authcommon.ModifyUserGroup{ ID: groups[0].ID, Owner: groups[0].Owner, Token: groups[0].Token, @@ -131,15 +137,21 @@ func Test_groupStore_DeleteGroup(t *testing.T) { groups := createTestUserGroup(1) - if err := gs.AddGroup(groups[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := gs.AddGroup(tx, groups[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) groups[0].Comment = time.Now().String() - if err := gs.DeleteGroup(groups[0]); err != nil { + tx, err = handler.StartTx() + assert.NoError(t, err) + if err := gs.DeleteGroup(tx, groups[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := gs.GetGroup(groups[0].ID) if err != nil { @@ -156,9 +168,12 @@ func Test_groupStore_GetGroupByName(t *testing.T) { groups := createTestUserGroup(1) - if err := gs.AddGroup(groups[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := gs.AddGroup(tx, groups[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := gs.GetGroupByName(groups[0].Name, groups[0].Owner) if err != nil { @@ -182,9 +197,12 @@ func Test_groupStore_GetGroups(t *testing.T) { groups := createTestUserGroup(10) for i := range groups { - if err := gs.AddGroup(groups[i]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := gs.AddGroup(tx, groups[i]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) } total, ret, err := gs.GetGroups(map[string]string{ diff --git a/store/boltdb/load.go b/store/boltdb/load.go index 56514b257..41cfdbd6f 100644 --- a/store/boltdb/load.go +++ b/store/boltdb/load.go @@ -29,6 +29,7 @@ import ( "gopkg.in/yaml.v2" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) func (m *boltStore) loadByDefault() error { @@ -46,7 +47,7 @@ func (m *boltStore) loadByDefault() error { // DefaultData 默认数据信息 type DefaultData struct { Namespaces []*model.Namespace `yaml:"namespaces"` - Users []*model.User `yaml:"users"` + Users []*authcommon.User `yaml:"users"` } func (m *boltStore) loadByFile(loadFile string) error { @@ -90,16 +91,16 @@ func (m *boltStore) loadFromData(data *DefaultData) error { tn := time.Now() var ( - superUser, mainUser *model.User + superUser, mainUser *authcommon.User ) - if len(users) >= 2 && users[0].Type == model.AdminUserRole && users[1].Type == model.OwnerUserRole { + if len(users) >= 2 && users[0].Type == authcommon.AdminUserRole && users[1].Type == authcommon.OwnerUserRole { superUser = users[0] superUser.CreateTime = tn superUser.ModifyTime = tn mainUser = users[1] mainUser.CreateTime = tn mainUser.ModifyTime = tn - } else if users[0].Type == model.OwnerUserRole { + } else if users[0].Type == authcommon.OwnerUserRole { mainUser = users[0] mainUser.CreateTime = tn mainUser.ModifyTime = tn @@ -108,7 +109,7 @@ func (m *boltStore) loadFromData(data *DefaultData) error { } if err := m.handler.Execute(true, func(tx *bolt.Tx) error { - saveFunc := func(user *model.User, rule *model.StrategyDetail) error { + saveFunc := func(user *authcommon.User, rule *authcommon.StrategyDetail) error { rule.Owner = user.ID rule.Principals[0].PrincipalID = user.ID saveUser, err := m.getUser(tx, user.ID) @@ -154,11 +155,18 @@ func (m *boltStore) loadFromData(data *DefaultData) error { return err } + tx, err := m.handle.StartTx() + if err != nil { + return err + } + defer func() { + _ = tx.Rollback() + }() // 挨个处理其他用户数据信息 for i := 1; i < len(users); i++ { - if err := m.addUser(users[i]); err != nil { + if err := m.AddUser(tx, users[i]); err != nil { return nil } } - return nil + return tx.Commit() } diff --git a/store/boltdb/role.go b/store/boltdb/role.go new file mode 100644 index 000000000..21b8a5085 --- /dev/null +++ b/store/boltdb/role.go @@ -0,0 +1,271 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package boltdb + +import ( + "encoding/json" + "time" + + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" + bolt "go.etcd.io/bbolt" + "go.uber.org/zap" +) + +var _ store.RoleStore = (*roleStore)(nil) + +const ( + // tblRole 角色表数据 + tblRole string = "role" + + roleFieldUsers string = "Users" + roleFieldUserGroups string = "UserGroups" + roleFieldMetadata string = "Metadata" + roleFieldType string = "Type" + roleFieldSource string = "Source" +) + +type roleStore struct { + handle BoltHandler +} + +// AddRole Add a role +func (s *roleStore) AddRole(role *authcommon.Role) error { + if role.ID == "" || role.Name == "" { + log.Error("[Store][role] create role missing some params") + return ErrBadParam + } + + data := newRoleData(role) + data.CreateTime = time.Now() + data.ModifyTime = time.Now() + data.Valid = true + + err := s.handle.Execute(true, func(tx *bolt.Tx) error { + if err := s.cleanInvalidRole(tx, data.ID); err != nil { + return err + } + return saveValue(tx, tblRole, data.ID, data) + }) + if err != nil { + log.Error("[Store][role] create role failed", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + return nil +} + +// cleanInvalidRole 删除无效的角色数据信息 +func (s *roleStore) cleanInvalidRole(tx *bolt.Tx, id string) error { + err := deleteValues(tx, tblRole, []string{id}) + if err != nil { + log.Errorf("[Store][role] delete invalid role error, %v", err) + return err + } + return nil +} + +// UpdateRole Update a role +func (s *roleStore) UpdateRole(role *authcommon.Role) error { + if role.ID == "" { + log.Error("[Store][role] update role missing some params") + return ErrBadParam + } + + data := newRoleData(role) + + err := s.handle.Execute(true, func(tx *bolt.Tx) error { + properties := map[string]interface{}{ + CommonFieldValid: true, + CommonFieldModifyTime: time.Now(), + CommonFieldDescription: data.Description, + roleFieldType: data.Type, + roleFieldMetadata: data.Metadata, + roleFieldSource: data.Source, + } + return updateValue(tx, tblRole, data.ID, properties) + }) + if err != nil { + log.Error("[Store][role] update role failed", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + return nil +} + +// DeleteRole Delete a role +func (s *roleStore) DeleteRole(role *authcommon.Role) error { + if role.ID == "" { + log.Error("[Store][role] delete role missing some params") + return ErrBadParam + } + + data := newRoleData(role) + + err := s.handle.Execute(true, func(tx *bolt.Tx) error { + properties := map[string]interface{}{ + CommonFieldValid: false, + CommonFieldModifyTime: time.Now(), + } + return updateValue(tx, tblRole, data.ID, properties) + }) + if err != nil { + log.Error("[Store][role] delete role failed", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + return nil +} + +// CleanPrincipalRoles clean principal roles +func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) error { + dbTx := tx.GetDelegateTx().(*bolt.Tx) + fields := []string{roleFieldUsers, roleFieldUserGroups, CommonFieldValid, CommonFieldID} + values := map[string]interface{}{} + + updateDatas := map[string]map[string]interface{}{} + + err := loadValuesByFilter(dbTx, tblRole, fields, &roleData{}, + func(m map[string]interface{}) bool { + valid, _ := m[CommonFieldValid].(bool) + if !valid { + return false + } + switch p.PrincipalType { + case authcommon.PrincipalUser: + users := make([]*authcommon.User, 0, 4) + _ = json.Unmarshal([]byte(m[roleFieldUsers].(string)), &users) + finalUsers := make([]*authcommon.User, 0, len(users)) + for i := range users { + if users[i].ID == p.PrincipalID { + continue + } + finalUsers = append(finalUsers, users[i]) + } + updateDatas[m[CommonFieldID].(string)] = map[string]interface{}{ + roleFieldUsers: utils.MustJson(users), + CommonFieldModifyTime: time.Now(), + } + case authcommon.PrincipalGroup: + groups := make([]*authcommon.UserGroup, 0, 4) + _ = json.Unmarshal([]byte(m[roleFieldUserGroups].(string)), &groups) + finalGroups := make([]*authcommon.UserGroup, 0, len(groups)) + for i := range groups { + if groups[i].ID == p.PrincipalID { + continue + } + finalGroups = append(finalGroups, groups[i]) + } + updateDatas[m[CommonFieldID].(string)] = map[string]interface{}{ + roleFieldUserGroups: utils.MustJson(groups), + CommonFieldModifyTime: time.Now(), + } + } + return false + }, values) + + if err != nil { + log.Error("[Store][role] get principal all role", zap.String("principal", p.String()), zap.Error(err)) + return store.Error(err) + } + + for id := range updateDatas { + if err := updateValue(dbTx, tblRole, id, updateDatas[id]); err != nil { + log.Error("[store][role] clean principal all roles", zap.String("principal", p.String()), zap.Error(err)) + return store.Error(err) + } + } + return nil +} + +// GetMoreRoles get more role for cache update +func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { + fields := []string{CommonFieldModifyTime, CommonFieldValid} + + ret, err := s.handle.LoadValuesByFilter(tblRole, fields, &model.RoutingConfig{}, + func(m map[string]interface{}) bool { + if firstUpdate { + valid, _ := m[CommonFieldValid].(bool) + if valid { + return true + } + } + saveMtime, _ := m[CommonFieldModifyTime].(time.Time) + return !saveMtime.Before(mtime) + }) + if err != nil { + log.Errorf("[Store][role] get more role for cache, %v", err) + return nil, store.Error(err) + } + + roles := make([]*authcommon.Role, 0, len(ret)) + for i := range ret { + roles = append(roles, newRole(ret[i].(*roleData))) + } + return roles, nil +} + +type roleData struct { + ID string + Name string + Owner string + Source string + Type string + Metadata map[string]string + Valid bool + Description string + CreateTime time.Time + ModifyTime time.Time + Users string + UserGroups string +} + +func newRoleData(r *authcommon.Role) *roleData { + return &roleData{ + ID: r.ID, + Name: r.Name, + Owner: r.Owner, + Source: r.Source, + Type: r.Type, + Metadata: r.Metadata, + Description: r.Comment, + Users: utils.MustJson(r.Users), + UserGroups: utils.MustJson(r.UserGroups), + } +} + +func newRole(r *roleData) *authcommon.Role { + users := make([]*authcommon.User, 0, 32) + groups := make([]*authcommon.UserGroup, 0, 32) + + _ = json.Unmarshal([]byte(r.Users), &users) + _ = json.Unmarshal([]byte(r.UserGroups), &groups) + + return &authcommon.Role{ + ID: r.ID, + Name: r.Name, + Owner: r.Owner, + Source: r.Source, + Type: r.Type, + Metadata: r.Metadata, + Comment: r.Description, + Users: users, + UserGroups: groups, + CreateTime: r.CreateTime, + ModifyTime: r.ModifyTime, + } +} diff --git a/store/boltdb/strategy.go b/store/boltdb/strategy.go index d9f6cc926..1a028c9a8 100644 --- a/store/boltdb/strategy.go +++ b/store/boltdb/strategy.go @@ -28,7 +28,7 @@ import ( bolt "go.etcd.io/bbolt" "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -82,7 +82,7 @@ type strategyStore struct { } // AddStrategy add a new strategy -func (ss *strategyStore) AddStrategy(strategy *model.StrategyDetail) error { +func (ss *strategyStore) AddStrategy(tx store.Tx, strategy *authcommon.StrategyDetail) error { if strategy.ID == "" || strategy.Name == "" || strategy.Owner == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "add auth_strategy missing some params, id is %s, name is %s, owner is %s", @@ -90,44 +90,24 @@ func (ss *strategyStore) AddStrategy(strategy *model.StrategyDetail) error { } initStrategy(strategy) + dbTx := tx.GetDelegateTx().(*bolt.Tx) - proxy, err := ss.handler.StartTx() - if err != nil { - return err - } - tx := proxy.GetDelegateTx().(*bolt.Tx) - - defer func() { - _ = tx.Rollback() - }() - - return ss.addStrategy(tx, strategy) -} - -func (ss *strategyStore) addStrategy(tx *bolt.Tx, strategy *model.StrategyDetail) error { - if err := ss.cleanInvalidStrategy(tx, strategy.Name, strategy.Owner); err != nil { + if err := ss.cleanInvalidStrategy(dbTx, strategy.Name, strategy.Owner); err != nil { log.Error("[Store][Strategy] clean invalid auth_strategy", zap.Error(err), zap.String("name", strategy.Name), zap.Any("owner", strategy.Owner)) return err } - if err := saveValue(tx, tblStrategy, strategy.ID, convertForStrategyStore(strategy)); err != nil { + if err := saveValue(dbTx, tblStrategy, strategy.ID, convertForStrategyStore(strategy)); err != nil { log.Error("[Store][Strategy] save auth_strategy", zap.Error(err), zap.String("name", strategy.Name), zap.String("owner", strategy.Owner)) return err } - - if err := tx.Commit(); err != nil { - log.Error("[Store][Strategy] clean invalid auth_strategy tx commit", zap.Error(err), - zap.String("name", strategy.Name), zap.String("owner", strategy.Owner)) - return err - } - return nil } // UpdateStrategy update a strategy -func (ss *strategyStore) UpdateStrategy(strategy *model.ModifyStrategyDetail) error { +func (ss *strategyStore) UpdateStrategy(strategy *authcommon.ModifyStrategyDetail) error { if strategy.ID == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "update auth_strategy missing some params, id is %s", strategy.ID)) @@ -155,7 +135,7 @@ func (ss *strategyStore) UpdateStrategy(strategy *model.ModifyStrategyDetail) er } // updateStrategy -func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *model.ModifyStrategyDetail, +func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *authcommon.ModifyStrategyDetail, saveVal *strategyForStore) error { saveVal.Action = modify.Action @@ -185,10 +165,10 @@ func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *model.ModifyStrateg return nil } -func computePrincipals(remove bool, principals []model.Principal, saveVal *strategyForStore) { +func computePrincipals(remove bool, principals []authcommon.Principal, saveVal *strategyForStore) { for i := range principals { principal := principals[i] - if principal.PrincipalRole == model.PrincipalUser { + if principal.PrincipalType == authcommon.PrincipalUser { if remove { delete(saveVal.Users, principal.PrincipalID) } else { @@ -204,7 +184,7 @@ func computePrincipals(remove bool, principals []model.Principal, saveVal *strat } } -func computeResources(remove bool, resources []model.StrategyResource, saveVal *strategyForStore) { +func computeResources(remove bool, resources []authcommon.StrategyResource, saveVal *strategyForStore) { for i := range resources { resource := resources[i] if resource.ResType == int32(apisecurity.ResourceType_Namespaces) { @@ -254,16 +234,16 @@ func (ss *strategyStore) DeleteStrategy(id string) error { } // RemoveStrategyResources 删除策略的资源数据信息 -func (ss *strategyStore) RemoveStrategyResources(resources []model.StrategyResource) error { +func (ss *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyResource) error { return ss.operateStrategyResources(true, resources) } // LooseAddStrategyResources 松要求的添加鉴权策略的资源,允许忽略主键冲突的问题 -func (ss *strategyStore) LooseAddStrategyResources(resources []model.StrategyResource) error { +func (ss *strategyStore) LooseAddStrategyResources(resources []authcommon.StrategyResource) error { return ss.operateStrategyResources(false, resources) } -func (ss *strategyStore) operateStrategyResources(remove bool, resources []model.StrategyResource) error { +func (ss *strategyStore) operateStrategyResources(remove bool, resources []authcommon.StrategyResource) error { proxy, err := ss.handler.StartTx() if err != nil { return err @@ -331,13 +311,13 @@ func loadStrategyById(tx *bolt.Tx, id string) (*strategyForStore, error) { return ret, nil } -func buildResMap(resources []model.StrategyResource) map[string][]model.StrategyResource { - ret := make(map[string][]model.StrategyResource) +func buildResMap(resources []authcommon.StrategyResource) map[string][]authcommon.StrategyResource { + ret := make(map[string][]authcommon.StrategyResource) for i := range resources { resource := resources[i] if _, exist := ret[resource.StrategyID]; !exist { - ret[resource.StrategyID] = make([]model.StrategyResource, 0, 4) + ret[resource.StrategyID] = make([]authcommon.StrategyResource, 0, 4) } val := ret[resource.StrategyID] @@ -350,7 +330,7 @@ func buildResMap(resources []model.StrategyResource) map[string][]model.Strategy } // GetStrategyDetail 获取策略详情 -func (ss *strategyStore) GetStrategyDetail(id string) (*model.StrategyDetail, error) { +func (ss *strategyStore) GetStrategyDetail(id string) (*authcommon.StrategyDetail, error) { proxy, err := ss.handler.StartTx() if err != nil { return nil, err @@ -364,7 +344,7 @@ func (ss *strategyStore) GetStrategyDetail(id string) (*model.StrategyDetail, er } // GetStrategyDetail -func (ss *strategyStore) getStrategyDetail(tx *bolt.Tx, id string) (*model.StrategyDetail, error) { +func (ss *strategyStore) getStrategyDetail(tx *bolt.Tx, id string) (*authcommon.StrategyDetail, error) { ret, err := loadStrategyById(tx, id) if err != nil { return nil, err @@ -378,11 +358,11 @@ func (ss *strategyStore) getStrategyDetail(tx *bolt.Tx, id string) (*model.Strat // GetStrategyResources 获取策略的资源 func (ss *strategyStore) GetStrategyResources(principalId string, - principalRole model.PrincipalType) ([]model.StrategyResource, error) { + principalRole authcommon.PrincipalType) ([]authcommon.StrategyResource, error) { fields := []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldUsersPrincipal} - if principalRole == model.PrincipalGroup { + if principalRole == authcommon.PrincipalGroup { fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } @@ -395,7 +375,7 @@ func (ss *strategyStore) GetStrategyResources(principalId string, var principals map[string]string - if principalRole == model.PrincipalUser { + if principalRole == authcommon.PrincipalUser { principals, _ = m[StrategyFieldUsersPrincipal].(map[string]string) } else { principals, _ = m[StrategyFieldGroupsPrincipal].(map[string]string) @@ -410,7 +390,7 @@ func (ss *strategyStore) GetStrategyResources(principalId string, return nil, err } - ret := make([]model.StrategyResource, 0, 4) + ret := make([]authcommon.StrategyResource, 0, 4) for _, item := range values { rule := item.(*strategyForStore) @@ -420,11 +400,11 @@ func (ss *strategyStore) GetStrategyResources(principalId string, return ret, nil } -func collectStrategyResources(rule *strategyForStore) []model.StrategyResource { - ret := make([]model.StrategyResource, 0, len(rule.NsResources)+len(rule.SvcResources)+len(rule.CfgResources)) +func collectStrategyResources(rule *strategyForStore) []authcommon.StrategyResource { + ret := make([]authcommon.StrategyResource, 0, len(rule.NsResources)+len(rule.SvcResources)+len(rule.CfgResources)) for id := range rule.NsResources { - ret = append(ret, model.StrategyResource{ + ret = append(ret, authcommon.StrategyResource{ StrategyID: rule.ID, ResType: int32(apisecurity.ResourceType_Namespaces), ResID: id, @@ -432,7 +412,7 @@ func collectStrategyResources(rule *strategyForStore) []model.StrategyResource { } for id := range rule.SvcResources { - ret = append(ret, model.StrategyResource{ + ret = append(ret, authcommon.StrategyResource{ StrategyID: rule.ID, ResType: int32(apisecurity.ResourceType_Services), ResID: id, @@ -440,7 +420,7 @@ func collectStrategyResources(rule *strategyForStore) []model.StrategyResource { } for id := range rule.CfgResources { - ret = append(ret, model.StrategyResource{ + ret = append(ret, authcommon.StrategyResource{ StrategyID: rule.ID, ResType: int32(apisecurity.ResourceType_ConfigGroups), ResID: id, @@ -452,11 +432,11 @@ func collectStrategyResources(rule *strategyForStore) []model.StrategyResource { // GetDefaultStrategyDetailByPrincipal 获取默认策略详情 func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, - principalType model.PrincipalType) (*model.StrategyDetail, error) { + principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) { fields := []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldUsersPrincipal} - if principalType == model.PrincipalGroup { + if principalType == authcommon.PrincipalGroup { fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } @@ -474,7 +454,7 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, var principals map[string]string - if principalType == model.PrincipalUser { + if principalType == authcommon.PrincipalUser { principals, _ = m[StrategyFieldUsersPrincipal].(map[string]string) } else { principals, _ = m[StrategyFieldGroupsPrincipal].(map[string]string) @@ -508,7 +488,7 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, // GetStrategies 查询鉴权策略列表 func (ss *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.StrategyDetail, error) { + []*authcommon.StrategyDetail, error) { showDetail := filters["show_detail"] delete(filters, "show_detail") @@ -517,7 +497,7 @@ func (ss *strategyStore) GetStrategies(filters map[string]string, offset uint32, } func (ss *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, - showDetail bool) (uint32, []*model.StrategyDetail, error) { + showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { fields := []string{StrategyFieldValid, StrategyFieldName, StrategyFieldUsersPrincipal, StrategyFieldGroupsPrincipal, StrategyFieldNsResources, StrategyFieldSvcResources, @@ -594,8 +574,8 @@ func (ss *strategyStore) listStrategies(filters map[string]string, offset uint32 return uint32(len(values)), doStrategyPage(values, offset, limit, showDetail), nil } -func doStrategyPage(ret map[string]interface{}, offset, limit uint32, showDetail bool) []*model.StrategyDetail { - rules := make([]*model.StrategyDetail, 0, len(ret)) +func doStrategyPage(ret map[string]interface{}, offset, limit uint32, showDetail bool) []*authcommon.StrategyDetail { + rules := make([]*authcommon.StrategyDetail, 0, len(ret)) beginIndex := offset endIndex := beginIndex + limit @@ -614,8 +594,8 @@ func doStrategyPage(ret map[string]interface{}, offset, limit uint32, showDetail endIndex = totalCount } - emptyPrincipals := make([]model.Principal, 0) - emptyResources := make([]model.StrategyResource, 0) + emptyPrincipals := make([]authcommon.Principal, 0) + emptyResources := make([]authcommon.StrategyResource, 0) for k := range ret { rule := convertForStrategyDetail(ret[k].(*strategyForStore)) @@ -673,9 +653,8 @@ func comparePrincipalExist(principalType, principalId string, m map[string]inter return true } -// GetStrategyDetailsForCache get strategy details for cache -func (ss *strategyStore) GetStrategyDetailsForCache(mtime time.Time, - firstUpdate bool) ([]*model.StrategyDetail, error) { +// GetMoreStrategies get strategy details for cache +func (ss *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { ret, err := ss.handler.LoadValuesByFilter(tblStrategy, []string{StrategyFieldModifyTime}, &strategyForStore{}, func(m map[string]interface{}) bool { @@ -688,7 +667,7 @@ func (ss *strategyStore) GetStrategyDetailsForCache(mtime time.Time, return nil, err } - strategies := make([]*model.StrategyDetail, 0, len(ret)) + strategies := make([]*authcommon.StrategyDetail, 0, len(ret)) for k := range ret { val := ret[k] @@ -698,122 +677,97 @@ func (ss *strategyStore) GetStrategyDetailsForCache(mtime time.Time, return strategies, nil } -// cleanInvalidStrategy clean up authentication strategy by name -func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) error { - - fields := []string{StrategyFieldName, StrategyFieldOwner, StrategyFieldValid} +func (ss *strategyStore) CleanPrincipalPolicies(tx store.Tx, p authcommon.Principal) error { + fields := []string{StrategyFieldDefault, StrategyFieldUsersPrincipal, StrategyFieldGroupsPrincipal} values := make(map[string]interface{}) - err := loadValuesByFilter(tx, tblStrategy, fields, &strategyForStore{}, + dbTx := tx.GetDelegateTx().(*bolt.Tx) + err := loadValuesByFilter(dbTx, tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { - valid, ok := m[StrategyFieldValid].(bool) - // 如果数据是 valid 的,则不能被清理 - if ok && valid { + isDefault := m[StrategyFieldDefault].(bool) + if !isDefault { return false } - saveName := m[StrategyFieldName] - saveOwner := m[StrategyFieldOwner] + var principals map[string]string + if p.PrincipalType == authcommon.PrincipalUser { + principals = m[StrategyFieldUsersPrincipal].(map[string]string) + } else { + principals = m[StrategyFieldGroupsPrincipal].(map[string]string) + } - return saveName == name && saveOwner == owner + if len(principals) != 1 { + return false + } + _, exist := principals[p.PrincipalID] + return exist }, values) if err != nil { - log.Error("[Store][Strategy] clean invalid auth_strategy", zap.Error(err), - zap.String("name", name), zap.Any("owner", owner)) + log.Error("[Store][Strategy] load link auth_strategy", zap.Error(err), zap.String("principal", p.String())) return err } if len(values) == 0 { return nil } + if len(values) > 1 { + return ErrorMultiDefaultStrategy + } - keys := make([]string, 0, len(values)) for k := range values { - keys = append(keys, k) - } - return deleteValues(tx, tblStrategy, keys) -} + properties := make(map[string]interface{}) + properties[StrategyFieldValid] = false + properties[StrategyFieldModifyTime] = time.Now() -func createDefaultStrategy(tx *bolt.Tx, role model.PrincipalType, principalId, name, owner string) error { - strategy := &model.StrategyDetail{ - ID: utils.NewUUID(), - Name: model.BuildDefaultStrategyName(role, name), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Default: true, - Owner: owner, - Revision: utils.NewUUID(), - Resources: []model.StrategyResource{}, - Valid: true, - Principals: []model.Principal{ - { - PrincipalID: principalId, - PrincipalRole: role, - }, - }, - Comment: "Default Strategy", - } - - return saveValue(tx, tblStrategy, strategy.ID, convertForStrategyStore(strategy)) + if err := updateValue(dbTx, tblStrategy, k, properties); err != nil { + log.Error("[Store][Strategy] clean link auth_strategy", zap.String("principal", p.String()), zap.Error(err)) + return err + } + } + return nil } -func cleanLinkStrategy(tx *bolt.Tx, role model.PrincipalType, principalId, owner string) error { +// cleanInvalidStrategy clean up authentication strategy by name +func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) error { - fields := []string{StrategyFieldDefault, StrategyFieldUsersPrincipal, StrategyFieldGroupsPrincipal} + fields := []string{StrategyFieldName, StrategyFieldOwner, StrategyFieldValid} values := make(map[string]interface{}) err := loadValuesByFilter(tx, tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { - isDefault := m[StrategyFieldDefault].(bool) - if !isDefault { + valid, ok := m[StrategyFieldValid].(bool) + // 如果数据是 valid 的,则不能被清理 + if ok && valid { return false } - var principals map[string]string - if role == model.PrincipalUser { - principals = m[StrategyFieldUsersPrincipal].(map[string]string) - } else { - principals = m[StrategyFieldGroupsPrincipal].(map[string]string) - } + saveName := m[StrategyFieldName] + saveOwner := m[StrategyFieldOwner] - if len(principals) != 1 { - return false - } - _, exist := principals[principalId] - return exist + return saveName == name && saveOwner == owner }, values) if err != nil { - log.Error("[Store][Strategy] load link auth_strategy", zap.Error(err), - zap.String("principal-id", principalId), zap.Any("principal-type", role)) + log.Error("[Store][Strategy] clean invalid auth_strategy", zap.Error(err), + zap.String("name", name), zap.Any("owner", owner)) return err } if len(values) == 0 { return nil } - if len(values) > 1 { - return ErrorMultiDefaultStrategy - } + keys := make([]string, 0, len(values)) for k := range values { - - properties := make(map[string]interface{}) - properties[StrategyFieldValid] = false - properties[StrategyFieldModifyTime] = time.Now() - - if err := updateValue(tx, tblStrategy, k, properties); err != nil { - log.Error("[Store][Strategy] clean link auth_strategy", zap.Error(err), - zap.String("principal-id", principalId), zap.Any("principal-type", role)) - return err - } + keys = append(keys, k) } - return nil + return deleteValues(tx, tblStrategy, keys) } -func convertForStrategyStore(strategy *model.StrategyDetail) *strategyForStore { +func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyForStore { var ( users = make(map[string]string, 4) @@ -823,7 +777,7 @@ func convertForStrategyStore(strategy *model.StrategyDetail) *strategyForStore { for i := range principals { principal := principals[i] - if principal.PrincipalRole == model.PrincipalUser { + if principal.PrincipalType == authcommon.PrincipalUser { users[principal.PrincipalID] = "" } else { groups[principal.PrincipalID] = "" @@ -867,32 +821,32 @@ func convertForStrategyStore(strategy *model.StrategyDetail) *strategyForStore { } } -func convertForStrategyDetail(strategy *strategyForStore) *model.StrategyDetail { +func convertForStrategyDetail(strategy *strategyForStore) *authcommon.StrategyDetail { - principals := make([]model.Principal, 0, len(strategy.Users)+len(strategy.Groups)) - resources := make([]model.StrategyResource, 0, len(strategy.NsResources)+ + principals := make([]authcommon.Principal, 0, len(strategy.Users)+len(strategy.Groups)) + resources := make([]authcommon.StrategyResource, 0, len(strategy.NsResources)+ len(strategy.SvcResources)+len(strategy.CfgResources)) for id := range strategy.Users { - principals = append(principals, model.Principal{ + principals = append(principals, authcommon.Principal{ StrategyID: strategy.ID, PrincipalID: id, - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }) } for id := range strategy.Groups { - principals = append(principals, model.Principal{ + principals = append(principals, authcommon.Principal{ StrategyID: strategy.ID, PrincipalID: id, - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }) } - fillRes := func(idMap map[string]string, resType apisecurity.ResourceType) []model.StrategyResource { - res := make([]model.StrategyResource, 0, len(idMap)) + fillRes := func(idMap map[string]string, resType apisecurity.ResourceType) []authcommon.StrategyResource { + res := make([]authcommon.StrategyResource, 0, len(idMap)) for id := range idMap { - res = append(res, model.StrategyResource{ + res = append(res, authcommon.StrategyResource{ StrategyID: strategy.ID, ResType: int32(resType), ResID: id, @@ -906,7 +860,7 @@ func convertForStrategyDetail(strategy *strategyForStore) *model.StrategyDetail resources = append(resources, fillRes(strategy.SvcResources, apisecurity.ResourceType_Services)...) resources = append(resources, fillRes(strategy.CfgResources, apisecurity.ResourceType_ConfigGroups)...) - return &model.StrategyDetail{ + return &authcommon.StrategyDetail{ ID: strategy.ID, Name: strategy.Name, Action: strategy.Action, @@ -922,7 +876,7 @@ func convertForStrategyDetail(strategy *strategyForStore) *model.StrategyDetail } } -func initStrategy(rule *model.StrategyDetail) { +func initStrategy(rule *authcommon.StrategyDetail) { if rule != nil { rule.Valid = true diff --git a/store/boltdb/strategy_test.go b/store/boltdb/strategy_test.go index 0293c2dd5..4f687b8b9 100644 --- a/store/boltdb/strategy_test.go +++ b/store/boltdb/strategy_test.go @@ -26,29 +26,29 @@ import ( apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "github.com/stretchr/testify/assert" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) -func createTestStrategy(num int) []*model.StrategyDetail { - ret := make([]*model.StrategyDetail, 0, num) +func createTestStrategy(num int) []*authcommon.StrategyDetail { + ret := make([]*authcommon.StrategyDetail, 0, num) for i := 0; i < num; i++ { - ret = append(ret, &model.StrategyDetail{ + ret = append(ret, &authcommon.StrategyDetail{ ID: fmt.Sprintf("strategy-%d", i), Name: fmt.Sprintf("strategy-%d", i), Action: apisecurity.AuthAction_READ_WRITE.String(), Comment: fmt.Sprintf("strategy-%d", i), - Principals: []model.Principal{ + Principals: []authcommon.Principal{ { StrategyID: fmt.Sprintf("strategy-%d", i), PrincipalID: fmt.Sprintf("user-%d", i), - PrincipalRole: model.PrincipalUser, + PrincipalType: authcommon.PrincipalUser, }, }, Default: true, Owner: "polaris", - Resources: []model.StrategyResource{ + Resources: []authcommon.StrategyResource{ { StrategyID: "", ResType: int32(apisecurity.ResourceType_Namespaces), @@ -70,9 +70,13 @@ func Test_strategyStore_AddStrategy(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) }) } @@ -81,30 +85,35 @@ func Test_strategyStore_UpdateStrategy(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) - addPrincipals := []model.Principal{{ + addPrincipals := []authcommon.Principal{{ StrategyID: rules[0].ID, PrincipalID: utils.NewUUID(), - PrincipalRole: model.PrincipalGroup, + PrincipalType: authcommon.PrincipalGroup, }} - req := &model.ModifyStrategyDetail{ + req := &authcommon.ModifyStrategyDetail{ ID: rules[0].ID, Name: rules[0].Name, Action: rules[0].Action, Comment: "update-strategy", AddPrincipals: addPrincipals, - RemovePrincipals: []model.Principal{}, - AddResources: []model.StrategyResource{ + RemovePrincipals: []authcommon.Principal{}, + AddResources: []authcommon.StrategyResource{ { StrategyID: rules[0].ID, ResType: int32(apisecurity.ResourceType_Services), ResID: utils.NewUUID(), }, }, - RemoveResources: []model.StrategyResource{}, + RemoveResources: []authcommon.StrategyResource{}, ModifyTime: time.Time{}, } @@ -123,8 +132,12 @@ func Test_strategyStore_DeleteStrategy(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) err = ss.DeleteStrategy(rules[0].ID) assert.Nil(t, err, "delete strategy must success") @@ -140,10 +153,14 @@ func Test_strategyStore_RemoveStrategyResources(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) - err = ss.RemoveStrategyResources([]model.StrategyResource{ + err = ss.RemoveStrategyResources([]authcommon.StrategyResource{ { StrategyID: rules[0].ID, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -157,7 +174,7 @@ func Test_strategyStore_RemoveStrategyResources(t *testing.T) { for i := range ret.Resources { res := ret.Resources[i] t.Logf("resource=%#v", res) - assert.NotEqual(t, res, model.StrategyResource{ + assert.NotEqual(t, res, authcommon.StrategyResource{ StrategyID: rules[0].ID, ResType: int32(apisecurity.ResourceType_Namespaces), ResID: "namespace_0", @@ -171,10 +188,14 @@ func Test_strategyStore_LooseAddStrategyResources(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) - err = ss.LooseAddStrategyResources([]model.StrategyResource{ + err = ss.LooseAddStrategyResources([]authcommon.StrategyResource{ { StrategyID: rules[0].ID, ResType: int32(apisecurity.ResourceType_Namespaces), @@ -185,12 +206,12 @@ func Test_strategyStore_LooseAddStrategyResources(t *testing.T) { ret, err := ss.GetStrategyDetail(rules[0].ID) assert.Nil(t, err, "get strategy must success") - ans := make([]model.StrategyResource, 0) + ans := make([]authcommon.StrategyResource, 0) for i := range ret.Resources { res := ret.Resources[i] t.Logf("resource=%#v", res) res.StrategyID = rules[0].ID - if reflect.DeepEqual(res, model.StrategyResource{ + if reflect.DeepEqual(res, authcommon.StrategyResource{ StrategyID: rules[0].ID, ResType: int32(apisecurity.ResourceType_Namespaces), ResID: "namespace_1", @@ -208,8 +229,12 @@ func Test_strategyStore_GetStrategyDetail(t *testing.T) { ss := &strategyStore{handler: handler} rules := createTestStrategy(1) - err := ss.AddStrategy(rules[0]) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rules[0]) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) v, err := ss.GetStrategyDetail(rules[0].ID) assert.Nil(t, err, "get strategy-detail must success") @@ -229,14 +254,18 @@ func Test_strategyStore_GetStrategyResources(t *testing.T) { rules := createTestStrategy(2) for i := range rules { rule := rules[i] - err := ss.AddStrategy(rule) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rule) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) } - res, err := ss.GetStrategyResources("user-1", model.PrincipalUser) + res, err := ss.GetStrategyResources("user-1", authcommon.PrincipalUser) assert.Nil(t, err, "GetStrategyResources must success") - assert.ElementsMatch(t, []model.StrategyResource{ + assert.ElementsMatch(t, []authcommon.StrategyResource{ { StrategyID: "strategy-1", ResType: int32(apisecurity.ResourceType_Namespaces), @@ -255,11 +284,15 @@ func Test_strategyStore_GetDefaultStrategyDetailByPrincipal(t *testing.T) { rule := rules[i] rule.Default = i == 1 rules[i] = rule - err := ss.AddStrategy(rule) + tx, err := handler.StartTx() + assert.NoError(t, err) + err = ss.AddStrategy(tx, rule) assert.Nil(t, err, "add strategy must success") + err = tx.Commit() + assert.NoError(t, err) } - res, err := ss.GetDefaultStrategyDetailByPrincipal("user-1", model.PrincipalUser) + res, err := ss.GetDefaultStrategyDetailByPrincipal("user-1", authcommon.PrincipalUser) assert.Nil(t, err, "GetStrategyResources must success") rules[1].ModifyTime = rules[1].CreateTime diff --git a/store/boltdb/user.go b/store/boltdb/user.go index 161e85119..cc386a404 100644 --- a/store/boltdb/user.go +++ b/store/boltdb/user.go @@ -26,7 +26,7 @@ import ( bolt "go.etcd.io/bbolt" "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -76,7 +76,7 @@ type userStore struct { } // AddUser 添加用户 -func (us *userStore) AddUser(user *model.User) error { +func (us *userStore) AddUser(tx store.Tx, user *authcommon.User) error { initUser(user) @@ -85,48 +85,14 @@ func (us *userStore) AddUser(user *model.User) error { return store.NewStatusError(store.EmptyParamsErr, "add user missing some params") } - return us.addUser(user) -} - -func (us *userStore) addUser(user *model.User) error { - proxy, err := us.handler.StartTx() - if err != nil { - return err - } - tx := proxy.GetDelegateTx().(*bolt.Tx) - - defer func() { - _ = tx.Rollback() - }() - + dbTx := tx.GetDelegateTx().(*bolt.Tx) owner := user.Owner if owner == "" { owner = user.ID } // 添加用户信息 - if err := us.addUserMain(tx, user); err != nil { - return err - } - - // 添加用户的默认策略 - if err := createDefaultStrategy(tx, model.PrincipalUser, user.ID, user.Name, owner); err != nil { - log.Error("[Store][User] create user default strategy fail", zap.Error(err), - zap.String("name", user.Name)) - return err - } - - if err := tx.Commit(); err != nil { - log.Error("[Store][User] save user tx commit fail", zap.Error(err), - zap.String("name", user.Name)) - return err - } - return nil -} - -func (us *userStore) addUserMain(tx *bolt.Tx, user *model.User) error { - // 添加用户信息 - if err := saveValue(tx, tblUser, user.ID, converToUserStore(user)); err != nil { + if err := saveValue(dbTx, tblUser, user.ID, converToUserStore(user)); err != nil { log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) return err } @@ -134,7 +100,7 @@ func (us *userStore) addUserMain(tx *bolt.Tx, user *model.User) error { } // UpdateUser -func (us *userStore) UpdateUser(user *model.User) error { +func (us *userStore) UpdateUser(user *authcommon.User) error { if user.ID == "" || user.Token == "" { return store.NewStatusError(store.EmptyParamsErr, "update user missing some params") } @@ -158,52 +124,25 @@ func (us *userStore) UpdateUser(user *model.User) error { } // DeleteUser 删除用户 -func (us *userStore) DeleteUser(user *model.User) error { +func (us *userStore) DeleteUser(tx store.Tx, user *authcommon.User) error { if user.ID == "" { return store.NewStatusError(store.EmptyParamsErr, "delete user missing some params") } - - return us.deleteUser(user) -} - -func (us *userStore) deleteUser(user *model.User) error { - proxy, err := us.handler.StartTx() - if err != nil { - return err - } - tx := proxy.GetDelegateTx().(*bolt.Tx) - - defer func() { - _ = tx.Rollback() - }() + dbTx := tx.GetDelegateTx().(*bolt.Tx) properties := make(map[string]interface{}) properties[UserFieldValid] = false properties[UserFieldModifyTime] = time.Now() - if err := updateValue(tx, tblUser, user.ID, properties); err != nil { + if err := updateValue(dbTx, tblUser, user.ID, properties); err != nil { log.Error("[Store][User] delete user by id", zap.Error(err), zap.String("id", user.ID)) return err } - - owner := user.Owner - if owner == "" { - owner = user.ID - } - - if err := cleanLinkStrategy(tx, model.PrincipalUser, user.ID, user.Owner); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - log.Error("[Store][User] delete user tx commit", zap.Error(err), zap.String("id", user.ID)) - return err - } return nil } // GetUser 获取用户 -func (us *userStore) GetUser(id string) (*model.User, error) { +func (us *userStore) GetUser(id string) (*authcommon.User, error) { if id == "" { return nil, store.NewStatusError(store.EmptyParamsErr, "get user missing some params") } @@ -221,7 +160,7 @@ func (us *userStore) GetUser(id string) (*model.User, error) { } // GetUser 获取用户 -func (us *userStore) getUser(tx *bolt.Tx, id string) (*model.User, error) { +func (us *userStore) getUser(tx *bolt.Tx, id string) (*authcommon.User, error) { if id == "" { return nil, store.NewStatusError(store.EmptyParamsErr, "get user missing id params") } @@ -243,7 +182,7 @@ func (us *userStore) getUser(tx *bolt.Tx, id string) (*model.User, error) { } // GetUserByName 获取用户 -func (us *userStore) GetUserByName(name, ownerId string) (*model.User, error) { +func (us *userStore) GetUserByName(name, ownerId string) (*authcommon.User, error) { if name == "" { return nil, store.NewStatusError(store.EmptyParamsErr, "get user missing name params") } @@ -285,7 +224,7 @@ func (us *userStore) GetUserByName(name, ownerId string) (*model.User, error) { } // GetUserByIds 通过用户ID批量获取用户 -func (us *userStore) GetUserByIds(ids []string) ([]*model.User, error) { +func (us *userStore) GetUserByIds(ids []string) ([]*authcommon.User, error) { if len(ids) == 0 { return nil, nil } @@ -299,7 +238,7 @@ func (us *userStore) GetUserByIds(ids []string) ([]*model.User, error) { return nil, nil } - users := make([]*model.User, 0, len(ids)) + users := make([]*authcommon.User, 0, len(ids)) for k := range ret { user := ret[k].(*userForStore) if !user.Valid { @@ -312,7 +251,7 @@ func (us *userStore) GetUserByIds(ids []string) ([]*model.User, error) { } // GetSubCount 获取子账户的个数 -func (us *userStore) GetSubCount(user *model.User) (uint32, error) { +func (us *userStore) GetSubCount(user *authcommon.User) (uint32, error) { ownerId := user.ID ret, err := us.handler.LoadValuesByFilter(tblUser, []string{UserFieldOwner, UserFieldValid}, &userForStore{}, func(m map[string]interface{}) bool { @@ -334,7 +273,7 @@ func (us *userStore) GetSubCount(user *model.User) (uint32, error) { } // GetUsers 获取用户列表 -func (us *userStore) GetUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*model.User, error) { +func (us *userStore) GetUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*authcommon.User, error) { if _, ok := filters["group_id"]; ok { return us.getGroupUsers(filters, offset, limit) } @@ -346,7 +285,7 @@ func (us *userStore) GetUsers(filters map[string]string, offset uint32, limit ui // "name": 1, // "owner": 1, // "source": 1, -func (us *userStore) getUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*model.User, error) { +func (us *userStore) getUsers(filters map[string]string, offset uint32, limit uint32) (uint32, []*authcommon.User, error) { fields := []string{UserFieldID, UserFieldName, UserFieldOwner, UserFieldSource, UserFieldValid, UserFieldType} ret, err := us.handler.LoadValuesByFilter(tblUser, fields, &userForStore{}, func(m map[string]interface{}) bool { @@ -363,7 +302,7 @@ func (us *userStore) getUsers(filters map[string]string, offset uint32, limit ui saveType, _ := m[UserFieldType].(int64) // 超级账户不做展示 - if model.UserRoleType(saveType) == model.AdminUserRole && + if authcommon.UserRoleType(saveType) == authcommon.AdminUserRole && strings.Compare("true", filters["hide_admin"]) == 0 { return false } @@ -414,7 +353,7 @@ func (us *userStore) getUsers(filters map[string]string, offset uint32, limit ui // getGroupUsers 获取某个用户组下的所有用户列表数据信息 func (us *userStore) getGroupUsers(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.User, error) { + []*authcommon.User, error) { groupId := filters["group_id"] delete(filters, "group_id") @@ -448,7 +387,7 @@ func (us *userStore) getGroupUsers(filters map[string]string, offset uint32, lim return false } - if model.UserRoleType(user.Type) == model.AdminUserRole { + if authcommon.UserRoleType(user.Type) == authcommon.AdminUserRole { return false } @@ -491,7 +430,7 @@ func (us *userStore) getGroupUsers(filters map[string]string, offset uint32, lim } // GetUsersForCache 获取所有用户信息 -func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*model.User, error) { +func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.User, error) { ret, err := us.handler.LoadValuesByFilter(tblUser, []string{UserFieldModifyTime}, &userForStore{}, func(m map[string]interface{}) bool { mt := m[UserFieldModifyTime].(time.Time) @@ -503,7 +442,7 @@ func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*mod return nil, err } - users := make([]*model.User, 0, len(ret)) + users := make([]*authcommon.User, 0, len(ret)) for k := range ret { val := ret[k] users = append(users, converToUserModel(val.(*userForStore))) @@ -513,8 +452,8 @@ func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*mod } // doPage 进行分页 -func doUserPage(ret map[string]interface{}, offset, limit uint32) []*model.User { - users := make([]*model.User, 0, len(ret)) +func doUserPage(ret map[string]interface{}, offset, limit uint32) []*authcommon.User { + users := make([]*authcommon.User, 0, len(ret)) beginIndex := offset endIndex := beginIndex + limit totalCount := uint32(len(ret)) @@ -542,7 +481,7 @@ func doUserPage(ret map[string]interface{}, offset, limit uint32) []*model.User return users[beginIndex:endIndex] } -func converToUserStore(user *model.User) *userForStore { +func converToUserStore(user *authcommon.User) *userForStore { return &userForStore{ ID: user.ID, Name: user.Name, @@ -559,14 +498,14 @@ func converToUserStore(user *model.User) *userForStore { } } -func converToUserModel(user *userForStore) *model.User { - return &model.User{ +func converToUserModel(user *userForStore) *authcommon.User { + return &authcommon.User{ ID: user.ID, Name: user.Name, Password: user.Password, Owner: user.Owner, Source: user.Source, - Type: model.UserRoleType(user.Type), + Type: authcommon.UserRoleType(user.Type), Token: user.Token, TokenEnable: user.TokenEnable, Valid: user.Valid, @@ -576,7 +515,7 @@ func converToUserModel(user *userForStore) *model.User { } } -func initUser(user *model.User) { +func initUser(user *authcommon.User) { if user != nil { tn := time.Now() user.Valid = true diff --git a/store/boltdb/user_test.go b/store/boltdb/user_test.go index b0d5dfa72..1c5441513 100644 --- a/store/boltdb/user_test.go +++ b/store/boltdb/user_test.go @@ -26,20 +26,20 @@ import ( "github.com/stretchr/testify/assert" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) -func createTestUsers(num int) []*model.User { - ret := make([]*model.User, 0, num) +func createTestUsers(num int) []*authcommon.User { + ret := make([]*authcommon.User, 0, num) for i := 0; i < num; i++ { - ret = append(ret, &model.User{ + ret = append(ret, &authcommon.User{ ID: fmt.Sprintf("user_%d", i), Name: fmt.Sprintf("user_%d", i), Password: fmt.Sprintf("user_%d", i), Owner: "polaris", Source: "Polaris", - Type: model.SubAccountUserRole, + Type: authcommon.SubAccountUserRole, Token: "polaris", TokenEnable: true, Valid: true, @@ -57,10 +57,12 @@ func Test_userStore_AddUser(t *testing.T) { us := &userStore{handler: handler} users := createTestUsers(1) - - if err := us.AddUser(users[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := us.GetUser(users[0].ID) if err != nil { @@ -86,9 +88,12 @@ func Test_userStore_UpdateUser(t *testing.T) { users := createTestUsers(1) - if err := us.AddUser(users[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) users[0].Comment = "user update test" @@ -120,9 +125,12 @@ func Test_userStore_DeleteUser(t *testing.T) { users := createTestUsers(1) - if err := us.AddUser(users[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := us.GetUser(users[0].ID) if err != nil { @@ -133,9 +141,12 @@ func Test_userStore_DeleteUser(t *testing.T) { t.FailNow() } - if err = us.DeleteUser(users[0]); err != nil { + tx, err = handler.StartTx() + assert.NoError(t, err) + if err = us.DeleteUser(tx, users[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err = us.GetUser(users[0].ID) if err != nil { @@ -154,9 +165,12 @@ func Test_userStore_GetUserByName(t *testing.T) { users := createTestUsers(1) - if err := us.AddUser(users[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ret, err := us.GetUserByName(users[0].Name, users[0].Owner) if err != nil { @@ -184,9 +198,12 @@ func Test_userStore_GetUserByIds(t *testing.T) { ids := make([]string, 0, len(users)) for i := range users { - if err := us.AddUser(users[i]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[i]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) ids = append(ids, users[i].ID) } @@ -235,12 +252,15 @@ func Test_userStore_GetSubCount(t *testing.T) { users := createTestUsers(5) for i := range users { - if err := us.AddUser(users[i]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[i]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) } - total, err := us.GetSubCount(&model.User{ + total, err := us.GetSubCount(&authcommon.User{ ID: "polaris", }) @@ -262,9 +282,12 @@ func Test_userStore_GetUsers(t *testing.T) { users := createTestUsers(10) for i := range users { - if err := us.AddUser(users[i]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, users[i]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) } total, ret, err := us.GetUsers(map[string]string{ @@ -302,11 +325,14 @@ func Test_userStore_GetUsers(t *testing.T) { admins := createTestUsers(1) admins[0].ID = "admin" admins[0].Name = "admin" - admins[0].Type = model.AdminUserRole + admins[0].Type = authcommon.AdminUserRole - if err := us.AddUser(admins[0]); err != nil { + tx, err := handler.StartTx() + assert.NoError(t, err) + if err := us.AddUser(tx, admins[0]); err != nil { t.Fatal(err) } + assert.NoError(t, tx.Commit()) total, ret, err = us.GetUsers(map[string]string{ "hide_admin": "true", @@ -340,18 +366,21 @@ func Test_userStore_GetUsersByGroup(t *testing.T) { gs := &groupStore{handler: handler} groups := createTestUserGroup(1) + tx, err := handler.StartTx() + assert.NoError(t, err) for i := range groups { - if err := gs.AddGroup(groups[i]); err != nil { + if err := gs.AddGroup(tx, groups[i]); err != nil { t.Fatal(err) } } users := createTestUsers(10) for i := range users { - if err := us.AddUser(users[i]); err != nil { + if err := us.AddUser(tx, users[i]); err != nil { t.Fatal(err) } } + assert.NoError(t, tx.Commit()) total, ret, err := us.GetUsers(map[string]string{ "group_id": groups[0].ID, diff --git a/store/discover_api.go b/store/discover_api.go index 5e16ed666..20a702309 100644 --- a/store/discover_api.go +++ b/store/discover_api.go @@ -40,12 +40,6 @@ type NamingModuleStore interface { CircuitBreakerStore // ToolStore 函数及工具接口 ToolStore - // UserStore 用户接口 - UserStore - // GroupStore 用户组接口 - GroupStore - // StrategyStore 鉴权策略接口 - StrategyStore // RoutingConfigStoreV2 路由策略 v2 接口 RoutingConfigStoreV2 // FaultDetectRuleStore fault detect rule interface diff --git a/store/mock/admin_mock.go b/store/mock/admin_mock.go index bf12adcba..3d7fcbf72 100644 --- a/store/mock/admin_mock.go +++ b/store/mock/admin_mock.go @@ -8,7 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - model "github.com/polarismesh/polaris/common/model" + admin "github.com/polarismesh/polaris/common/model/admin" ) // MockLeaderElectionStore is a mock of LeaderElectionStore interface. @@ -95,10 +95,10 @@ func (mr *MockLeaderElectionStoreMockRecorder) GetVersion(key interface{}) *gomo } // ListLeaderElections mocks base method. -func (m *MockLeaderElectionStore) ListLeaderElections() ([]*model.LeaderElection, error) { +func (m *MockLeaderElectionStore) ListLeaderElections() ([]*admin.LeaderElection, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListLeaderElections") - ret0, _ := ret[0].([]*model.LeaderElection) + ret0, _ := ret[0].([]*admin.LeaderElection) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/store/mock/api_mock.go b/store/mock/api_mock.go index 9fed72254..3b397247d 100644 --- a/store/mock/api_mock.go +++ b/store/mock/api_mock.go @@ -11,6 +11,8 @@ import ( gomock "github.com/golang/mock/gomock" model "github.com/polarismesh/polaris/common/model" + admin "github.com/polarismesh/polaris/common/model/admin" + auth "github.com/polarismesh/polaris/common/model/auth" store "github.com/polarismesh/polaris/store" ) @@ -52,17 +54,17 @@ func (mr *MockStoreMockRecorder) ActiveConfigFileReleaseTx(tx, release interface } // AddGroup mocks base method. -func (m *MockStore) AddGroup(group *model.UserGroupDetail) error { +func (m *MockStore) AddGroup(tx store.Tx, group *auth.UserGroupDetail) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddGroup", group) + ret := m.ctrl.Call(m, "AddGroup", tx, group) ret0, _ := ret[0].(error) return ret0 } // AddGroup indicates an expected call of AddGroup. -func (mr *MockStoreMockRecorder) AddGroup(group interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) AddGroup(tx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddGroup", reflect.TypeOf((*MockStore)(nil).AddGroup), group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddGroup", reflect.TypeOf((*MockStore)(nil).AddGroup), tx, group) } // AddInstance mocks base method. @@ -107,6 +109,20 @@ func (mr *MockStoreMockRecorder) AddNamespace(namespace interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNamespace", reflect.TypeOf((*MockStore)(nil).AddNamespace), namespace) } +// AddRole mocks base method. +func (m *MockStore) AddRole(role *auth.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddRole", role) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddRole indicates an expected call of AddRole. +func (mr *MockStoreMockRecorder) AddRole(role interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRole", reflect.TypeOf((*MockStore)(nil).AddRole), role) +} + // AddService mocks base method. func (m *MockStore) AddService(service *model.Service) error { m.ctrl.T.Helper() @@ -136,31 +152,31 @@ func (mr *MockStoreMockRecorder) AddServiceContractInterfaces(contract interface } // AddStrategy mocks base method. -func (m *MockStore) AddStrategy(strategy *model.StrategyDetail) error { +func (m *MockStore) AddStrategy(tx store.Tx, strategy *auth.StrategyDetail) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddStrategy", strategy) + ret := m.ctrl.Call(m, "AddStrategy", tx, strategy) ret0, _ := ret[0].(error) return ret0 } // AddStrategy indicates an expected call of AddStrategy. -func (mr *MockStoreMockRecorder) AddStrategy(strategy interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) AddStrategy(tx, strategy interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddStrategy", reflect.TypeOf((*MockStore)(nil).AddStrategy), strategy) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddStrategy", reflect.TypeOf((*MockStore)(nil).AddStrategy), tx, strategy) } // AddUser mocks base method. -func (m *MockStore) AddUser(user *model.User) error { +func (m *MockStore) AddUser(tx store.Tx, user *auth.User) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddUser", user) + ret := m.ctrl.Call(m, "AddUser", tx, user) ret0, _ := ret[0].(error) return ret0 } // AddUser indicates an expected call of AddUser. -func (mr *MockStoreMockRecorder) AddUser(user interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) AddUser(tx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUser", reflect.TypeOf((*MockStore)(nil).AddUser), user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUser", reflect.TypeOf((*MockStore)(nil).AddUser), tx, user) } // AppendServiceContractInterfaces mocks base method. @@ -435,6 +451,34 @@ func (mr *MockStoreMockRecorder) CleanInstance(instanceID interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanInstance", reflect.TypeOf((*MockStore)(nil).CleanInstance), instanceID) } +// CleanPrincipalPolicies mocks base method. +func (m *MockStore) CleanPrincipalPolicies(tx store.Tx, p auth.Principal) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanPrincipalPolicies", tx, p) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanPrincipalPolicies indicates an expected call of CleanPrincipalPolicies. +func (mr *MockStoreMockRecorder) CleanPrincipalPolicies(tx, p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanPrincipalPolicies", reflect.TypeOf((*MockStore)(nil).CleanPrincipalPolicies), tx, p) +} + +// CleanPrincipalRoles mocks base method. +func (m *MockStore) CleanPrincipalRoles(tx store.Tx, p *auth.Principal) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanPrincipalRoles", tx, p) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanPrincipalRoles indicates an expected call of CleanPrincipalRoles. +func (mr *MockStoreMockRecorder) CleanPrincipalRoles(tx, p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanPrincipalRoles", reflect.TypeOf((*MockStore)(nil).CleanPrincipalRoles), tx, p) +} + // CountConfigFileEachGroup mocks base method. func (m *MockStore) CountConfigFileEachGroup() (map[string]map[string]int64, error) { m.ctrl.T.Helper() @@ -765,17 +809,17 @@ func (mr *MockStoreMockRecorder) DeleteFaultDetectRule(id interface{}) *gomock.C } // DeleteGroup mocks base method. -func (m *MockStore) DeleteGroup(group *model.UserGroupDetail) error { +func (m *MockStore) DeleteGroup(tx store.Tx, group *auth.UserGroupDetail) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGroup", group) + ret := m.ctrl.Call(m, "DeleteGroup", tx, group) ret0, _ := ret[0].(error) return ret0 } // DeleteGroup indicates an expected call of DeleteGroup. -func (mr *MockStoreMockRecorder) DeleteGroup(group interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteGroup(tx, group interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroup", reflect.TypeOf((*MockStore)(nil).DeleteGroup), group) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroup", reflect.TypeOf((*MockStore)(nil).DeleteGroup), tx, group) } // DeleteInstance mocks base method. @@ -820,6 +864,20 @@ func (mr *MockStoreMockRecorder) DeleteRateLimit(limiting interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRateLimit", reflect.TypeOf((*MockStore)(nil).DeleteRateLimit), limiting) } +// DeleteRole mocks base method. +func (m *MockStore) DeleteRole(role *auth.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRole", role) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRole indicates an expected call of DeleteRole. +func (mr *MockStoreMockRecorder) DeleteRole(role interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRole", reflect.TypeOf((*MockStore)(nil).DeleteRole), role) +} + // DeleteRoutingConfig mocks base method. func (m *MockStore) DeleteRoutingConfig(serviceID string) error { m.ctrl.T.Helper() @@ -933,17 +991,17 @@ func (mr *MockStoreMockRecorder) DeleteStrategy(id interface{}) *gomock.Call { } // DeleteUser mocks base method. -func (m *MockStore) DeleteUser(user *model.User) error { +func (m *MockStore) DeleteUser(tx store.Tx, user *auth.User) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUser", user) + ret := m.ctrl.Call(m, "DeleteUser", tx, user) ret0, _ := ret[0].(error) return ret0 } // DeleteUser indicates an expected call of DeleteUser. -func (mr *MockStoreMockRecorder) DeleteUser(user interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteUser(tx, user interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockStore)(nil).DeleteUser), user) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockStore)(nil).DeleteUser), tx, user) } // Destroy mocks base method. @@ -1184,10 +1242,10 @@ func (mr *MockStoreMockRecorder) GetConfigFileTx(tx, namespace, group, name inte } // GetDefaultStrategyDetailByPrincipal mocks base method. -func (m *MockStore) GetDefaultStrategyDetailByPrincipal(principalId string, principalType model.PrincipalType) (*model.StrategyDetail, error) { +func (m *MockStore) GetDefaultStrategyDetailByPrincipal(principalId string, principalType auth.PrincipalType) (*auth.StrategyDetail, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetDefaultStrategyDetailByPrincipal", principalId, principalType) - ret0, _ := ret[0].(*model.StrategyDetail) + ret0, _ := ret[0].(*auth.StrategyDetail) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1262,10 +1320,10 @@ func (mr *MockStoreMockRecorder) GetFaultDetectRulesForCache(mtime, firstUpdate } // GetGroup mocks base method. -func (m *MockStore) GetGroup(id string) (*model.UserGroupDetail, error) { +func (m *MockStore) GetGroup(id string) (*auth.UserGroupDetail, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetGroup", id) - ret0, _ := ret[0].(*model.UserGroupDetail) + ret0, _ := ret[0].(*auth.UserGroupDetail) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1277,10 +1335,10 @@ func (mr *MockStoreMockRecorder) GetGroup(id interface{}) *gomock.Call { } // GetGroupByName mocks base method. -func (m *MockStore) GetGroupByName(name, owner string) (*model.UserGroup, error) { +func (m *MockStore) GetGroupByName(name, owner string) (*auth.UserGroup, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetGroupByName", name, owner) - ret0, _ := ret[0].(*model.UserGroup) + ret0, _ := ret[0].(*auth.UserGroup) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1292,11 +1350,11 @@ func (mr *MockStoreMockRecorder) GetGroupByName(name, owner interface{}) *gomock } // GetGroups mocks base method. -func (m *MockStore) GetGroups(filters map[string]string, offset, limit uint32) (uint32, []*model.UserGroup, error) { +func (m *MockStore) GetGroups(filters map[string]string, offset, limit uint32) (uint32, []*auth.UserGroup, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetGroups", filters, offset, limit) ret0, _ := ret[0].(uint32) - ret1, _ := ret[1].([]*model.UserGroup) + ret1, _ := ret[1].([]*auth.UserGroup) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } @@ -1308,10 +1366,10 @@ func (mr *MockStoreMockRecorder) GetGroups(filters, offset, limit interface{}) * } // GetGroupsForCache mocks base method. -func (m *MockStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) { +func (m *MockStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*auth.UserGroupDetail, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetGroupsForCache", mtime, firstUpdate) - ret0, _ := ret[0].([]*model.UserGroupDetail) + ret0, _ := ret[0].([]*auth.UserGroupDetail) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1669,6 +1727,21 @@ func (mr *MockStoreMockRecorder) GetMoreReleaseFile(firstUpdate, modifyTime inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMoreReleaseFile", reflect.TypeOf((*MockStore)(nil).GetMoreReleaseFile), firstUpdate, modifyTime) } +// GetMoreRoles mocks base method. +func (m *MockStore) GetMoreRoles(firstUpdate bool, modifyTime time.Time) ([]*auth.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMoreRoles", firstUpdate, modifyTime) + ret0, _ := ret[0].([]*auth.Role) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMoreRoles indicates an expected call of GetMoreRoles. +func (mr *MockStoreMockRecorder) GetMoreRoles(firstUpdate, modifyTime interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMoreRoles", reflect.TypeOf((*MockStore)(nil).GetMoreRoles), firstUpdate, modifyTime) +} + // GetMoreServiceContracts mocks base method. func (m *MockStore) GetMoreServiceContracts(firstUpdate bool, mtime time.Time) ([]*model.EnrichServiceContract, error) { m.ctrl.T.Helper() @@ -1699,6 +1772,21 @@ func (mr *MockStoreMockRecorder) GetMoreServices(mtime, firstUpdate, disableBusi return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMoreServices", reflect.TypeOf((*MockStore)(nil).GetMoreServices), mtime, firstUpdate, disableBusiness, needMeta) } +// GetMoreStrategies mocks base method. +func (m *MockStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*auth.StrategyDetail, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMoreStrategies", mtime, firstUpdate) + ret0, _ := ret[0].([]*auth.StrategyDetail) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMoreStrategies indicates an expected call of GetMoreStrategies. +func (mr *MockStoreMockRecorder) GetMoreStrategies(mtime, firstUpdate interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMoreStrategies", reflect.TypeOf((*MockStore)(nil).GetMoreStrategies), mtime, firstUpdate) +} + // GetNamespace mocks base method. func (m *MockStore) GetNamespace(name string) (*model.Namespace, error) { m.ctrl.T.Helper() @@ -2005,11 +2093,11 @@ func (mr *MockStoreMockRecorder) GetSourceServiceToken(name, namespace interface } // GetStrategies mocks base method. -func (m *MockStore) GetStrategies(filters map[string]string, offset, limit uint32) (uint32, []*model.StrategyDetail, error) { +func (m *MockStore) GetStrategies(filters map[string]string, offset, limit uint32) (uint32, []*auth.StrategyDetail, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStrategies", filters, offset, limit) ret0, _ := ret[0].(uint32) - ret1, _ := ret[1].([]*model.StrategyDetail) + ret1, _ := ret[1].([]*auth.StrategyDetail) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } @@ -2021,10 +2109,10 @@ func (mr *MockStoreMockRecorder) GetStrategies(filters, offset, limit interface{ } // GetStrategyDetail mocks base method. -func (m *MockStore) GetStrategyDetail(id string) (*model.StrategyDetail, error) { +func (m *MockStore) GetStrategyDetail(id string) (*auth.StrategyDetail, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStrategyDetail", id) - ret0, _ := ret[0].(*model.StrategyDetail) + ret0, _ := ret[0].(*auth.StrategyDetail) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2035,26 +2123,11 @@ func (mr *MockStoreMockRecorder) GetStrategyDetail(id interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategyDetail", reflect.TypeOf((*MockStore)(nil).GetStrategyDetail), id) } -// GetStrategyDetailsForCache mocks base method. -func (m *MockStore) GetStrategyDetailsForCache(mtime time.Time, firstUpdate bool) ([]*model.StrategyDetail, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetStrategyDetailsForCache", mtime, firstUpdate) - ret0, _ := ret[0].([]*model.StrategyDetail) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetStrategyDetailsForCache indicates an expected call of GetStrategyDetailsForCache. -func (mr *MockStoreMockRecorder) GetStrategyDetailsForCache(mtime, firstUpdate interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategyDetailsForCache", reflect.TypeOf((*MockStore)(nil).GetStrategyDetailsForCache), mtime, firstUpdate) -} - // GetStrategyResources mocks base method. -func (m *MockStore) GetStrategyResources(principalId string, principalRole model.PrincipalType) ([]model.StrategyResource, error) { +func (m *MockStore) GetStrategyResources(principalId string, principalRole auth.PrincipalType) ([]auth.StrategyResource, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStrategyResources", principalId, principalRole) - ret0, _ := ret[0].([]model.StrategyResource) + ret0, _ := ret[0].([]auth.StrategyResource) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2066,7 +2139,7 @@ func (mr *MockStoreMockRecorder) GetStrategyResources(principalId, principalRole } // GetSubCount mocks base method. -func (m *MockStore) GetSubCount(user *model.User) (uint32, error) { +func (m *MockStore) GetSubCount(user *auth.User) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSubCount", user) ret0, _ := ret[0].(uint32) @@ -2126,10 +2199,10 @@ func (mr *MockStoreMockRecorder) GetUnixSecond(maxWait interface{}) *gomock.Call } // GetUser mocks base method. -func (m *MockStore) GetUser(id string) (*model.User, error) { +func (m *MockStore) GetUser(id string) (*auth.User, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUser", id) - ret0, _ := ret[0].(*model.User) + ret0, _ := ret[0].(*auth.User) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2141,10 +2214,10 @@ func (mr *MockStoreMockRecorder) GetUser(id interface{}) *gomock.Call { } // GetUserByIds mocks base method. -func (m *MockStore) GetUserByIds(ids []string) ([]*model.User, error) { +func (m *MockStore) GetUserByIds(ids []string) ([]*auth.User, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserByIds", ids) - ret0, _ := ret[0].([]*model.User) + ret0, _ := ret[0].([]*auth.User) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2156,10 +2229,10 @@ func (mr *MockStoreMockRecorder) GetUserByIds(ids interface{}) *gomock.Call { } // GetUserByName mocks base method. -func (m *MockStore) GetUserByName(name, ownerId string) (*model.User, error) { +func (m *MockStore) GetUserByName(name, ownerId string) (*auth.User, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserByName", name, ownerId) - ret0, _ := ret[0].(*model.User) + ret0, _ := ret[0].(*auth.User) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2171,11 +2244,11 @@ func (mr *MockStoreMockRecorder) GetUserByName(name, ownerId interface{}) *gomoc } // GetUsers mocks base method. -func (m *MockStore) GetUsers(filters map[string]string, offset, limit uint32) (uint32, []*model.User, error) { +func (m *MockStore) GetUsers(filters map[string]string, offset, limit uint32) (uint32, []*auth.User, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUsers", filters, offset, limit) ret0, _ := ret[0].(uint32) - ret1, _ := ret[1].([]*model.User) + ret1, _ := ret[1].([]*auth.User) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } @@ -2187,10 +2260,10 @@ func (mr *MockStoreMockRecorder) GetUsers(filters, offset, limit interface{}) *g } // GetUsersForCache mocks base method. -func (m *MockStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*model.User, error) { +func (m *MockStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*auth.User, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUsersForCache", mtime, firstUpdate) - ret0, _ := ret[0].([]*model.User) + ret0, _ := ret[0].([]*auth.User) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2334,10 +2407,10 @@ func (mr *MockStoreMockRecorder) IsLeader(key interface{}) *gomock.Call { } // ListLeaderElections mocks base method. -func (m *MockStore) ListLeaderElections() ([]*model.LeaderElection, error) { +func (m *MockStore) ListLeaderElections() ([]*admin.LeaderElection, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListLeaderElections") - ret0, _ := ret[0].([]*model.LeaderElection) + ret0, _ := ret[0].([]*admin.LeaderElection) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2394,7 +2467,7 @@ func (mr *MockStoreMockRecorder) LockLaneGroup(tx, name interface{}) *gomock.Cal } // LooseAddStrategyResources mocks base method. -func (m *MockStore) LooseAddStrategyResources(resources []model.StrategyResource) error { +func (m *MockStore) LooseAddStrategyResources(resources []auth.StrategyResource) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LooseAddStrategyResources", resources) ret0, _ := ret[0].(error) @@ -2483,7 +2556,7 @@ func (mr *MockStoreMockRecorder) ReleaseLeaderElection(key interface{}) *gomock. } // RemoveStrategyResources mocks base method. -func (m *MockStore) RemoveStrategyResources(resources []model.StrategyResource) error { +func (m *MockStore) RemoveStrategyResources(resources []auth.StrategyResource) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoveStrategyResources", resources) ret0, _ := ret[0].(error) @@ -2626,7 +2699,7 @@ func (mr *MockStoreMockRecorder) UpdateFaultDetectRule(conf interface{}) *gomock } // UpdateGroup mocks base method. -func (m *MockStore) UpdateGroup(group *model.ModifyUserGroup) error { +func (m *MockStore) UpdateGroup(group *auth.ModifyUserGroup) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateGroup", group) ret0, _ := ret[0].(error) @@ -2709,6 +2782,20 @@ func (mr *MockStoreMockRecorder) UpdateRateLimit(limiting interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRateLimit", reflect.TypeOf((*MockStore)(nil).UpdateRateLimit), limiting) } +// UpdateRole mocks base method. +func (m *MockStore) UpdateRole(role *auth.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateRole", role) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateRole indicates an expected call of UpdateRole. +func (mr *MockStoreMockRecorder) UpdateRole(role interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRole", reflect.TypeOf((*MockStore)(nil).UpdateRole), role) +} + // UpdateRoutingConfig mocks base method. func (m *MockStore) UpdateRoutingConfig(conf *model.RoutingConfig) error { m.ctrl.T.Helper() @@ -2808,7 +2895,7 @@ func (mr *MockStoreMockRecorder) UpdateServiceToken(serviceID, token, revision i } // UpdateStrategy mocks base method. -func (m *MockStore) UpdateStrategy(strategy *model.ModifyStrategyDetail) error { +func (m *MockStore) UpdateStrategy(strategy *auth.ModifyStrategyDetail) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateStrategy", strategy) ret0, _ := ret[0].(error) @@ -2822,7 +2909,7 @@ func (mr *MockStoreMockRecorder) UpdateStrategy(strategy interface{}) *gomock.Ca } // UpdateUser mocks base method. -func (m *MockStore) UpdateUser(user *model.User) error { +func (m *MockStore) UpdateUser(user *auth.User) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateUser", user) ret0, _ := ret[0].(error) diff --git a/store/mysql/admin.go b/store/mysql/admin.go index 2d9c56b3c..b3e775530 100644 --- a/store/mysql/admin.go +++ b/store/mysql/admin.go @@ -27,7 +27,7 @@ import ( "time" "github.com/polarismesh/polaris/common/eventhub" - "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -64,7 +64,7 @@ type LeaderElectionStore interface { // CheckMtimeExpired check mtime expired CheckMtimeExpired(key string, leaseTime int32) (string, bool, error) // ListLeaderElections list all leaderelection - ListLeaderElections() ([]*model.LeaderElection, error) + ListLeaderElections() ([]*admin.LeaderElection, error) } // leaderElectionStore @@ -148,7 +148,7 @@ func (l *leaderElectionStore) CheckMtimeExpired(key string, leaseTime int32) (st } // ListLeaderElections list the election records -func (l *leaderElectionStore) ListLeaderElections() ([]*model.LeaderElection, error) { +func (l *leaderElectionStore) ListLeaderElections() ([]*admin.LeaderElection, error) { log.Info("[Store][database] list leader election") mainStr := "select elect_key, leader, UNIX_TIMESTAMP(ctime), UNIX_TIMESTAMP(mtime) from leader_election" @@ -161,16 +161,16 @@ func (l *leaderElectionStore) ListLeaderElections() ([]*model.LeaderElection, er return fetchLeaderElectionRows(rows) } -func fetchLeaderElectionRows(rows *sql.Rows) ([]*model.LeaderElection, error) { +func fetchLeaderElectionRows(rows *sql.Rows) ([]*admin.LeaderElection, error) { if rows == nil { return nil, nil } defer rows.Close() - var out []*model.LeaderElection + var out []*admin.LeaderElection for rows.Next() { - space := &model.LeaderElection{} + space := &admin.LeaderElection{} if err := rows.Scan(&space.ElectKey, &space.Host, &space.Ctime, &space.Mtime); err != nil { log.Errorf("[Store][database] fetch leader election rows scan err: %s", err.Error()) return nil, err @@ -300,7 +300,7 @@ func (le *leaderElectionStateMachine) changeToLeader() { // changeToFollower func (le *leaderElectionStateMachine) changeToFollower(leader string) { - log.Infof("[Store][database] change from leader to follower (%s)", le.electKey) + log.Infof("[Store][database] change from leader(%s) to follower (%s)", leader, le.electKey) atomic.StoreInt32(&le.leaderFlag, 0) le.leader = leader le.publishLeaderChangeEvent() @@ -410,7 +410,7 @@ func (m *adminStore) IsLeader(key string) bool { } // ListLeaderElections list election records -func (m *adminStore) ListLeaderElections() ([]*model.LeaderElection, error) { +func (m *adminStore) ListLeaderElections() ([]*admin.LeaderElection, error) { return m.leStore.ListLeaderElections() } @@ -533,8 +533,7 @@ func (m *adminStore) BatchCleanDeletedClients(timeout time.Duration, batchSize u log.Infof("[Store][database] batch clean soft deleted clients(%d)", batchSize) var rows int64 err := m.master.processWithTransaction("batchCleanDeletedClients", func(tx *BaseTx) error { - mainStr := "delete from client where flag = 1 and " + - "mtime <= FROM_UNIXTIME(UNIX_TIMESTAMP(SYSDATE()) - ?) limit ?" + mainStr := "delete from client where flag = 1 limit ?" result, err := tx.Exec(mainStr, int32(timeout.Seconds()), batchSize) if err != nil { log.Errorf("[Store][database] batch clean soft deleted clients(%d), err: %s", batchSize, err.Error()) diff --git a/store/mysql/client.go b/store/mysql/client.go index 6a27433ab..d54b65460 100644 --- a/store/mysql/client.go +++ b/store/mysql/client.go @@ -107,7 +107,7 @@ func (cs *clientStore) GetMoreClients(mtime time.Time, firstUpdate bool) (map[st from client left join client_stat on client.id = client_stat.client_id ` str += " where client.mtime >= FROM_UNIXTIME(?)" if firstUpdate { - str += " and flag != 1" + str += " and flag = 0" } rows, err := cs.slave.Query(str, timeToTimestamp(mtime)) if err != nil { diff --git a/store/mysql/default.go b/store/mysql/default.go index 299985033..ae5d8c205 100644 --- a/store/mysql/default.go +++ b/store/mysql/default.go @@ -70,10 +70,12 @@ type stableStore struct { *clientStore *adminStore *toolStore + *grayStore + *userStore *groupStore *strategyStore - *grayStore + *roleStore // 主数据库,可以进行读写 master *BaseDB diff --git a/store/mysql/group.go b/store/mysql/group.go index 3b4667542..938a10087 100644 --- a/store/mysql/group.go +++ b/store/mysql/group.go @@ -24,7 +24,7 @@ import ( "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -57,29 +57,15 @@ type groupStore struct { } // AddGroup 创建一个用户组 -func (u *groupStore) AddGroup(group *model.UserGroupDetail) error { +func (u *groupStore) AddGroup(tx store.Tx, group *authcommon.UserGroupDetail) error { if group.ID == "" || group.Name == "" || group.Token == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "add usergroup missing some params, groupId is %s, name is %s", group.ID, group.Name)) } - - err := RetryTransaction("addGroup", func() error { - return u.addGroup(group) - }) - - return store.Error(err) -} - -func (u *groupStore) addGroup(group *model.UserGroupDetail) error { - tx, err := u.master.Begin() - if err != nil { - return err - } - - defer func() { _ = tx.Rollback() }() + dbTx := tx.GetDelegateTx().(*BaseTx) // 先清理无效数据 - if err := cleanInValidGroup(tx, group.Name, group.Owner); err != nil { + if err := cleanInValidGroup(dbTx, group.Name, group.Owner); err != nil { return store.Error(err) } @@ -93,7 +79,7 @@ func (u *groupStore) addGroup(group *model.UserGroupDetail) error { tokenEnable = 0 } - if _, err = tx.Exec(addSql, []interface{}{ + if _, err := dbTx.Exec(addSql, []interface{}{ group.ID, group.Name, group.Owner, @@ -106,25 +92,15 @@ func (u *groupStore) addGroup(group *model.UserGroupDetail) error { return err } - if err := u.addGroupRelation(tx, group.ID, group.ToUserIdSlice()); err != nil { + if err := u.addGroupRelation(dbTx, group.ID, group.ToUserIdSlice()); err != nil { log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error()) return err } - - if err := createDefaultStrategy(tx, model.PrincipalGroup, group.ID, group.Name, group.Owner); err != nil { - log.Errorf("[Store][Group] add usergroup default strategy err: %s", err.Error()) - return err - } - - if err := tx.Commit(); err != nil { - log.Errorf("[Store][Group] add usergroup tx commit err: %s", err.Error()) - return err - } return nil } // UpdateGroup 更新用户组 -func (u *groupStore) UpdateGroup(group *model.ModifyUserGroup) error { +func (u *groupStore) UpdateGroup(group *authcommon.ModifyUserGroup) error { if group.ID == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "update usergroup missing some params, groupId is %s", group.ID)) @@ -137,7 +113,7 @@ func (u *groupStore) UpdateGroup(group *model.ModifyUserGroup) error { return store.Error(err) } -func (u *groupStore) updateGroup(group *model.ModifyUserGroup) error { +func (u *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { tx, err := u.master.Begin() if err != nil { return err @@ -186,55 +162,32 @@ func (u *groupStore) updateGroup(group *model.ModifyUserGroup) error { } // DeleteGroup 删除用户组 -func (u *groupStore) DeleteGroup(group *model.UserGroupDetail) error { +func (u *groupStore) DeleteGroup(tx store.Tx, group *authcommon.UserGroupDetail) error { if group.ID == "" || group.Name == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "delete usergroup missing some params, groupId is %s", group.ID)) } - err := RetryTransaction("deleteUserGroup", func() error { - return u.deleteUserGroup(group) - }) - - return store.Error(err) -} - -func (u *groupStore) deleteUserGroup(group *model.UserGroupDetail) error { - tx, err := u.master.Begin() - if err != nil { - return err - } + dbTx := tx.GetDelegateTx().(*BaseTx) - defer func() { _ = tx.Rollback() }() - - if _, err = tx.Exec("DELETE FROM user_group_relation WHERE group_id = ?", []interface{}{ + if _, err := dbTx.Exec("DELETE FROM user_group_relation WHERE group_id = ?", []interface{}{ group.ID, }...); err != nil { log.Errorf("[Store][Group] clean usergroup relation err: %s", err.Error()) return err } - if _, err = tx.Exec("UPDATE user_group SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{ + if _, err := dbTx.Exec("UPDATE user_group SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{ group.ID, }...); err != nil { log.Errorf("[Store][Group] remove usergroup err: %s", err.Error()) return err } - - if err := cleanLinkStrategy(tx, model.PrincipalGroup, group.ID, group.Owner); err != nil { - log.Errorf("[Store][Group] clean usergroup default strategy err: %s", err.Error()) - return err - } - - if err := tx.Commit(); err != nil { - log.Errorf("[Store][Group] delete usergroupr tx commit err: %s", err.Error()) - return err - } return nil } // GetGroup 根据用户组ID获取用户组 -func (u *groupStore) GetGroup(groupId string) (*model.UserGroupDetail, error) { +func (u *groupStore) GetGroup(groupId string) (*authcommon.UserGroupDetail, error) { if groupId == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "get usergroup missing some params, groupId is %s", groupId)) @@ -249,8 +202,8 @@ func (u *groupStore) GetGroup(groupId string) (*model.UserGroupDetail, error) { ` row := u.master.QueryRow(getSql, groupId) - group := &model.UserGroupDetail{ - UserGroup: &model.UserGroup{}, + group := &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{}, } var ( ctime, mtime int64 @@ -280,7 +233,7 @@ func (u *groupStore) GetGroup(groupId string) (*model.UserGroupDetail, error) { } // GetGroupByName 根据 owner、name 获取用户组 -func (u *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error) { +func (u *groupStore) GetGroupByName(name, owner string) (*authcommon.UserGroup, error) { if name == "" || owner == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "get usergroup missing some params, name=%s, owner=%s", name, owner)) @@ -298,7 +251,7 @@ func (u *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error ` row := u.master.QueryRow(getSql, name, owner) - group := new(model.UserGroup) + group := new(authcommon.UserGroup) if err := row.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &ctime, &mtime); err != nil { switch err { @@ -317,7 +270,7 @@ func (u *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error // GetGroups 根据不同的请求情况进行不同的用户组列表查询 func (u *groupStore) GetGroups(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.UserGroup, error) { + []*authcommon.UserGroup, error) { // 如果本次请求参数携带了 user_id,那么就是查询这个用户所关联的所有用户组 if _, ok := filters["user_id"]; ok { @@ -329,7 +282,7 @@ func (u *groupStore) GetGroups(filters map[string]string, offset uint32, limit u // listSimpleGroups 正常的用户组查询 func (u *groupStore) listSimpleGroups(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.UserGroup, error) { + []*authcommon.UserGroup, error) { query := make(map[string]string) if _, ok := filters["id"]; ok { @@ -388,7 +341,7 @@ func (u *groupStore) listSimpleGroups(filters map[string]string, offset uint32, // listGroupByUser 查询某个用户下所关联的用户组信息 func (u *groupStore) listGroupByUser(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.UserGroup, error) { + []*authcommon.UserGroup, error) { countSql := "SELECT COUNT(*) FROM user_group_relation ul LEFT JOIN user_group ug ON " + " ul.group_id = ug.id WHERE ug.flag = 0 " getSql := "SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable, UNIX_TIMESTAMP(ug.ctime), " + @@ -438,7 +391,7 @@ func (u *groupStore) listGroupByUser(filters map[string]string, offset uint32, l // collectGroupsFromRows 查询用户组列表 func (u *groupStore) collectGroupsFromRows(handler QueryHandler, querySql string, - args []interface{}) ([]*model.UserGroup, error) { + args []interface{}) ([]*authcommon.UserGroup, error) { rows, err := u.master.Query(querySql, args...) if err != nil { log.Error("[Store][Group] list group", zap.String("query sql", querySql), zap.Any("args", args)) @@ -446,7 +399,7 @@ func (u *groupStore) collectGroupsFromRows(handler QueryHandler, querySql string } defer rows.Close() - groups := make([]*model.UserGroup, 0) + groups := make([]*authcommon.UserGroup, 0) for rows.Next() { group, err := fetchRown2UserGroup(rows) if err != nil { @@ -460,7 +413,7 @@ func (u *groupStore) collectGroupsFromRows(handler QueryHandler, querySql string } // GetGroupsForCache . -func (u *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) { +func (u *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.UserGroupDetail, error) { tx, err := u.slave.Begin() if err != nil { return nil, store.Error(err) @@ -482,9 +435,9 @@ func (u *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*mo } defer rows.Close() - ret := make([]*model.UserGroupDetail, 0) + ret := make([]*authcommon.UserGroupDetail, 0) for rows.Next() { - detail := &model.UserGroupDetail{ + detail := &authcommon.UserGroupDetail{ UserIds: make(map[string]struct{}, 0), } group, err := fetchRown2UserGroup(rows) @@ -576,10 +529,10 @@ func (u *groupStore) getGroupLinkUserIds(groupId string) (map[string]struct{}, e return ids, nil } -func fetchRown2UserGroup(rows *sql.Rows) (*model.UserGroup, error) { +func fetchRown2UserGroup(rows *sql.Rows) (*authcommon.UserGroup, error) { var ctime, mtime int64 var flag, tokenEnable int - group := new(model.UserGroup) + group := new(authcommon.UserGroup) if err := rows.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &tokenEnable, &ctime, &mtime, &flag); err != nil { return nil, err diff --git a/store/mysql/role.go b/store/mysql/role.go new file mode 100644 index 000000000..97732eed9 --- /dev/null +++ b/store/mysql/role.go @@ -0,0 +1,260 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package sqldb + +import ( + "encoding/json" + "time" + + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/store" + "go.uber.org/zap" +) + +type roleStore struct { + master *BaseDB + slave *BaseDB +} + +// AddRole Add a role +func (s *roleStore) AddRole(role *authcommon.Role) error { + if role.ID == "" || role.Name == "" { + return store.NewStatusError(store.EmptyParamsErr, "role id or name is empty") + } + err := s.master.processWithTransaction("add_role", func(tx *BaseTx) error { + if _, err := tx.Exec("DELETE FROM auth_role WHERE id = ? AND flag = 1", role.ID); err != nil { + log.Error("[store][role] delete invalid role", zap.String("name", role.Name), zap.Error(err)) + return err + } + addSql := ` +INSERT INTO auth_role (id, name, owner, source, role_type + , comment, flag, metadata, ctime, mtime) +VALUES (?, ?, ?, ?, ? + , ?, 0, ?, sysdate(), sysdate()) + ` + args := []interface{}{role.ID, role.Name, role.Owner, role.Source, role.Type, role.Comment, utils.MustJson(role.Metadata)} + if _, err := tx.Exec(addSql, args...); err != nil { + log.Error("[store][role] add role main info", zap.String("name", role.Name), zap.Error(err)) + return err + } + + if err := s.savePrincipals(tx, role); err != nil { + log.Error("[store][role] save role principal info", zap.String("name", role.Name), zap.Error(err)) + return err + } + return nil + }) + return store.Error(err) +} + +func (s *roleStore) savePrincipals(tx *BaseTx, role *authcommon.Role) error { + if _, err := tx.Exec("DELETE FROM auth_role_principal WHERE id = ?", role.ID); err != nil { + log.Error("[store][role] clean role principal info", zap.String("name", role.Name), zap.Error(err)) + return err + } + + insertTpl := "INSERT INTO auth_role_principal(role_id, principal_id, principal_role) VALUES (?, ?, ?)" + + for i := range role.Users { + args := []interface{}{role.ID, role.Users[i].ID, authcommon.PrincipalUser} + if _, err := tx.Exec(insertTpl, args...); err != nil { + return err + } + } + for i := range role.UserGroups { + args := []interface{}{role.ID, role.UserGroups[i].ID, authcommon.PrincipalGroup} + if _, err := tx.Exec(insertTpl, args...); err != nil { + return err + } + } + return nil +} + +// UpdateRole Update a role +func (s *roleStore) UpdateRole(role *authcommon.Role) error { + if role.ID == "" { + return store.NewStatusError(store.EmptyParamsErr, "role id is empty") + } + err := s.master.processWithTransaction("update_role", func(tx *BaseTx) error { + updateSql := ` +UPDATE auth_role +SET source = ?, role_type = ?, comment = ?, metadata = ?, mtime = sysdate() +WHERE id = ? + ` + args := []interface{}{role.Source, role.Type, role.Comment, utils.MustJson(role.Metadata), role.ID} + if _, err := tx.Exec(updateSql, args...); err != nil { + log.Error("[store][role] update role main info", zap.String("name", role.Name), zap.Error(err)) + return err + } + + if err := s.savePrincipals(tx, role); err != nil { + log.Error("[store][role] save role principal info", zap.String("name", role.Name), zap.Error(err)) + return err + } + return nil + }) + return store.Error(err) +} + +// DeleteRole Delete a role +func (s *roleStore) DeleteRole(role *authcommon.Role) error { + if role.ID == "" { + return store.NewStatusError(store.EmptyParamsErr, "role id is empty") + } + err := s.master.processWithTransaction("delete_role", func(tx *BaseTx) error { + if _, err := tx.Exec("UPDATE auth_role SET flag = 1 WHERE id = ?", role.ID); err != nil { + log.Error("[store][role] delete role", zap.String("name", role.Name), zap.Error(err)) + return err + } + return nil + }) + return store.Error(err) +} + +// CleanPrincipalRoles clean principal roles +func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) error { + dbTx := tx.GetDelegateTx().(*BaseTx) + listSql := "SELECT role_id FROM auth_role_principal WHERE principal_id = ? AND principal_role = ?" + rows, err := dbTx.Query(listSql, p.PrincipalID, p.PrincipalType) + if err != nil { + log.Error("[store][role] list principal all roles", zap.String("principal", p.String()), zap.Error(err)) + return err + } + defer func() { + _ = rows.Close() + }() + + for rows.Next() { + var roleId string + if err := rows.Scan(&roleId); err != nil { + log.Error("[store][role] fetch one record principal role", zap.String("principal", p.String()), + zap.String("type", p.PrincipalType.String()), zap.Error(err)) + return err + } + + if _, err := dbTx.Exec("UPDATE auth_role SET mtime = sysdate() WHERE id = ?", roleId); err != nil { + log.Error("[store][role] update role when clean principal role", zap.String("id", roleId), + zap.String("principal", p.String()), zap.Error(err)) + return err + } + } + + if _, err := dbTx.Exec("DELETE FROM auth_role_principal WHERE principal_id = ? AND principal_role = ?", + p.PrincipalID, p.PrincipalType); err != nil { + log.Error("[store][role] clean principal all roles", zap.String("principal", p.String()), zap.Error(err)) + return store.Error(err) + } + return nil +} + +// GetRole get more role for cache update +func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { + tx, err := s.slave.Begin() + if err != nil { + return nil, store.Error(err) + } + + defer func() { _ = tx.Commit() }() + + args := make([]interface{}, 0) + querySql := "SELECT id, name, owner, source, role_type, comment, flag, metadata, UNIX_TIMESTAMP(ctime), " + + " UNIX_TIMESTAMP(mtime) FROM auth_role " + if !firstUpdate { + querySql += " WHERE mtime >= FROM_UNIXTIME(?)" + args = append(args, timeToTimestamp(mtime)) + } else { + querySql += " WHERE flag = 0" + } + + rows, err := tx.Query(querySql, args...) + if err != nil { + log.Error("[store][role] get more role for cache", zap.String("query sql", querySql), + zap.Any("args", args), zap.Error(err)) + return nil, store.Error(err) + } + defer func() { + _ = rows.Close() + }() + + roles := make([]*authcommon.Role, 0, 32) + for rows.Next() { + var ( + ctime, mtime int64 + flag int16 + metadata string + ) + ret := &authcommon.Role{ + Metadata: map[string]string{}, + Users: make([]*authcommon.User, 0, 4), + UserGroups: make([]*authcommon.UserGroup, 0, 4), + } + + if err := rows.Scan(&ret.ID, &ret.Name, &ret.Owner, &ret.Source, &ret.Type, &ret.Comment, + &flag, &metadata, &ctime, &mtime); err != nil { + log.Error("[store][role] fetch one record role info", zap.Error(err)) + return nil, store.Error(err) + } + + ret.CreateTime = time.Unix(ctime, 0) + ret.ModifyTime = time.Unix(mtime, 0) + ret.Valid = flag == 0 + _ = json.Unmarshal([]byte(metadata), &ret.Metadata) + + if err := s.fetchRolePrincipals(tx, ret); err != nil { + return nil, store.Error(err) + } + + // fetch link user or groups + roles = append(roles, ret) + } + return roles, nil +} + +func (s *roleStore) fetchRolePrincipals(tx *BaseTx, role *authcommon.Role) error { + rows, err := tx.Query("SELECT role_id, principal_id, principal_role FROM auth_role_principal WHERE rold_id = ?", role.ID) + if err != nil { + log.Error("[store][role] fetch role principals", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + defer func() { + _ = rows.Close() + }() + + for rows.Next() { + var ( + roleID, principalID string + principalRole int + ) + if err := rows.Scan(&roleID, &principalID, &principalRole); err != nil { + log.Error("[store][role] fetch one record role principal", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + + if principalRole == int(authcommon.PrincipalUser) { + role.Users = append(role.Users, &authcommon.User{ + ID: principalID, + }) + } else { + role.UserGroups = append(role.UserGroups, &authcommon.UserGroup{ + ID: principalID, + }) + } + } + return nil +} diff --git a/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql b/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql new file mode 100644 index 000000000..49041073d --- /dev/null +++ b/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql @@ -0,0 +1,43 @@ +/* 角色数据 */ +CREATE TABLE + `auth_role` ( + `id` VARCHAR(128) NOT NULL COMMENT 'role id', + `name` VARCHAR(100) NOT NULL COMMENT 'role name', + `owner` VARCHAR(128) NOT NULL COMMENT 'Main account ID', + `source` VARCHAR(32) NOT NULL COMMENT 'role source', + `role_type` INT NOT NULL DEFAULT 20 COMMENT 'role type', + `comment` VARCHAR(255) NOT NULL COMMENT 'describe', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', + `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', + `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'user metadata', + PRIMARY KEY (`id`), + UNIQUE KEY (`name`, `owner`), + KEY `owner` (`owner`), + KEY `mtime` (`mtime`) + ) ENGINE = InnoDB; + +/* 角色关联用户/用户组关系表 */ +CREATE TABLE + `auth_role_principal` ( + `role_id` VARCHAR(128) NOT NULL COMMENT 'role id', + `principal_id` VARCHAR(128) NOT NULL COMMENT 'principal id', + `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group', + PRIMARY KEY (`role_id`, `principal_id`, `principal_role`) + ) ENGINE = InnoDB; + +/* 鉴权策略中的资源标签关联信息 */ +CRAETE TABLE `auth_strategy_label` ( + `strategy_id` VARCHAR(128) NOT NULL COMMENT 'strategy id', + `key` VARCHAR(128) NOT NULL COMMENT 'tag key', + `value` TEXT NOT NULL COMMENT 'tag value', + `compare_type` VARCHAR(128) NOT NULL COMMENT 'tag kv compare func', + PRIMARY KEY (`strategy_id`, `key`) +) ENGINE = InnoDB; + +/* 鉴权策略中的资源标签关联信息 */ +CRAETE TABLE `auth_strategy_function` ( + `strategy_id` VARCHAR(128) NOT NULL COMMENT 'strategy id', + `function` VARCHAR(256) NOT NULL COMMENT 'server provider function name', + PRIMARY KEY (`strategy_id`, `function`) +) ENGINE = InnoDB; \ No newline at end of file diff --git a/store/mysql/scripts/polaris_server.sql b/store/mysql/scripts/polaris_server.sql index 24e0234a8..c49da920d 100644 --- a/store/mysql/scripts/polaris_server.sql +++ b/store/mysql/scripts/polaris_server.sql @@ -23,7 +23,9 @@ SET -- -- Database: `polaris_server` -- -CREATE DATABASE IF NOT EXISTS `polaris_server` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_bin; +CREATE DATABASE IF NOT EXISTS `polaris_server` DEFAULT CHARACTER +SET + utf8mb4 COLLATE utf8mb4_bin; USE `polaris_server`; @@ -37,20 +39,20 @@ CREATE TABLE `service_id` VARCHAR(32) NOT NULL COMMENT 'Service ID', `vpc_id` VARCHAR(64) DEFAULT NULL COMMENT 'VPC ID', `host` VARCHAR(128) NOT NULL COMMENT 'instance Host Information', - `port` INT(11) NOT NULL COMMENT 'instance port information', + `port` INT (11) NOT NULL COMMENT 'instance port information', `protocol` VARCHAR(32) DEFAULT NULL COMMENT 'Listening protocols for corresponding ports, such as TPC, UDP, GRPC, DUBBO, etc.', `version` VARCHAR(32) DEFAULT NULL COMMENT 'The version of the instance can be used for version routing', - `health_status` TINYINT(4) NOT NULL DEFAULT '1' COMMENT 'The health status of the instance, 1 is health, 0 is unhealthy', - `isolate` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Example isolation status flag, 0 is not isolated, 1 is isolated', - `weight` SMALLINT(6) NOT NULL DEFAULT '100' COMMENT 'The weight of the instance is mainly used for LoadBalance, default is 100', - `enable_health_check` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Whether to open a heartbeat on an instance, check the logic, 0 is not open, 1 is open', + `health_status` TINYINT (4) NOT NULL DEFAULT '1' COMMENT 'The health status of the instance, 1 is health, 0 is unhealthy', + `isolate` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Example isolation status flag, 0 is not isolated, 1 is isolated', + `weight` SMALLINT (6) NOT NULL DEFAULT '100' COMMENT 'The weight of the instance is mainly used for LoadBalance, default is 100', + `enable_health_check` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether to open a heartbeat on an instance, check the logic, 0 is not open, 1 is open', `logic_set` VARCHAR(128) DEFAULT NULL COMMENT 'Example logic packet information', `cmdb_region` VARCHAR(128) DEFAULT NULL COMMENT 'The region information of the instance is mainly used to close the route', `cmdb_zone` VARCHAR(128) DEFAULT NULL COMMENT 'The ZONE information of the instance is mainly used to close the route.', `cmdb_idc` VARCHAR(128) DEFAULT NULL COMMENT 'The IDC information of the instance is mainly used to close the route', - `priority` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Example priority, currently useless', + `priority` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Example priority, currently useless', `revision` VARCHAR(32) NOT NULL COMMENT 'Instance version information', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', PRIMARY KEY (`id`), @@ -66,8 +68,8 @@ CREATE TABLE CREATE TABLE `health_check` ( `id` VARCHAR(128) NOT NULL COMMENT 'Instance ID', - `type` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Instance health check type', - `ttl` INT(11) NOT NULL COMMENT 'TTL time jumping', + `type` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Instance health check type', + `ttl` INT (11) NOT NULL COMMENT 'TTL time jumping', PRIMARY KEY (`id`) /* CONSTRAINT `health_check_ibfk_1` FOREIGN KEY (`id`) REFERENCES `instance` (`id`) ON DELETE CASCADE ON UPDATE CASCADE */ ) ENGINE = InnoDB; @@ -98,7 +100,7 @@ CREATE TABLE `comment` VARCHAR(1024) DEFAULT NULL COMMENT 'Description of namespace', `token` VARCHAR(64) NOT NULL COMMENT 'TOKEN named space for write operation check', `owner` VARCHAR(1024) NOT NULL COMMENT 'Responsible for named space Owner', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', `service_export_to` TEXT COMMENT 'namespace metadata', @@ -149,7 +151,7 @@ CREATE TABLE `in_bounds` TEXT COMMENT 'Service is routing rules', `out_bounds` TEXT COMMENT 'Service main routing rules', `revision` VARCHAR(40) NOT NULL COMMENT 'Routing rule version', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', PRIMARY KEY (`id`), @@ -164,17 +166,18 @@ CREATE TABLE `ratelimit_config` ( `id` VARCHAR(32) NOT NULL COMMENT 'ratelimit rule ID', `name` VARCHAR(64) NOT NULL COMMENT 'ratelimt rule name', - `disable` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'ratelimit disable', + `disable` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'ratelimit disable', `service_id` VARCHAR(32) NOT NULL COMMENT 'Service ID', `method` VARCHAR(512) NOT NULL COMMENT 'ratelimit method', `labels` TEXT NOT NULL COMMENT 'Conductive flow for a specific label', - `priority` SMALLINT(6) NOT NULL DEFAULT '0' COMMENT 'ratelimit rule priority', + `priority` SMALLINT (6) NOT NULL DEFAULT '0' COMMENT 'ratelimit rule priority', `rule` TEXT NOT NULL COMMENT 'Current limiting rules', `revision` VARCHAR(32) NOT NULL COMMENT 'Limiting version', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', `etime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'RateLimit rule enable time', + `metadata` TEXT COMMENT 'ratelimit rule metadata', PRIMARY KEY (`id`), KEY `mtime` (`mtime`), KEY `service_id` (`service_id`) @@ -213,7 +216,7 @@ CREATE TABLE `token` VARCHAR(2048) NOT NULL COMMENT 'Service token, used to handle all the services involved in the service', `revision` VARCHAR(32) NOT NULL COMMENT 'Service version information', `owner` VARCHAR(1024) NOT NULL COMMENT 'Owner information belonging to the service', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `reference` VARCHAR(32) DEFAULT NULL COMMENT 'Service alias, what is the actual service name that the service is actually pointed out?', `refer_filter` VARCHAR(1024) DEFAULT NULL COMMENT '', `platform_id` VARCHAR(32) DEFAULT '' COMMENT 'The platform ID to which the service belongs', @@ -310,9 +313,10 @@ CREATE TABLE `token` VARCHAR(32) NOT NULL COMMENT 'Token, which is fucking, mainly for writing operation check', `owner` VARCHAR(1024) NOT NULL COMMENT 'Melting rule Owner information', `revision` VARCHAR(32) NOT NULL COMMENT 'Melt rule version information', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'circuit_breaker rule metadata', PRIMARY KEY (`id`, `version`), UNIQUE KEY `name` (`name`, `namespace`, `version`), KEY `mtime` (`mtime`) @@ -327,7 +331,7 @@ CREATE TABLE `service_id` VARCHAR(32) NOT NULL COMMENT 'Service ID', `rule_id` VARCHAR(97) NOT NULL COMMENT 'Melting rule ID', `rule_version` VARCHAR(32) NOT NULL COMMENT 'Melting rule version', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Logic delete flag, 0 means visible, 1 means that it has been logically deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', PRIMARY KEY (`service_id`), @@ -342,13 +346,13 @@ CREATE TABLE -- CREATE TABLE `t_ip_config` ( - `Fip` INT(10) UNSIGNED NOT NULL COMMENT 'Machine IP', - `FareaId` INT(10) UNSIGNED NOT NULL COMMENT 'Area number', - `FcityId` INT(10) UNSIGNED NOT NULL COMMENT 'City number', - `FidcId` INT(10) UNSIGNED NOT NULL COMMENT 'IDC number', - `Fflag` TINYINT(4) DEFAULT '0', + `Fip` INT (10) UNSIGNED NOT NULL COMMENT 'Machine IP', + `FareaId` INT (10) UNSIGNED NOT NULL COMMENT 'Area number', + `FcityId` INT (10) UNSIGNED NOT NULL COMMENT 'City number', + `FidcId` INT (10) UNSIGNED NOT NULL COMMENT 'IDC number', + `Fflag` TINYINT (4) DEFAULT '0', `Fstamp` DATETIME NOT NULL, - `Fflow` INT(10) UNSIGNED NOT NULL, + `Fflow` INT (10) UNSIGNED NOT NULL, PRIMARY KEY (`Fip`), KEY `idx_Fflow` (`Fflow`) ) ENGINE = InnoDB; @@ -359,12 +363,12 @@ CREATE TABLE -- CREATE TABLE `t_policy` ( - `FmodId` INT(10) UNSIGNED NOT NULL, - `Fdiv` INT(10) UNSIGNED NOT NULL, - `Fmod` INT(10) UNSIGNED NOT NULL, - `Fflag` TINYINT(4) DEFAULT '0', + `FmodId` INT (10) UNSIGNED NOT NULL, + `Fdiv` INT (10) UNSIGNED NOT NULL, + `Fmod` INT (10) UNSIGNED NOT NULL, + `Fflag` TINYINT (4) DEFAULT '0', `Fstamp` DATETIME NOT NULL, - `Fflow` INT(10) UNSIGNED NOT NULL, + `Fflow` INT (10) UNSIGNED NOT NULL, PRIMARY KEY (`FmodId`) ) ENGINE = InnoDB; @@ -374,13 +378,13 @@ CREATE TABLE -- CREATE TABLE `t_route` ( - `Fip` INT(10) UNSIGNED NOT NULL, - `FmodId` INT(10) UNSIGNED NOT NULL, - `FcmdId` INT(10) UNSIGNED NOT NULL, + `Fip` INT (10) UNSIGNED NOT NULL, + `FmodId` INT (10) UNSIGNED NOT NULL, + `FcmdId` INT (10) UNSIGNED NOT NULL, `FsetId` VARCHAR(32) NOT NULL, - `Fflag` TINYINT(4) DEFAULT '0', + `Fflag` TINYINT (4) DEFAULT '0', `Fstamp` DATETIME NOT NULL, - `Fflow` INT(10) UNSIGNED NOT NULL, + `Fflow` INT (10) UNSIGNED NOT NULL, PRIMARY KEY (`Fip`, `FmodId`, `FcmdId`), KEY `Fflow` (`Fflow`), KEY `idx1` (`FmodId`, `FcmdId`, `FsetId`) @@ -392,13 +396,13 @@ CREATE TABLE -- CREATE TABLE `t_section` ( - `FmodId` INT(10) UNSIGNED NOT NULL, - `Ffrom` INT(10) UNSIGNED NOT NULL, - `Fto` INT(10) UNSIGNED NOT NULL, - `Fxid` INT(10) UNSIGNED NOT NULL, - `Fflag` TINYINT(4) DEFAULT '0', + `FmodId` INT (10) UNSIGNED NOT NULL, + `Ffrom` INT (10) UNSIGNED NOT NULL, + `Fto` INT (10) UNSIGNED NOT NULL, + `Fxid` INT (10) UNSIGNED NOT NULL, + `Fflag` TINYINT (4) DEFAULT '0', `Fstamp` DATETIME NOT NULL, - `Fflow` INT(10) UNSIGNED NOT NULL, + `Fflow` INT (10) UNSIGNED NOT NULL, PRIMARY KEY (`FmodId`, `Ffrom`, `Fto`) ) ENGINE = InnoDB; @@ -408,7 +412,7 @@ CREATE TABLE -- CREATE TABLE `start_lock` ( - `lock_id` INT(11) NOT NULL COMMENT '锁序号', + `lock_id` INT (11) NOT NULL COMMENT '锁序号', `lock_key` VARCHAR(32) NOT NULL COMMENT 'Lock name', `server` VARCHAR(32) NOT NULL COMMENT 'SERVER holding launch lock', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Update time', @@ -429,9 +433,9 @@ VALUES -- CREATE TABLE `cl5_module` ( - `module_id` INT(11) NOT NULL COMMENT 'Module ID', - `interface_id` INT(11) NOT NULL COMMENT 'Interface ID', - `range_num` INT(11) NOT NULL, + `module_id` INT (11) NOT NULL COMMENT 'Module ID', + `interface_id` INT (11) NOT NULL COMMENT 'Interface ID', + `range_num` INT (11) NOT NULL, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', PRIMARY KEY (`module_id`) ) ENGINE = InnoDB COMMENT = 'To generate SID'; @@ -457,7 +461,7 @@ CREATE TABLE `content` LONGTEXT NOT NULL COMMENT '文件内容', `format` VARCHAR(16) DEFAULT 'text' COMMENT '文件格式,枚举值', `comment` VARCHAR(512) DEFAULT NULL COMMENT '备注信息', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT '软删除标记位', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT '软删除标记位', `create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `create_by` VARCHAR(32) DEFAULT NULL COMMENT '创建人', `modify_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '最后更新时间', @@ -484,7 +488,7 @@ CREATE TABLE `business` VARCHAR(64) DEFAULT NULL COMMENT 'Service business information', `department` VARCHAR(1024) DEFAULT NULL COMMENT 'Service department information', `metadata` TEXT COMMENT '配置分组标签', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT '是否被删除', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT '是否被删除', PRIMARY KEY (`id`), UNIQUE KEY `uk_name` (`namespace`, `name`) ) ENGINE = InnoDB AUTO_INCREMENT = 1 COMMENT = '配置文件组表'; @@ -504,14 +508,14 @@ CREATE TABLE `content` LONGTEXT NOT NULL COMMENT '文件内容', `comment` VARCHAR(512) DEFAULT NULL COMMENT '备注信息', `md5` VARCHAR(128) NOT NULL COMMENT 'content的md5值', - `version` BIGINT(11) NOT NULL COMMENT '版本号,每次发布自增1', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT '是否被删除', + `version` BIGINT (11) NOT NULL COMMENT '版本号,每次发布自增1', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT '是否被删除', `create_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `create_by` VARCHAR(32) DEFAULT NULL COMMENT '创建人', `modify_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '最后更新时间', `modify_by` VARCHAR(32) DEFAULT NULL COMMENT '最后更新人', `tags` TEXT COMMENT '文件标签', - `active` TINYINT(4) NOT NULL DEFAULT '0' COMMENT '是否处于使用中', + `active` TINYINT (4) NOT NULL DEFAULT '0' COMMENT '是否处于使用中', `description` VARCHAR(512) DEFAULT NULL COMMENT '发布描述', `release_type` VARCHAR(25) NOT NULL DEFAULT '' COMMENT '文件类型:"":全量 gray:灰度', PRIMARY KEY (`id`), @@ -541,7 +545,7 @@ CREATE TABLE `modify_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '最后更新时间', `modify_by` VARCHAR(32) DEFAULT NULL COMMENT '最后更新人', `tags` TEXT COMMENT '文件标签', - `version` BIGINT(11) COMMENT '版本号,每次发布自增1', + `version` BIGINT (11) COMMENT '版本号,每次发布自增1', `reason` VARCHAR(3000) DEFAULT '' COMMENT '原因', `description` VARCHAR(512) DEFAULT NULL COMMENT '发布描述', PRIMARY KEY (`id`), @@ -579,12 +583,13 @@ CREATE TABLE `mobile` VARCHAR(12) NOT NULL DEFAULT '' COMMENT 'Account mobile phone number', `email` VARCHAR(64) NOT NULL DEFAULT '' COMMENT 'Account mailbox', `token` VARCHAR(255) NOT NULL COMMENT 'The token information owned by the account can be used for SDK access authentication', - `token_enable` TINYINT(4) NOT NULL DEFAULT 1, + `token_enable` TINYINT (4) NOT NULL DEFAULT 1, `user_type` INT NOT NULL DEFAULT 20 COMMENT 'Account type, 0 is the admin super account, 20 is the primary account, 50 for the child account', `comment` VARCHAR(255) NOT NULL COMMENT 'describe', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'user metadata', PRIMARY KEY (`id`), UNIQUE KEY (`name`, `owner`), KEY `owner` (`owner`), @@ -598,10 +603,11 @@ CREATE TABLE `owner` VARCHAR(128) NOT NULL COMMENT 'The main account ID of the user group', `token` VARCHAR(255) NOT NULL COMMENT 'TOKEN information of this user group', `comment` VARCHAR(255) NOT NULL COMMENT 'Description', - `token_enable` TINYINT(4) NOT NULL DEFAULT 1, - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', + `token_enable` TINYINT (4) NOT NULL DEFAULT 1, + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'user_group metadata', PRIMARY KEY (`id`), UNIQUE KEY (`name`, `owner`), KEY `owner` (`owner`), @@ -625,11 +631,13 @@ CREATE TABLE `action` VARCHAR(32) NOT NULL COMMENT 'Read and write permission for this policy, only_read = 0, read_write = 1', `owner` VARCHAR(128) NOT NULL COMMENT 'The account ID to which this policy is', `comment` VARCHAR(255) NOT NULL COMMENT 'describe', - `default` TINYINT(4) NOT NULL DEFAULT '0', + `default` TINYINT (4) NOT NULL DEFAULT '0', + `source` VARCHAR(32) NOT NULL COMMENT 'policy rule source', `revision` VARCHAR(128) NOT NULL COMMENT 'Authentication rule version', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'policy rule metadata', PRIMARY KEY (`id`), UNIQUE KEY (`name`, `owner`), KEY `owner` (`owner`), @@ -640,7 +648,7 @@ CREATE TABLE `auth_principal` ( `strategy_id` VARCHAR(128) NOT NULL COMMENT 'Strategy ID', `principal_id` VARCHAR(128) NOT NULL COMMENT 'Principal ID', - `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group', + `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group, 3 is Role', PRIMARY KEY (`strategy_id`, `principal_id`, `principal_role`) ) ENGINE = InnoDB; @@ -655,6 +663,50 @@ CREATE TABLE KEY `mtime` (`mtime`) ) ENGINE = InnoDB; +/* 角色数据 */ +CREATE TABLE + `auth_role` ( + `id` VARCHAR(128) NOT NULL COMMENT 'role id', + `name` VARCHAR(100) NOT NULL COMMENT 'role name', + `owner` VARCHAR(128) NOT NULL COMMENT 'Main account ID', + `source` VARCHAR(32) NOT NULL COMMENT 'role source', + `role_type` INT NOT NULL DEFAULT 20 COMMENT 'role type', + `comment` VARCHAR(255) NOT NULL COMMENT 'describe', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT 'Whether the rules are valid, 0 is valid, 1 is invalid, it is deleted', + `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'Create time', + `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Last updated time', + `metadata` TEXT COMMENT 'user metadata', + PRIMARY KEY (`id`), + UNIQUE KEY (`name`, `owner`), + KEY `owner` (`owner`), + KEY `mtime` (`mtime`) + ) ENGINE = InnoDB; + +/* 角色关联用户/用户组关系表 */ +CREATE TABLE + `auth_role_principal` ( + `role_id` VARCHAR(128) NOT NULL COMMENT 'role id', + `principal_id` VARCHAR(128) NOT NULL COMMENT 'principal id', + `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group', + PRIMARY KEY (`role_id`, `principal_id`, `principal_role`) + ) ENGINE = InnoDB; + +/* 鉴权策略中的资源标签关联信息 */ +CRAETE TABLE `auth_strategy_label` ( + `strategy_id` VARCHAR(128) NOT NULL COMMENT 'strategy id', + `key` VARCHAR(128) NOT NULL COMMENT 'tag key', + `value` TEXT NOT NULL COMMENT 'tag value', + `compare_type` VARCHAR(128) NOT NULL COMMENT 'tag kv compare func', + PRIMARY KEY (`strategy_id`, `key`) +) ENGINE = InnoDB; + +/* 鉴权策略中的资源标签关联信息 */ +CRAETE TABLE `auth_strategy_function` ( + `strategy_id` VARCHAR(128) NOT NULL COMMENT 'strategy id', + `function` VARCHAR(256) NOT NULL COMMENT 'server provider function name', + PRIMARY KEY (`strategy_id`, `function`) +) ENGINE = InnoDB; + -- Create a default master account, password is Polarismesh @ 2021 INSERT INTO `user` ( @@ -709,8 +761,8 @@ VALUES 1, 'fbca9bfa04ae4ead86e1ecf5811e32a9', 0, - SYSDATE(), - SYSDATE() + SYSDATE (), + SYSDATE () ); -- Sport rules inserted into Polaris-Admin to access @@ -727,22 +779,22 @@ VALUES 'fbca9bfa04ae4ead86e1ecf5811e32a9', 0, '*', - SYSDATE(), - SYSDATE() + SYSDATE (), + SYSDATE () ), ( 'fbca9bfa04ae4ead86e1ecf5811e32a9', 1, '*', - SYSDATE(), - SYSDATE() + SYSDATE (), + SYSDATE () ), ( 'fbca9bfa04ae4ead86e1ecf5811e32a9', 2, '*', - SYSDATE(), - SYSDATE() + SYSDATE (), + SYSDATE () ); -- Insert permission policies and association relationships for Polaris-Admin accounts @@ -763,7 +815,7 @@ CREATE TABLE `region` VARCHAR(128) DEFAULT NULL COMMENT 'region info for client', `zone` VARCHAR(128) DEFAULT NULL COMMENT 'zone info for client', `campus` VARCHAR(128) DEFAULT NULL COMMENT 'campus info for client', - `flag` TINYINT(4) NOT NULL DEFAULT '0' COMMENT '0 is valid, 1 is invalid(deleted)', + `flag` TINYINT (4) NOT NULL DEFAULT '0' COMMENT '0 is valid, 1 is invalid(deleted)', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last updated time', PRIMARY KEY (`id`), @@ -774,7 +826,7 @@ CREATE TABLE `client_stat` ( `client_id` VARCHAR(128) NOT NULL COMMENT 'client id', `target` VARCHAR(100) NOT NULL COMMENT 'target stat platform', - `port` INT(11) NOT NULL COMMENT 'client port to get stat information', + `port` INT (11) NOT NULL COMMENT 'client port to get stat information', `protocol` VARCHAR(100) NOT NULL COMMENT 'stat info transport protocol', `path` VARCHAR(128) NOT NULL COMMENT 'stat metric path', PRIMARY KEY (`client_id`, `target`, `port`) @@ -783,7 +835,7 @@ CREATE TABLE -- v1.9.0 CREATE TABLE `config_file_template` ( - `id` BIGINT(10) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主键', + `id` BIGINT (10) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT '主键', `name` VARCHAR(128) COLLATE utf8_bin NOT NULL COMMENT '配置文件模板名称', `content` LONGTEXT COLLATE utf8_bin NOT NULL COMMENT '配置文件模板内容', `format` VARCHAR(16) COLLATE utf8_bin DEFAULT 'text' COMMENT '模板文件格式', @@ -831,9 +883,9 @@ VALUES }', 'json', 'Spring Cloud Gateway 染色规则', - NOW(), + NOW (), 'polaris', - NOW(), + NOW (), 'polaris' ); @@ -848,12 +900,13 @@ CREATE TABLE `enable` INT NOT NULL DEFAULT 0, `revision` VARCHAR(40) NOT NULL, `description` VARCHAR(500) NOT NULL DEFAULT '', - `priority` SMALLINT(6) NOT NULL DEFAULT '0' COMMENT 'ratelimit rule priority', - `flag` TINYINT(4) NOT NULL DEFAULT '0', + `priority` SMALLINT (6) NOT NULL DEFAULT '0' COMMENT 'ratelimit rule priority', + `flag` TINYINT (4) NOT NULL DEFAULT '0', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `etime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `extend_info` VARCHAR(1024) DEFAULT '', + `metadata` TEXT COMMENT 'route rule metadata', PRIMARY KEY (`id`), KEY `mtime` (`mtime`) ) ENGINE = innodb; @@ -885,10 +938,11 @@ CREATE TABLE `dst_namespace` VARCHAR(64) NOT NULL, `dst_method` VARCHAR(128) NOT NULL, `config` TEXT, - `flag` TINYINT(4) NOT NULL DEFAULT '0', + `flag` TINYINT (4) NOT NULL DEFAULT '0', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `etime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `metadata` TEXT COMMENT 'circuit_breaker rule metadata', PRIMARY KEY (`id`), KEY `name` (`name`), KEY `mtime` (`mtime`) @@ -905,9 +959,10 @@ CREATE TABLE `dst_namespace` VARCHAR(64) NOT NULL, `dst_method` VARCHAR(128) NOT NULL, `config` TEXT, - `flag` TINYINT(4) NOT NULL DEFAULT '0', + `flag` TINYINT (4) NOT NULL DEFAULT '0', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `metadata` TEXT COMMENT 'faultdetect rule metadata', PRIMARY KEY (`id`), KEY `name` (`name`), KEY `mtime` (`mtime`) @@ -923,7 +978,7 @@ CREATE TABLE `protocol` VARCHAR(32) NOT NULL COMMENT '当前契约对应的协议信息 e.g. http/dubbo/grpc/thrift', `version` VARCHAR(64) NOT NULL COMMENT '服务契约版本', `revision` VARCHAR(128) NOT NULL COMMENT '当前服务契约的全部内容版本摘要', - `flag` TINYINT(4) DEFAULT 0 COMMENT '逻辑删除标志位 , 0 位有效 , 1 为逻辑删除', + `flag` TINYINT (4) DEFAULT 0 COMMENT '逻辑删除标志位 , 0 位有效 , 1 为逻辑删除', `content` LONGTEXT COMMENT '描述信息', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, @@ -954,7 +1009,7 @@ CREATE TABLE `source` INT COMMENT '该条记录来源, 0:SDK/1:MANUAL', `content` LONGTEXT COMMENT '描述信息', `revision` VARCHAR(128) NOT NULL COMMENT '当前接口定义的全部内容版本摘要', - `flag` TINYINT(4) DEFAULT 0 COMMENT '逻辑删除标志位, 0 位有效, 1 为逻辑删除', + `flag` TINYINT (4) DEFAULT 0 COMMENT '逻辑删除标志位, 0 位有效, 1 为逻辑删除', `ctime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), @@ -971,7 +1026,7 @@ CREATE TABLE `create_by` VARCHAR(32) DEFAULT "" COMMENT '创建人', `modify_time` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '最后更新时间', `modify_by` VARCHAR(32) DEFAULT "" COMMENT '最后更新人', - `flag` TINYINT(4) DEFAULT 0 COMMENT '逻辑删除标志位, 0 位有效, 1 为逻辑删除', + `flag` TINYINT (4) DEFAULT 0 COMMENT '逻辑删除标志位, 0 位有效, 1 为逻辑删除', PRIMARY KEY (`name`) ) ENGINE = InnoDB COMMENT = '灰度资源表'; @@ -985,6 +1040,7 @@ CREATE TABLE `flag` tinyint default 0 comment '软删除标识位', `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, `mtime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `metadata` TEXT COMMENT 'lane rule metadata', PRIMARY KEY (`id`), UNIQUE KEY `name` (`name`) ) ENGINE = InnoDB; diff --git a/store/mysql/strategy.go b/store/mysql/strategy.go index 8f0f688d5..cd95c2fce 100644 --- a/store/mysql/strategy.go +++ b/store/mysql/strategy.go @@ -25,7 +25,7 @@ import ( "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -51,42 +51,36 @@ type strategyStore struct { slave *BaseDB } -func (s *strategyStore) AddStrategy(strategy *model.StrategyDetail) error { +func (s *strategyStore) AddStrategy(tx store.Tx, strategy *authcommon.StrategyDetail) error { if strategy.ID == "" || strategy.Name == "" || strategy.Owner == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "add auth_strategy missing some params, id is %s, name is %s, owner is %s", strategy.ID, strategy.Name, strategy.Owner)) } - // 先清理无效数据 - if err := s.cleanInvalidStrategy(strategy.Name, strategy.Owner); err != nil { - return store.Error(err) - } + dbTx := tx.GetDelegateTx().(*BaseTx) - err := RetryTransaction("addStrategy", func() error { - return s.addStrategy(strategy) - }) - return store.Error(err) -} + // 先清理无效数据 + log.Info("[Store][Strategy] clean invalid auth_strategy", zap.String("name", strategy.Name), + zap.String("owner", strategy.Owner)) -func (s *strategyStore) addStrategy(strategy *model.StrategyDetail) error { - tx, err := s.master.Begin() - if err != nil { + str := "delete from auth_strategy where name = ? and owner = ? and flag = 1" + if _, err := dbTx.Exec(str, strategy.Name, strategy.Owner); err != nil { + log.Errorf("[Store][Strategy] clean invalid auth_strategy(%s) err: %s", strategy.Name, err.Error()) return err } - defer func() { _ = tx.Rollback() }() isDefault := 0 if strategy.Default { isDefault = 1 } - if err := s.addStrategyPrincipals(tx, strategy.ID, strategy.Principals); err != nil { + if err := s.addStrategyPrincipals(dbTx, strategy.ID, strategy.Principals); err != nil { log.Error("[Store][Strategy] add auth_strategy principals", zap.Error(err)) return err } - if err := s.addStrategyResources(tx, strategy.ID, strategy.Resources); err != nil { + if err := s.addStrategyResources(dbTx, strategy.ID, strategy.Resources); err != nil { log.Error("[Store][Strategy] add auth_strategy resources", zap.Error(err)) return err } @@ -94,7 +88,7 @@ func (s *strategyStore) addStrategy(strategy *model.StrategyDetail) error { // 保存策略主信息 saveMainSql := "INSERT INTO auth_strategy(`id`, `name`, `action`, `owner`, `comment`, `flag`, " + " `default`, `revision`) VALUES (?,?,?,?,?,?,?,?)" - if _, err = tx.Exec(saveMainSql, + if _, err := dbTx.Exec(saveMainSql, []interface{}{ strategy.ID, strategy.Name, strategy.Action, strategy.Owner, strategy.Comment, 0, isDefault, strategy.Revision}..., @@ -102,17 +96,11 @@ func (s *strategyStore) addStrategy(strategy *model.StrategyDetail) error { log.Error("[Store][Strategy] add auth_strategy main info", zap.Error(err)) return err } - - if err := tx.Commit(); err != nil { - log.Errorf("[Store][Strategy] add auth_strategy tx commit err: %s", err.Error()) - return err - } - return nil } // UpdateStrategy 更新鉴权规则 -func (s *strategyStore) UpdateStrategy(strategy *model.ModifyStrategyDetail) error { +func (s *strategyStore) UpdateStrategy(strategy *authcommon.ModifyStrategyDetail) error { if strategy.ID == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "update auth_strategy missing some params, id is %s", strategy.ID)) @@ -124,7 +112,7 @@ func (s *strategyStore) UpdateStrategy(strategy *model.ModifyStrategyDetail) err return store.Error(err) } -func (s *strategyStore) updateStrategy(strategy *model.ModifyStrategyDetail) error { +func (s *strategyStore) updateStrategy(strategy *authcommon.ModifyStrategyDetail) error { tx, err := s.master.Begin() if err != nil { return err @@ -212,7 +200,7 @@ func (s *strategyStore) deleteStrategy(id string) error { } // addStrategyPrincipals -func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals []model.Principal) error { +func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { if len(principals) == 0 { return nil } @@ -224,7 +212,7 @@ func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals for i := range principals { principal := principals[i] values = append(values, "(?,?,?)") - args = append(args, id, principal.PrincipalID, principal.PrincipalRole) + args = append(args, id, principal.PrincipalID, principal.PrincipalType) } savePrincipalSql += strings.Join(values, ",") @@ -238,7 +226,7 @@ func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals // deleteStrategyPrincipals func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, - principals []model.Principal) error { + principals []authcommon.Principal) error { if len(principals) == 0 { return nil } @@ -248,7 +236,7 @@ func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, for i := range principals { principal := principals[i] if _, err := tx.Exec(savePrincipalSql, []interface{}{ - id, principal.PrincipalID, principal.PrincipalRole, + id, principal.PrincipalID, principal.PrincipalType, }...); err != nil { return err } @@ -257,7 +245,7 @@ func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, return nil } -func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources []model.StrategyResource) error { +func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { if len(resources) == 0 { return nil } @@ -284,7 +272,7 @@ func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources [] } func (s *strategyStore) deleteStrategyResources(tx *BaseTx, id string, - resources []model.StrategyResource) error { + resources []authcommon.StrategyResource) error { if len(resources) == 0 { return nil @@ -305,7 +293,7 @@ func (s *strategyStore) deleteStrategyResources(tx *BaseTx, id string, } // LooseAddStrategyResources loose add strategy resources -func (s *strategyStore) LooseAddStrategyResources(resources []model.StrategyResource) error { +func (s *strategyStore) LooseAddStrategyResources(resources []authcommon.StrategyResource) error { tx, err := s.master.Begin() if err != nil { return err @@ -344,7 +332,7 @@ func (s *strategyStore) LooseAddStrategyResources(resources []model.StrategyReso } // RemoveStrategyResources 删除策略的资源 -func (s *strategyStore) RemoveStrategyResources(resources []model.StrategyResource) error { +func (s *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyResource) error { tx, err := s.master.Begin() if err != nil { return err @@ -379,7 +367,7 @@ func (s *strategyStore) RemoveStrategyResources(resources []model.StrategyResour } // GetStrategyDetail 获取策略详情 -func (s *strategyStore) GetStrategyDetail(id string) (*model.StrategyDetail, error) { +func (s *strategyStore) GetStrategyDetail(id string) (*authcommon.StrategyDetail, error) { if id == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "get auth_strategy missing some params, id is %s", id)) @@ -395,7 +383,7 @@ func (s *strategyStore) GetStrategyDetail(id string) (*model.StrategyDetail, err // GetDefaultStrategyDetailByPrincipal 获取默认策略 func (s *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, - principalType model.PrincipalType) (*model.StrategyDetail, error) { + principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) { if principalId == "" { return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( @@ -422,12 +410,12 @@ func (s *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, } // getStrategyDetail -func (s *strategyStore) getStrategyDetail(row *sql.Row) (*model.StrategyDetail, error) { +func (s *strategyStore) getStrategyDetail(row *sql.Row) (*authcommon.StrategyDetail, error) { var ( ctime, mtime int64 isDefault, flag int16 ) - ret := new(model.StrategyDetail) + ret := new(authcommon.StrategyDetail) if err := row.Scan(&ret.ID, &ret.Name, &ret.Action, &ret.Owner, &isDefault, &ret.Comment, &ret.Revision, &flag, &ctime, &mtime); err != nil { switch err { @@ -459,7 +447,7 @@ func (s *strategyStore) getStrategyDetail(row *sql.Row) (*model.StrategyDetail, // GetStrategies 获取策略列表 func (s *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.StrategyDetail, error) { + []*authcommon.StrategyDetail, error) { showDetail := filters["show_detail"] delete(filters, "show_detail") @@ -470,7 +458,7 @@ func (s *strategyStore) GetStrategies(filters map[string]string, offset uint32, // listStrategies func (s *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, - showDetail bool) (uint32, []*model.StrategyDetail, error) { + showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { querySql := `SELECT @@ -509,7 +497,7 @@ func (s *strategyStore) queryStrategies( handler QueryHandler, filters map[string]string, mapping map[string]string, querySqlPrefix string, countSqlPrefix string, - offset uint32, limit uint32, showDetail bool) (uint32, []*model.StrategyDetail, error) { + offset uint32, limit uint32, showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { querySql := querySqlPrefix countSql := countSqlPrefix @@ -570,7 +558,7 @@ func (s *strategyStore) queryStrategies( // collectStrategies 执行真正的 sql 并从 rows 中获取策略列表 func (s *strategyStore) collectStrategies(handler QueryHandler, querySql string, - args []interface{}, showDetail bool) ([]*model.StrategyDetail, error) { + args []interface{}, showDetail bool) ([]*authcommon.StrategyDetail, error) { log.Debug("[Store][Strategy] get simple strategies", zap.String("query sql", querySql), zap.Any("args", args)) @@ -586,7 +574,7 @@ func (s *strategyStore) collectStrategies(handler QueryHandler, querySql string, idMap := make(map[string]struct{}) - ret := make([]*model.StrategyDetail, 0, 16) + ret := make([]*authcommon.StrategyDetail, 0, 16) for rows.Next() { detail, err := fetchRown2StrategyDetail(rows) if err != nil { @@ -619,8 +607,7 @@ func (s *strategyStore) collectStrategies(handler QueryHandler, querySql string, return ret, nil } -func (s *strategyStore) GetStrategyDetailsForCache(mtime time.Time, - firstUpdate bool) ([]*model.StrategyDetail, error) { +func (s *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { tx, err := s.slave.Begin() if err != nil { return nil, store.Error(err) @@ -644,7 +631,7 @@ func (s *strategyStore) GetStrategyDetailsForCache(mtime time.Time, _ = rows.Close() }() - ret := make([]*model.StrategyDetail, 0) + ret := make([]*authcommon.StrategyDetail, 0) for rows.Next() { detail, err := fetchRown2StrategyDetail(rows) if err != nil { @@ -671,7 +658,7 @@ func (s *strategyStore) GetStrategyDetailsForCache(mtime time.Time, // GetStrategyResources 获取对应 principal 能操作的所有资源 func (s *strategyStore) GetStrategyResources(principalId string, - principalRole model.PrincipalType) ([]model.StrategyResource, error) { + principalRole authcommon.PrincipalType) ([]authcommon.StrategyResource, error) { querySql := "SELECT res_id, res_type FROM auth_strategy_resource WHERE strategy_id IN (SELECT DISTINCT " + " ap.strategy_id FROM auth_principal ap join auth_strategy ar ON ap.strategy_id = ar.id WHERE ar.flag = 0 " + @@ -691,10 +678,10 @@ func (s *strategyStore) GetStrategyResources(principalId string, defer rows.Close() - resArr := make([]model.StrategyResource, 0) + resArr := make([]authcommon.StrategyResource, 0) for rows.Next() { - res := new(model.StrategyResource) + res := new(authcommon.StrategyResource) if err := rows.Scan(&res.ResID, &res.ResType); err != nil { return nil, store.Error(err) } @@ -704,7 +691,7 @@ func (s *strategyStore) GetStrategyResources(principalId string, return resArr, nil } -func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id string) ([]model.Principal, error) { +func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id string) ([]authcommon.Principal, error) { rows, err := queryHander("SELECT principal_id, principal_role FROM auth_principal WHERE strategy_id = ?", id) if err != nil { @@ -718,11 +705,11 @@ func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id strin } defer rows.Close() - principals := make([]model.Principal, 0) + principals := make([]authcommon.Principal, 0) for rows.Next() { - res := new(model.Principal) - if err := rows.Scan(&res.PrincipalID, &res.PrincipalRole); err != nil { + res := new(authcommon.Principal) + if err := rows.Scan(&res.PrincipalID, &res.PrincipalType); err != nil { return nil, store.Error(err) } principals = append(principals, *res) @@ -731,7 +718,7 @@ func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id strin return principals, nil } -func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string) ([]model.StrategyResource, error) { +func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string) ([]authcommon.StrategyResource, error) { querySql := "SELECT res_id, res_type FROM auth_strategy_resource WHERE strategy_id = ?" rows, err := queryHander(querySql, id) if err != nil { @@ -745,10 +732,10 @@ func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string } defer rows.Close() - resArr := make([]model.StrategyResource, 0) + resArr := make([]authcommon.StrategyResource, 0) for rows.Next() { - res := new(model.StrategyResource) + res := new(authcommon.StrategyResource) if err := rows.Scan(&res.ResID, &res.ResType); err != nil { return nil, store.Error(err) } @@ -758,13 +745,13 @@ func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string return resArr, nil } -func fetchRown2StrategyDetail(rows *sql.Rows) (*model.StrategyDetail, error) { +func fetchRown2StrategyDetail(rows *sql.Rows) (*authcommon.StrategyDetail, error) { var ( ctime, mtime int64 isDefault, flag int16 ) - ret := &model.StrategyDetail{ - Resources: make([]model.StrategyResource, 0), + ret := &authcommon.StrategyDetail{ + Resources: make([]authcommon.StrategyResource, 0), } if err := rows.Scan(&ret.ID, &ret.Name, &ret.Action, &ret.Owner, &ret.Comment, &isDefault, &ret.Revision, &flag, @@ -783,34 +770,12 @@ func fetchRown2StrategyDetail(rows *sql.Rows) (*model.StrategyDetail, error) { return ret, nil } -// cleanInvalidStrategy 按名称清理鉴权策略 -func (s *strategyStore) cleanInvalidStrategy(name, owner string) error { - log.Info("[Store][Strategy] clean invalid auth_strategy", - zap.String("name", name), zap.String("owner", owner)) - - tx, err := s.master.Begin() - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - - str := "delete from auth_strategy where name = ? and owner = ? and flag = 1" - if _, err = tx.Exec(str, name, owner); err != nil { - log.Errorf("[Store][Strategy] clean invalid auth_strategy(%s) err: %s", name, err.Error()) - return err - } - if err := tx.Commit(); err != nil { - log.Errorf("[Store][Strategy] clean invalid auth_strategy tx commit err: %s", err.Error()) - return err - } - return nil -} - -// cleanLinkStrategy 清理与自己相关联的鉴权信息 +// CleanPrincipalPolicies 清理与自己相关联的鉴权信息 // step 1. 清理用户/用户组默认策略所关联的所有资源信息(直接走delete删除) // step 2. 清理用户/用户组默认策略 // step 3. 清理用户/用户组所关联的其他鉴权策略的关联关系(直接走delete删除) -func cleanLinkStrategy(tx *BaseTx, role model.PrincipalType, principalId, owner string) error { +func (s *strategyStore) CleanPrincipalPolicies(tx store.Tx, p authcommon.Principal) error { + dbTx := tx.GetDelegateTx().(*BaseTx) // 清理默认策略对应的所有鉴权关联资源 removeResSql := ` @@ -829,7 +794,7 @@ func cleanLinkStrategy(tx *BaseTx, role model.PrincipalType, principalId, owner ) ` - if _, err := tx.Exec(removeResSql, []interface{}{owner, principalId, role}...); err != nil { + if _, err := dbTx.Exec(removeResSql, []interface{}{p.Owner, p.PrincipalID, p.PrincipalType}...); err != nil { return err } @@ -847,20 +812,20 @@ func cleanLinkStrategy(tx *BaseTx, role model.PrincipalType, principalId, owner AND ag.owner = ? ` - if _, err := tx.Exec(cleanaRuleSql, []interface{}{principalId, role, owner}...); err != nil { + if _, err := dbTx.Exec(cleanaRuleSql, []interface{}{p.PrincipalID, p.PrincipalType, p.Owner}...); err != nil { return err } // 调整所关联的鉴权策略的 mtime 数据,保证cache刷新可以获取到变更的数据信息 updateStrategySql := "UPDATE auth_strategy SET mtime = sysdate() WHERE id IN (SELECT DISTINCT " + " strategy_id FROM auth_principal WHERE principal_id = ? AND principal_role = ?)" - if _, err := tx.Exec(updateStrategySql, []interface{}{principalId, role}...); err != nil { + if _, err := dbTx.Exec(updateStrategySql, []interface{}{p.PrincipalID, p.PrincipalType}...); err != nil { return err } // 清理所在的所有鉴权principal cleanPrincipalSql := "DELETE FROM auth_principal WHERE principal_id = ? AND principal_role = ?" - if _, err := tx.Exec(cleanPrincipalSql, []interface{}{principalId, role}...); err != nil { + if _, err := dbTx.Exec(cleanPrincipalSql, []interface{}{p.PrincipalID, p.PrincipalType}...); err != nil { return err } diff --git a/store/mysql/user.go b/store/mysql/user.go index 22e6db4f3..2a94f51ef 100644 --- a/store/mysql/user.go +++ b/store/mysql/user.go @@ -20,13 +20,11 @@ package sqldb import ( "database/sql" "fmt" - "strings" "time" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -54,25 +52,23 @@ type userStore struct { } // AddUser 添加用户 -func (u *userStore) AddUser(user *model.User) error { +func (u *userStore) AddUser(tx store.Tx, user *authcommon.User) error { if user.ID == "" || user.Name == "" || user.Token == "" || user.Password == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "add user missing some params, id is %s, name is %s", user.ID, user.Name)) } + dbTx := tx.GetDelegateTx().(*BaseTx) // 先清理无效数据 - if err := u.cleanInValidUser(user.Name, user.Owner); err != nil { + if err := u.cleanInValidUser(dbTx, user.Name, user.Owner); err != nil { return err } - err := RetryTransaction("addUser", func() error { - return u.addUser(user) - }) - + err := u.addUser(dbTx, user) return store.Error(err) } -func (u *userStore) addUser(user *model.User) error { +func (u *userStore) addUser(tx *BaseTx, user *authcommon.User) error { tx, err := u.master.Begin() if err != nil { @@ -102,26 +98,11 @@ func (u *userStore) addUser(user *model.User) error { if err != nil { return store.Error(err) } - - owner := user.Owner - if owner == "" { - owner = user.ID - } - - if err := createDefaultStrategy(tx, model.PrincipalUser, user.ID, user.Name, user.Owner); err != nil { - log.Error("[Auth][User] create default strategy", zap.Error(err)) - return store.Error(err) - } - - if err := tx.Commit(); err != nil { - log.Errorf("[Store][User] add user tx commit err: %s", err.Error()) - return store.Error(err) - } return nil } // UpdateUser 更新用户信息 -func (u *userStore) UpdateUser(user *model.User) error { +func (u *userStore) UpdateUser(user *authcommon.User) error { if user.ID == "" || user.Name == "" || user.Token == "" || user.Password == "" { return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( "update user missing some params, id is %s, name is %s", user.ID, user.Name)) @@ -134,7 +115,7 @@ func (u *userStore) UpdateUser(user *model.User) error { return store.Error(err) } -func (u *userStore) updateUser(user *model.User) error { +func (u *userStore) updateUser(user *authcommon.User) error { tx, err := u.master.Begin() if err != nil { @@ -174,63 +155,33 @@ func (u *userStore) updateUser(user *model.User) error { } // DeleteUser delete user by user id -func (u *userStore) DeleteUser(user *model.User) error { +func (u *userStore) DeleteUser(tx store.Tx, user *authcommon.User) error { if user.ID == "" || user.Name == "" { return store.NewStatusError(store.EmptyParamsErr, "delete user id parameter missing") } - err := RetryTransaction("deleteUser", func() error { - return u.deleteUser(user) - }) - - return store.Error(err) -} - -// deleteUser Specific deletion user steps -// step 1. Delete the user-associated policy information -// -// a. Delete the user's default policy -// b. Update the latest update time of related policies, make the Cache mechanism -// c. Delete the association relationship of the user and policy -// -// step 2. Delete the user group associated with this user -func (u *userStore) deleteUser(user *model.User) error { - tx, err := u.master.Begin() - if err != nil { - return err - } - - defer func() { _ = tx.Rollback() }() - - if err := cleanLinkStrategy(tx, model.PrincipalUser, user.ID, user.Owner); err != nil { - return err - } + dbTx := tx.GetDelegateTx().(*BaseTx) - if _, err = tx.Exec("UPDATE user SET flag = 1 WHERE id = ?", user.ID); err != nil { + if _, err := dbTx.Exec("UPDATE user SET flag = 1 WHERE id = ?", user.ID); err != nil { log.Error("[Store][User] update set user flag", zap.Error(err)) return err } - if _, err = tx.Exec("UPDATE user_group SET mtime = sysdate() WHERE id IN (SELECT DISTINCT group_id FROM "+ + if _, err := dbTx.Exec("UPDATE user_group SET mtime = sysdate() WHERE id IN (SELECT DISTINCT group_id FROM "+ " user_group_relation WHERE user_id = ?)", user.ID); err != nil { log.Error("[Store][User] update usergroup mtime", zap.Error(err)) return err } - if _, err = tx.Exec("DELETE FROM user_group_relation WHERE user_id = ?", user.ID); err != nil { + if _, err := dbTx.Exec("DELETE FROM user_group_relation WHERE user_id = ?", user.ID); err != nil { log.Error("[Store][User] delete usergroup relation", zap.Error(err)) return err } - - if err := tx.Commit(); err != nil { - log.Error("[Store][User] delete user tx commit", zap.Error(err)) - return err - } return nil } // GetSubCount get user's sub count -func (u *userStore) GetSubCount(user *model.User) (uint32, error) { +func (u *userStore) GetSubCount(user *authcommon.User) (uint32, error) { var ( countSql = "SELECT COUNT(*) FROM user WHERE owner = ? AND flag = 0" count, err = queryEntryCount(u.master, countSql, []interface{}{user.ID}) @@ -244,7 +195,7 @@ func (u *userStore) GetSubCount(user *model.User) (uint32, error) { } // GetUser get user by user id -func (u *userStore) GetUser(id string) (*model.User, error) { +func (u *userStore) GetUser(id string) (*authcommon.User, error) { var tokenEnable, userType int getSql := ` SELECT u.id, u.name, u.password, u.owner, u.comment, u.source, u.token, u.token_enable, @@ -254,7 +205,7 @@ func (u *userStore) GetUser(id string) (*model.User, error) { ` var ( row = u.master.QueryRow(getSql, id) - user = new(model.User) + user = new(authcommon.User) ) if err := row.Scan(&user.ID, &user.Name, &user.Password, &user.Owner, &user.Comment, &user.Source, @@ -268,7 +219,7 @@ func (u *userStore) GetUser(id string) (*model.User, error) { } user.TokenEnable = tokenEnable == 1 - user.Type = model.UserRoleType(userType) + user.Type = authcommon.UserRoleType(userType) // 北极星后续不在保存用户的 mobile 以及 email 信息,这里针对原来保存的数据也不进行对外展示,强制屏蔽数据 user.Mobile = "" user.Email = "" @@ -276,7 +227,7 @@ func (u *userStore) GetUser(id string) (*model.User, error) { } // GetUserByName 根据用户名、owner 获取用户 -func (u *userStore) GetUserByName(name, ownerId string) (*model.User, error) { +func (u *userStore) GetUserByName(name, ownerId string) (*authcommon.User, error) { getSql := ` SELECT u.id, u.name, u.password, u.owner, u.comment, u.source, u.token, u.token_enable, u.user_type, u.mobile, u.email @@ -288,7 +239,7 @@ func (u *userStore) GetUserByName(name, ownerId string) (*model.User, error) { var ( row = u.master.QueryRow(getSql, name, ownerId) - user = new(model.User) + user = new(authcommon.User) tokenEnable, userType int ) @@ -303,7 +254,7 @@ func (u *userStore) GetUserByName(name, ownerId string) (*model.User, error) { } user.TokenEnable = tokenEnable == 1 - user.Type = model.UserRoleType(userType) + user.Type = authcommon.UserRoleType(userType) // 北极星后续不在保存用户的 mobile 以及 email 信息,这里针对原来保存的数据也不进行对外展示,强制屏蔽数据 user.Mobile = "" user.Email = "" @@ -311,7 +262,7 @@ func (u *userStore) GetUserByName(name, ownerId string) (*model.User, error) { } // GetUserByIds Get user list data according to user ID -func (u *userStore) GetUserByIds(ids []string) ([]*model.User, error) { +func (u *userStore) GetUserByIds(ids []string) ([]*authcommon.User, error) { if len(ids) == 0 { return nil, nil } @@ -346,7 +297,7 @@ func (u *userStore) GetUserByIds(ids []string) ([]*model.User, error) { _ = rows.Close() }() - users := make([]*model.User, 0) + users := make([]*authcommon.User, 0) for rows.Next() { user, err := fetchRown2User(rows) if err != nil { @@ -363,7 +314,7 @@ func (u *userStore) GetUserByIds(ids []string) ([]*model.User, error) { // Case 1. From the user's perspective, normal query conditions // Case 2. From the perspective of the user group, query is the list of users involved under a user group. func (u *userStore) GetUsers(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.User, error) { + []*authcommon.User, error) { if _, ok := filters["group_id"]; ok { return u.listGroupUsers(filters, offset, limit) } @@ -372,7 +323,7 @@ func (u *userStore) GetUsers(filters map[string]string, offset uint32, limit uin // listUsers Query user list information func (u *userStore) listUsers(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.User, error) { + []*authcommon.User, error) { countSql := "SELECT COUNT(*) FROM user WHERE flag = 0 " getSql := ` SELECT id, name, password, owner, comment, source @@ -434,7 +385,7 @@ func (u *userStore) listUsers(filters map[string]string, offset uint32, limit ui // listGroupUsers Check the user information under a user group func (u *userStore) listGroupUsers(filters map[string]string, offset uint32, limit uint32) (uint32, - []*model.User, error) { + []*authcommon.User, error) { if _, ok := filters[GroupIDAttribute]; !ok { return 0, nil, store.NewStatusError(store.EmptyParamsErr, "group_id is missing") } @@ -498,7 +449,7 @@ func (u *userStore) listGroupUsers(filters map[string]string, offset uint32, lim } // GetUsersForCache Get user information, mainly for cache -func (u *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*model.User, error) { +func (u *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.User, error) { args := make([]interface{}, 0) querySql := ` SELECT u.id, u.name, u.password, u.owner, u.comment, u.source @@ -521,7 +472,7 @@ func (u *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*mode } // collectUsers General query user list -func (u *userStore) collectUsers(handler QueryHandler, querySql string, args []interface{}) ([]*model.User, error) { +func (u *userStore) collectUsers(handler QueryHandler, querySql string, args []interface{}) ([]*authcommon.User, error) { rows, err := u.master.Query(querySql, args...) if err != nil { log.Error("[Store][User] list user ", zap.String("query sql", querySql), zap.Any("args", args), zap.Error(err)) @@ -530,7 +481,7 @@ func (u *userStore) collectUsers(handler QueryHandler, querySql string, args []i defer func() { _ = rows.Close() }() - users := make([]*model.User, 0) + users := make([]*authcommon.User, 0) for rows.Next() { user, err := fetchRown2User(rows) if err != nil { @@ -543,51 +494,11 @@ func (u *userStore) collectUsers(handler QueryHandler, querySql string, args []i return users, nil } -func createDefaultStrategy(tx *BaseTx, role model.PrincipalType, id, name, owner string) error { - if strings.Compare(owner, "") == 0 { - owner = id - } - - // Create the user's default weight policy - strategy := &model.StrategyDetail{ - ID: utils.NewUUID(), - Name: model.BuildDefaultStrategyName(role, name), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Default: true, - Owner: owner, - Revision: utils.NewUUID(), - Resources: []model.StrategyResource{}, - Valid: true, - Comment: "Default Strategy", - } - - // 需要清理过期的 auth_strategy - cleanInvalidRule := "DELETE FROM auth_strategy WHERE name = ? AND owner = ? AND flag = 1 AND `default` = ?" - if _, err := tx.Exec(cleanInvalidRule, []interface{}{strategy.Name, strategy.Owner, - strategy.Default}...); err != nil { - return err - } - - // Save policy master information - saveMainSql := "INSERT INTO auth_strategy(`id`, `name`, `action`, `owner`, `comment`, `flag`, " + - " `default`, `revision`) VALUES (?,?,?,?,?,?,?,?)" - if _, err := tx.Exec(saveMainSql, []interface{}{strategy.ID, strategy.Name, strategy.Action, - strategy.Owner, strategy.Comment, - 0, strategy.Default, strategy.Revision}...); err != nil { - return err - } - - // Insert User / Group and Policy Association - savePrincipalSql := "INSERT INTO auth_principal(`strategy_id`, `principal_id`, `principal_role`) VALUES (?,?,?)" - _, err := tx.Exec(savePrincipalSql, []interface{}{strategy.ID, id, role}...) - return err -} - -func fetchRown2User(rows *sql.Rows) (*model.User, error) { +func fetchRown2User(rows *sql.Rows) (*authcommon.User, error) { var ( ctime, mtime int64 flag, tokenEnable, userType int - user = new(model.User) + user = new(authcommon.User) err = rows.Scan(&user.ID, &user.Name, &user.Password, &user.Owner, &user.Comment, &user.Source, &user.Token, &tokenEnable, &userType, &ctime, &mtime, &flag, &user.Mobile, &user.Email) @@ -601,7 +512,7 @@ func fetchRown2User(rows *sql.Rows) (*model.User, error) { user.TokenEnable = tokenEnable == 1 user.CreateTime = time.Unix(ctime, 0) user.ModifyTime = time.Unix(mtime, 0) - user.Type = model.UserRoleType(userType) + user.Type = authcommon.UserRoleType(userType) // 北极星后续不在保存用户的 mobile 以及 email 信息,这里针对原来保存的数据也不进行对外展示,强制屏蔽数据 user.Mobile = "" @@ -610,13 +521,12 @@ func fetchRown2User(rows *sql.Rows) (*model.User, error) { return user, nil } -func (u *userStore) cleanInValidUser(name, owner string) error { +func (u *userStore) cleanInValidUser(tx *BaseTx, name, owner string) error { log.Infof("[Store][User] clean user, name=(%s), owner=(%s)", name, owner) str := "delete from user where name = ? and owner = ? and flag = 1" - if _, err := u.master.Exec(str, name, owner); err != nil { + if _, err := tx.Exec(str, name, owner); err != nil { log.Errorf("[Store][User] clean user(%s) err: %s", name, err.Error()) return err } - return nil }