diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index ad849bec5..8ae24e36c 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -61,25 +61,13 @@ jobs: with: fetch-depth: 2 - - uses: shogo82148/actions-setup-mysql@v1 - with: - mysql-version: "5.7" - auto-start: true - my-cnf: | - innodb_log_file_size=256MB - innodb_buffer_pool_size=512MB - max_allowed_packet=16MB - max_connections=50 - local_infile=1 - root-password: root - - - name: Initialize database env: MYSQL_DB_USER: root MYSQL_DB_PWD: root MYSQL_DATABASE: polaris_server run: | + sudo systemctl start mysql.service mysql -e 'CREATE DATABASE ${{ env.MYSQL_DATABASE }};' -u${{ env.MYSQL_DB_USER }} -p${{ env.MYSQL_DB_PWD }} mysql -e "ALTER USER '${{ env.MYSQL_DB_USER }}'@'localhost' IDENTIFIED WITH mysql_native_password BY 'root';" -u${{ env.MYSQL_DB_USER }} -p${{ env.MYSQL_DB_PWD }} diff --git a/.golangci.yml b/.golangci.yml index 33340b6fb..3ef63cc8b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -167,7 +167,7 @@ linters-settings: disabled: false - name: max-public-structs severity: warning - disabled: true + disabled: false arguments: [35] - name: indent-error-flow severity: warning @@ -281,7 +281,7 @@ linters-settings: govet: # Report about shadowed variables. # Default: false - shadow: false + check-shadowing: true # Settings per analyzer. settings: # Analyzer name, run `go tool vet help` to see all analyzers. diff --git a/admin/api.go b/admin/api.go index 8f717c105..aec490231 100644 --- a/admin/api.go +++ b/admin/api.go @@ -25,7 +25,6 @@ import ( "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/admin" - authcommon "github.com/polarismesh/polaris/common/model/auth" ) // AdminOperateServer Maintain related operation @@ -56,6 +55,4 @@ type AdminOperateServer interface { GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) // InitMainUser InitMainUser(ctx context.Context, user apisecurity.User) error - // GetServerFunctions Get server functions - GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup } diff --git a/admin/config.go b/admin/config.go index 21bea1f52..891a68e3a 100644 --- a/admin/config.go +++ b/admin/config.go @@ -23,8 +23,7 @@ import ( // Config maintain configuration type Config struct { - Jobs []job.JobConfig `yaml:"jobs"` - Interceptors []string `yaml:"-"` + Jobs []job.JobConfig `yaml:"jobs"` } func DefaultConfig() *Config { diff --git a/admin/default.go b/admin/default.go index 5ebf3643b..c7b3ab59d 100644 --- a/admin/default.go +++ b/admin/default.go @@ -20,9 +20,9 @@ package admin import ( "context" "errors" - "fmt" "github.com/polarismesh/polaris/admin/job" + "github.com/polarismesh/polaris/auth" "github.com/polarismesh/polaris/cache" "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/service/healthcheck" @@ -30,22 +30,11 @@ import ( ) var ( - server AdminOperateServer - maintainServer = &Server{} - finishInit bool - serverProxyFactories = map[string]ServerProxyFactory{} + server AdminOperateServer + maintainServer = &Server{} + finishInit bool ) -type ServerProxyFactory func(ctx context.Context, pre AdminOperateServer) (AdminOperateServer, error) - -func RegisterServerProxy(name string, factor ServerProxyFactory) error { - if _, ok := serverProxyFactories[name]; ok { - return fmt.Errorf("duplicate ServerProxyFactory, name(%s)", name) - } - serverProxyFactories[name] = factor - return nil -} - // Initialize 初始化 func Initialize(ctx context.Context, cfg *Config, namingService service.DiscoverServer, healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) error { @@ -54,49 +43,40 @@ func Initialize(ctx context.Context, cfg *Config, namingService service.Discover return nil } - proxySvr, actualSvr, err := InitServer(ctx, cfg, namingService, healthCheckServer, cacheMgn, storage) + err := initialize(ctx, cfg, namingService, healthCheckServer, cacheMgn, storage) if err != nil { return err } - server = proxySvr - maintainServer = actualSvr finishInit = true return nil } -func InitServer(ctx context.Context, cfg *Config, namingService service.DiscoverServer, - healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) (AdminOperateServer, *Server, error) { +func initialize(_ context.Context, cfg *Config, namingService service.DiscoverServer, + healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) error { - actualSvr := new(Server) + userMgn, err := auth.GetUserServer() + if err != nil { + return err + } - actualSvr.namingServer = namingService - actualSvr.healthCheckServer = healthCheckServer - actualSvr.cacheMgn = cacheMgn - actualSvr.storage = storage + strategyMgn, err := auth.GetStrategyServer() + if err != nil { + return err + } + + maintainServer.namingServer = namingService + maintainServer.healthCheckServer = healthCheckServer + maintainServer.cacheMgn = cacheMgn + maintainServer.storage = storage maintainJobs := job.NewMaintainJobs(namingService, cacheMgn, storage) if err := maintainJobs.StartMaintianJobs(cfg.Jobs); err != nil { - return nil, nil, err - } - - var proxySvr AdminOperateServer - proxySvr = actualSvr - order := GetChainOrder() - for i := range order { - factory, exist := serverProxyFactories[order[i]] - if !exist { - return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) - } - - afterSvr, err := factory(ctx, proxySvr) - if err != nil { - return nil, nil, err - } - proxySvr = afterSvr + return err } - return proxySvr, actualSvr, nil + server = newServerAuthAbility(maintainServer, userMgn, strategyMgn) + return nil } // GetServer 获取已经初始化好的Server diff --git a/admin/interceptor/auth/log.go b/admin/interceptor/auth/log.go deleted file mode 100644 index 9ff5f1c28..000000000 --- a/admin/interceptor/auth/log.go +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import ( - commonlog "github.com/polarismesh/polaris/common/log" -) - -var log = commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName) diff --git a/admin/interceptor/auth/server.go b/admin/interceptor/auth/server.go deleted file mode 100644 index a5e67c578..000000000 --- a/admin/interceptor/auth/server.go +++ /dev/null @@ -1,214 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import ( - "context" - - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - - "github.com/polarismesh/polaris/admin" - "github.com/polarismesh/polaris/auth" - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - admincommon "github.com/polarismesh/polaris/common/model/admin" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" -) - -var _ admin.AdminOperateServer = (*Server)(nil) - -// Server 带有鉴权能力的 maintainServer -type Server struct { - nextSvr admin.AdminOperateServer - userSvr auth.UserServer - policySvr auth.StrategyServer -} - -func NewServer(nextSvr admin.AdminOperateServer, - userSvr auth.UserServer, policySvr auth.StrategyServer) admin.AdminOperateServer { - proxy := &Server{ - nextSvr: nextSvr, - userSvr: userSvr, - policySvr: policySvr, - } - - return proxy -} - -func (svr *Server) collectMaintainAuthContext(ctx context.Context, resourceOp authcommon.ResourceOperation, - methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { - return authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), - authcommon.WithModule(authcommon.MaintainModule), - authcommon.WithMethod(methodName), - ) -} - -func (s *Server) HasMainUser(ctx context.Context) (bool, error) { - return false, nil -} - -func (s *Server) InitMainUser(ctx context.Context, user apisecurity.User) error { - return nil -} - -func (svr *Server) GetServerConnections(ctx context.Context, req *admincommon.ConnReq) (*admincommon.ConnCountResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnections) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.GetServerConnections(ctx, req) -} - -func (svr *Server) GetServerConnStats(ctx context.Context, req *admincommon.ConnReq) (*admincommon.ConnStatsResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnStats) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.GetServerConnStats(ctx, req) -} - -func (svr *Server) CloseConnections(ctx context.Context, reqs []admincommon.ConnReq) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CloseConnections) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.CloseConnections(ctx, reqs) -} - -func (svr *Server) FreeOSMemory(ctx context.Context) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.FreeOSMemory) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.FreeOSMemory(ctx) -} - -func (svr *Server) CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CleanInstance) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.CleanInstance(ctx, req) -} - -func (svr *Server) BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.BatchCleanInstances) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return 0, err - } - - return svr.nextSvr.BatchCleanInstances(ctx, batchSize) -} - -func (svr *Server) GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeInstanceLastHeartbeat) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.GetLastHeartbeat(ctx, req) -} - -func (svr *Server) GetLogOutputLevel(ctx context.Context) ([]admincommon.ScopeLevel, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeGetLogOutputLevel) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.GetLogOutputLevel(ctx) -} - -func (svr *Server) SetLogOutputLevel(ctx context.Context, scope string, level string) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.UpdateLogOutputLevel) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - return svr.nextSvr.SetLogOutputLevel(ctx, scope, level) -} - -func (svr *Server) ListLeaderElections(ctx context.Context) ([]*admincommon.LeaderElection, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeLeaderElections) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.ListLeaderElections(ctx) -} - -func (svr *Server) ReleaseLeaderElection(ctx context.Context, electKey string) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.ReleaseLeaderElection) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.ReleaseLeaderElection(ctx, electKey) -} - -func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeCMDBInfo) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.nextSvr.GetCMDBInfo(ctx) -} - -// GetServerFunctions . -func (svr *Server) GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup { - return svr.nextSvr.GetServerFunctions(ctx) -} diff --git a/admin/interceptor/register.go b/admin/interceptor/register.go deleted file mode 100644 index d151f615d..000000000 --- a/admin/interceptor/register.go +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package inteceptor - -import ( - "context" - - "github.com/polarismesh/polaris/admin" - admin_auth "github.com/polarismesh/polaris/admin/interceptor/auth" - "github.com/polarismesh/polaris/auth" -) - -type ( - ContextKeyUserSvr struct{} - ContextKeyPolicySvr struct{} -) - -func init() { - err := admin.RegisterServerProxy("auth", func(ctx context.Context, - pre admin.AdminOperateServer) (admin.AdminOperateServer, error) { - - var userSvr auth.UserServer - var policySvr auth.StrategyServer - - userSvrVal := ctx.Value(ContextKeyUserSvr{}) - if userSvrVal == nil { - svr, err := auth.GetUserServer() - if err != nil { - return nil, err - } - userSvr = svr - } else { - userSvr = userSvrVal.(auth.UserServer) - } - - policySvrVal := ctx.Value(ContextKeyPolicySvr{}) - if policySvrVal == nil { - svr, err := auth.GetStrategyServer() - if err != nil { - return nil, err - } - policySvr = svr - } else { - policySvr = policySvrVal.(auth.StrategyServer) - } - - return admin_auth.NewServer(pre, userSvr, policySvr), nil - }) - if err != nil { - panic(err) - } -} diff --git a/admin/maintain.go b/admin/maintain.go index 0b4fafa74..208eded34 100644 --- a/admin/maintain.go +++ b/admin/maintain.go @@ -33,7 +33,6 @@ import ( commonlog "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/admin" - authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" @@ -167,11 +166,11 @@ func (s *Server) CleanInstance(ctx context.Context, req *apiservice.Instance) *a } if err := s.storage.CleanInstance(instanceID); err != nil { log.Error("Clean instance", - zap.String("err", err.Error()), utils.RequestID(ctx)) + zap.String("err", err.Error()), utils.ZapRequestID(utils.ParseRequestID(ctx))) return api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } - log.Info("Clean instance", utils.RequestID(ctx), utils.ZapInstanceID(instanceID)) + log.Info("Clean instance", utils.ZapRequestID(utils.ParseRequestID(ctx)), utils.ZapInstanceID(instanceID)) return api.NewInstanceResponse(apimodel.Code_ExecuteSuccess, req) } @@ -206,6 +205,7 @@ func (s *Server) ListLeaderElections(_ context.Context) ([]*admin.LeaderElection func (s *Server) ReleaseLeaderElection(_ context.Context, electKey string) error { return s.storage.ReleaseLeaderElection(electKey) + } func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { @@ -230,8 +230,3 @@ func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error return ret, nil } - -// GetServerFunctions 获取服务端支持的功能列表 -func (svr *Server) GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup { - return authcommon.ServerFunctions -} diff --git a/admin/maintain_authability.go b/admin/maintain_authability.go new file mode 100644 index 000000000..d90e14a56 --- /dev/null +++ b/admin/maintain_authability.go @@ -0,0 +1,179 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package admin + +import ( + "context" + + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/admin" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" +) + +var _ AdminOperateServer = (*serverAuthAbility)(nil) + +func (s *serverAuthAbility) HasMainUser(ctx context.Context) (bool, error) { + return false, nil +} + +func (s *serverAuthAbility) InitMainUser(ctx context.Context, user apisecurity.User) error { + return nil +} + +func (svr *serverAuthAbility) GetServerConnections(ctx context.Context, req *admin.ConnReq) (*admin.ConnCountResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.GetServerConnections(ctx, req) +} + +func (svr *serverAuthAbility) GetServerConnStats(ctx context.Context, req *admin.ConnReq) (*admin.ConnStatsResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnStats) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.GetServerConnStats(ctx, req) +} + +func (svr *serverAuthAbility) CloseConnections(ctx context.Context, reqs []admin.ConnReq) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CloseConnections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.CloseConnections(ctx, reqs) +} + +func (svr *serverAuthAbility) FreeOSMemory(ctx context.Context) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.FreeOSMemory) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.FreeOSMemory(ctx) +} + +func (svr *serverAuthAbility) CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CleanInstance) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.CleanInstance(ctx, req) +} + +func (svr *serverAuthAbility) BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.BatchCleanInstances) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return 0, err + } + + return svr.targetServer.BatchCleanInstances(ctx, batchSize) +} + +func (svr *serverAuthAbility) GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeInstanceLastHeartbeat) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.GetLastHeartbeat(ctx, req) +} + +func (svr *serverAuthAbility) GetLogOutputLevel(ctx context.Context) ([]admin.ScopeLevel, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeGetLogOutputLevel) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.GetLogOutputLevel(ctx) +} + +func (svr *serverAuthAbility) SetLogOutputLevel(ctx context.Context, scope string, level string) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.UpdateLogOutputLevel) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + return svr.targetServer.SetLogOutputLevel(ctx, scope, level) +} + +func (svr *serverAuthAbility) ListLeaderElections(ctx context.Context) ([]*admin.LeaderElection, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeLeaderElections) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.ListLeaderElections(ctx) +} + +func (svr *serverAuthAbility) ReleaseLeaderElection(ctx context.Context, electKey string) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.ReleaseLeaderElection) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.ReleaseLeaderElection(ctx, electKey) +} + +func (svr *serverAuthAbility) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeCMDBInfo) + if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.GetCMDBInfo(ctx) +} diff --git a/admin/server.go b/admin/server.go index 9ab899f00..e2e682f76 100644 --- a/admin/server.go +++ b/admin/server.go @@ -35,9 +35,3 @@ type Server struct { cacheMgn *cache.CacheManager storage store.Store } - -func GetChainOrder() []string { - return []string{ - "auth", - } -} diff --git a/admin/server_authability.go b/admin/server_authability.go new file mode 100644 index 000000000..2ddfbcbde --- /dev/null +++ b/admin/server_authability.go @@ -0,0 +1,68 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package admin + +import ( + "context" + "errors" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + + "github.com/polarismesh/polaris/auth" + authcommon "github.com/polarismesh/polaris/common/model/auth" +) + +// serverAuthAbility 带有鉴权能力的 maintainServer +type serverAuthAbility struct { + targetServer *Server + userMgn auth.UserServer + strategyMgn auth.StrategyServer +} + +func newServerAuthAbility(targetServer *Server, + userMgn auth.UserServer, strategyMgn auth.StrategyServer) AdminOperateServer { + proxy := &serverAuthAbility{ + targetServer: targetServer, + userMgn: userMgn, + strategyMgn: strategyMgn, + } + + return proxy +} + +func (svr *serverAuthAbility) collectMaintainAuthContext(ctx context.Context, resourceOp authcommon.ResourceOperation, + methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.MaintainModule), + authcommon.WithMethod(methodName), + ) +} + +func convertToErrCode(err error) apimodel.Code { + if errors.Is(err, authcommon.ErrorTokenNotExist) { + return apimodel.Code_TokenNotExisted + } + + if errors.Is(err, authcommon.ErrorTokenDisabled) { + return apimodel.Code_TokenDisabled + } + + return apimodel.Code_NotAllowedAccess +} diff --git a/apiserver/httpserver/admin_access.go b/apiserver/httpserver/admin_access.go index 22784edac..80589b46a 100644 --- a/apiserver/httpserver/admin_access.go +++ b/apiserver/httpserver/admin_access.go @@ -67,7 +67,6 @@ func (h *HTTPServer) GetAdminAccessServer() *restful.WebService { ws.Route(docs.EnrichGetCMDBInfoApiDocs(ws.GET("/cmdb/info").To(h.GetCMDBInfo))) ws.Route(docs.EnrichGetReportClientsApiDocs(ws.GET("/report/clients").To(h.GetReportClients))) ws.Route(docs.EnrichEnablePprofApiDocs(ws.POST("/pprof/enable").To(h.EnablePprof))) - ws.Route(docs.EnrichGetServerFunctionsApiDocs(ws.GET("/server/functions").To(h.GetServerFunctions))) return ws } @@ -309,13 +308,6 @@ func (h *HTTPServer) EnablePprof(req *restful.Request, rsp *restful.Response) { _ = rsp.WriteEntity("ok") } -// GetServerFunctions . -func (h *HTTPServer) GetServerFunctions(req *restful.Request, rsp *restful.Response) { - ctx := initContext(req) - ret := h.maintainServer.GetServerFunctions(ctx) - _ = rsp.WriteAsJson(ret) -} - func initContext(req *restful.Request) context.Context { ctx := context.Background() diff --git a/apiserver/httpserver/auth_access.go b/apiserver/httpserver/auth_access.go index ce2a064f2..15c5fc0e2 100644 --- a/apiserver/httpserver/auth_access.go +++ b/apiserver/httpserver/auth_access.go @@ -35,7 +35,7 @@ import ( // GetAuthServer 运维接口 func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichAuthStatusApiDocs(ws.GET("/auth/status").To(h.AuthStatus))) - // 用户 + // ws.Route(docs.EnrichLoginApiDocs(ws.POST("/user/login").To(h.Login))) ws.Route(docs.EnrichGetUsersApiDocs(ws.GET("/users").To(h.GetUsers))) ws.Route(docs.EnrichCreateUsersApiDocs(ws.POST("/users").To(h.CreateUsers))) @@ -45,8 +45,7 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichGetUserTokenApiDocs(ws.GET("/user/token").To(h.GetUserToken))) ws.Route(docs.EnrichUpdateUserTokenApiDocs(ws.PUT("/user/token/status").To(h.EnableUserToken))) ws.Route(docs.EnrichResetUserTokenApiDocs(ws.PUT("/user/token/refresh").To(h.ResetUserToken))) - - // 用户组 + // ws.Route(docs.EnrichCreateGroupApiDocs(ws.POST("/usergroup").To(h.CreateGroup))) ws.Route(docs.EnrichUpdateGroupsApiDocs(ws.PUT("/usergroups").To(h.UpdateGroups))) ws.Route(docs.EnrichGetGroupsApiDocs(ws.GET("/usergroups").To(h.GetGroups))) @@ -56,7 +55,6 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichUpdateGroupTokenApiDocs(ws.PUT("/usergroup/token/status").To(h.EnableGroupToken))) ws.Route(docs.EnrichResetGroupTokenApiDocs(ws.PUT("/usergroup/token/refresh").To(h.ResetGroupToken))) - // 鉴权策略 ws.Route(docs.EnrichCreateStrategyApiDocs(ws.POST("/auth/strategy").To(h.CreateStrategy))) ws.Route(docs.EnrichGetStrategyApiDocs(ws.GET("/auth/strategy/detail").To(h.GetStrategy))) ws.Route(docs.EnrichUpdateStrategiesApiDocs(ws.PUT("/auth/strategies").To(h.UpdateStrategies))) @@ -64,12 +62,6 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichGetStrategiesApiDocs(ws.GET("/auth/strategies").To(h.GetStrategies))) ws.Route(docs.EnrichGetPrincipalResourcesApiDocs(ws.GET("/auth/principal/resources").To(h.GetPrincipalResources))) - // 角色 - ws.Route(docs.EnrichGetRolesApiDocs(ws.GET("/roles").To(h.GetRoles))) - ws.Route(docs.EnrichCreateRolesApiDocs(ws.POST("/roles").To(h.CreateRoles))) - ws.Route(docs.EnrichDeleteRolesApiDocs(ws.POST("/roles/delete").To(h.DeleteRoles))) - ws.Route(docs.EnrichUpdateRolesApiDocs(ws.PUT("/roles").To(h.UpdateRoles))) - return nil } @@ -506,79 +498,3 @@ func (h *HTTPServer) GetPrincipalResources(req *restful.Request, rsp *restful.Re handler.WriteHeaderAndProto(h.strategyMgn.GetPrincipalResources(ctx, queryParams)) } - -// CreateRoles . -func (h *HTTPServer) CreateRoles(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - - roles := make([]*apisecurity.Role, 0, 4) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apisecurity.Role{} - roles = append(roles, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - handler.WriteHeaderAndProto(h.strategyMgn.CreateRoles(ctx, roles)) -} - -// UpdateRoles . -func (h *HTTPServer) UpdateRoles(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - - roles := make([]*apisecurity.Role, 0, 4) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apisecurity.Role{} - roles = append(roles, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - handler.WriteHeaderAndProto(h.strategyMgn.UpdateRoles(ctx, roles)) -} - -// DeleteRoles . -func (h *HTTPServer) DeleteRoles(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - - roles := make([]*apisecurity.Role, 0, 4) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apisecurity.Role{} - roles = append(roles, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - handler.WriteHeaderAndProto(h.strategyMgn.DeleteRoles(ctx, roles)) -} - -// GetRoles 查询角色列表 -func (h *HTTPServer) GetRoles(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - - queryParams := httpcommon.ParseQueryParams(req) - ctx := handler.ParseHeaderContext() - - handler.WriteHeaderAndProto(h.strategyMgn.GetRoles(ctx, queryParams)) -} diff --git a/apiserver/httpserver/config/client_access.go b/apiserver/httpserver/config/client_access.go index ffade0a6b..7de65ebc3 100644 --- a/apiserver/httpserver/config/client_access.go +++ b/apiserver/httpserver/config/client_access.go @@ -168,12 +168,7 @@ func (h *HTTPServer) Discover(req *restful.Request, rsp *restful.Response) { switch in.Type { case apiconfig.ConfigDiscoverRequest_CONFIG_FILE: action = metrics.ActionGetConfigFile - ret := h.configServer.GetConfigFileWithCache(ctx, &apiconfig.ClientConfigFileInfo{ - Version: in.GetConfigFile().Version, - Namespace: in.GetConfigFile().GetNamespace(), - Group: in.GetConfigFile().GetGroup(), - FileName: in.GetConfigFile().GetFileName(), - }) + ret := h.configServer.GetConfigFileWithCache(ctx, &apiconfig.ClientConfigFileInfo{}) out = api.NewConfigDiscoverResponse(apimodel.Code(ret.GetCode().GetValue())) out.ConfigFile = ret.GetConfigFile() out.Type = apiconfig.ConfigDiscoverResponse_CONFIG_FILE diff --git a/apiserver/httpserver/discover/v1/console_access.go b/apiserver/httpserver/discover/v1/console_access.go index 86e6dd5ac..12a136e6d 100644 --- a/apiserver/httpserver/discover/v1/console_access.go +++ b/apiserver/httpserver/discover/v1/console_access.go @@ -1303,78 +1303,3 @@ func (h *HTTPServerV1) DeleteServiceContractInterfaces(req *restful.Request, rsp ret := h.namingServer.DeleteServiceContractInterfaces(ctx, msg) handler.WriteHeaderAndProto(ret) } - -// CreateLaneGroups 批量创建泳道组 -func (h *HTTPServerV1) CreateLaneGroups(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - groups := make([]*apitraffic.LaneGroup, 0) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apitraffic.LaneGroup{} - groups = append(groups, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - ret := h.namingServer.CreateLaneGroups(ctx, groups) - handler.WriteHeaderAndProto(ret) -} - -// UpdateLaneGroups 批量更新泳道组 -func (h *HTTPServerV1) UpdateLaneGroups(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - groups := make([]*apitraffic.LaneGroup, 0) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apitraffic.LaneGroup{} - groups = append(groups, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - ret := h.namingServer.UpdateLaneGroups(ctx, groups) - handler.WriteHeaderAndProto(ret) -} - -// DeleteLaneGroups 批量删除泳道组 -func (h *HTTPServerV1) DeleteLaneGroups(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - groups := make([]*apitraffic.LaneGroup, 0) - ctx, err := handler.ParseArray(func() proto.Message { - msg := &apitraffic.LaneGroup{} - groups = append(groups, msg) - return msg - }) - if err != nil { - handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) - return - } - - ret := h.namingServer.DeleteLaneGroups(ctx, groups) - handler.WriteHeaderAndProto(ret) -} - -// GetLaneGroups 批量删除泳道组 -func (h *HTTPServerV1) GetLaneGroups(req *restful.Request, rsp *restful.Response) { - handler := &httpcommon.Handler{ - Request: req, - Response: rsp, - } - queryParams := httpcommon.ParseQueryParams(req) - ctx := handler.ParseHeaderContext() - ret := h.namingServer.GetLaneGroups(ctx, queryParams) - handler.WriteHeaderAndProto(ret) -} diff --git a/apiserver/httpserver/discover/v1/server.go b/apiserver/httpserver/discover/v1/server.go index 42df28ae3..8544e34d9 100644 --- a/apiserver/httpserver/discover/v1/server.go +++ b/apiserver/httpserver/discover/v1/server.go @@ -93,7 +93,6 @@ func (h *HTTPServerV1) GetConsoleAccessServer(include []string) (*restful.WebSer h.addCircuitBreakerRuleAccess(ws) case routingAccess: h.addRoutingRuleAccess(ws) - h.addLaneRuleAccess(ws) case rateLimitAccess: h.addRateLimitRuleAccess(ws) } @@ -140,7 +139,6 @@ func (h *HTTPServerV1) addDefaultAccess(ws *restful.WebService) { // 管理端接口:增删改查请求全部操作存储层 h.addServiceAccess(ws) h.addRoutingRuleAccess(ws) - h.addLaneRuleAccess(ws) h.addRateLimitRuleAccess(ws) h.addCircuitBreakerRuleAccess(ws) } @@ -209,14 +207,6 @@ func (h *HTTPServerV1) addRoutingRuleAccess(ws *restful.WebService) { // Deprecate -- end } -// addLaneRuleAccess 泳道规则 -func (h *HTTPServerV1) addLaneRuleAccess(ws *restful.WebService) { - ws.Route(ws.POST("/lane/groups").To(h.CreateLaneGroups)) - ws.Route(ws.POST("/lane/groups/delete").To(h.DeleteLaneGroups)) - ws.Route(ws.PUT("/lane/groups").To(h.UpdateLaneGroups)) - ws.Route(ws.GET("/lane/groups").To(h.GetLaneGroups)) -} - func (h *HTTPServerV1) addRateLimitRuleAccess(ws *restful.WebService) { ws.Route(docs.EnrichCreateRateLimitsApiDocs(ws.POST("/ratelimits").To(h.CreateRateLimits))) ws.Route(docs.EnrichDeleteRateLimitsApiDocs(ws.POST("/ratelimits/delete").To(h.DeleteRateLimits))) diff --git a/apiserver/httpserver/docs/admin_apidoc.go b/apiserver/httpserver/docs/admin_apidoc.go index 810134d49..928a663c2 100644 --- a/apiserver/httpserver/docs/admin_apidoc.go +++ b/apiserver/httpserver/docs/admin_apidoc.go @@ -149,10 +149,3 @@ func EnrichEnablePprofApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { Enable bool `json:"enable"` }{}) } - -func EnrichGetServerFunctionsApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { - return r. - Doc("查询服务端的接口名称列表"). - Metadata(restfulspec.KeyOpenAPITags, maintainApiTags). - Returns(0, "", map[string][]string{}) -} diff --git a/apiserver/httpserver/docs/auth_apidoc.go b/apiserver/httpserver/docs/auth_apidoc.go index 28b9dc11a..f25ba9a68 100644 --- a/apiserver/httpserver/docs/auth_apidoc.go +++ b/apiserver/httpserver/docs/auth_apidoc.go @@ -26,8 +26,7 @@ import ( var ( authApiTags = []string{"AuthRule"} usersApiTags = []string{"Users"} - userGroupApiTags = []string{"UserGroups"} - roleApiTags = []string{"Roles"} + userGroupApiTags = []string{"Users"} ) func EnrichAuthStatusApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { @@ -351,43 +350,3 @@ func EnrichResetGroupTokenApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder UserGroup apisecurity.UserGroup `json:"userGroup"` }{}) } - -func EnrichCreateRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { - return r. - Doc("批量创建角色"). - Metadata(restfulspec.KeyOpenAPITags, roleApiTags). - Reads([]apisecurity.Role{}, "batch create role"). - Returns(0, "", struct { - BaseResponse - }{}) -} - -func EnrichUpdateRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { - return r. - Doc("批量更新角色"). - Metadata(restfulspec.KeyOpenAPITags, roleApiTags). - Reads([]apisecurity.Role{}, "batch update role"). - Returns(0, "", struct { - BaseResponse - }{}) -} - -func EnrichDeleteRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { - return r. - Doc("批量删除角色"). - Metadata(restfulspec.KeyOpenAPITags, roleApiTags). - Reads([]apisecurity.Role{}, "batch delete role"). - Returns(0, "", struct { - BaseResponse - }{}) -} - -func EnrichGetRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { - return r. - Doc("查询角色列表"). - Metadata(restfulspec.KeyOpenAPITags, roleApiTags). - Reads([]apisecurity.Role{}, "query roles"). - Returns(0, "", struct { - BaseResponse - }{}) -} diff --git a/apiserver/httpserver/utils/handler.go b/apiserver/httpserver/utils/handler.go index 02be2a841..4aa2ad794 100644 --- a/apiserver/httpserver/utils/handler.go +++ b/apiserver/httpserver/utils/handler.go @@ -71,13 +71,15 @@ func (h *Handler) ParseArrayByText(createMessage func() proto.Message, text stri func (h *Handler) parseArray(createMessage func() proto.Message, jsonDecoder *json.Decoder) (context.Context, error) { requestID := h.Request.HeaderParameter("Request-Id") // read open bracket - if _, err := jsonDecoder.Token(); err != nil { + _, err := jsonDecoder.Token() + if err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, err } for jsonDecoder.More() { protoMessage := createMessage() - if err := UnmarshalNext(jsonDecoder, protoMessage); err != nil { + err := UnmarshalNext(jsonDecoder, protoMessage) + if err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, err } @@ -284,7 +286,7 @@ func (h *Handler) WriteHeaderAndProto(obj api.ResponseMessage) { status := api.CalcCode(obj) if status != http.StatusOK { - accesslog.Error(h.Request.Request.RequestURI+" "+obj.String(), utils.ZapRequestID(requestID)) + accesslog.Error(obj.String(), utils.ZapRequestID(requestID)) } if code := obj.GetCode().GetValue(); code != api.ExecuteSuccess { h.Response.AddHeader(utils.PolarisCode, fmt.Sprintf("%d", code)) @@ -315,8 +317,9 @@ func (h *Handler) WriteHeaderAndProtoV2(obj api.ResponseMessageV2) { h.Response.AddHeader(utils.PolarisRequestID, requestID) h.Response.WriteHeader(status) - m := newJsonpbMarshaler() - if err := m.Marshal(h.Response, obj); err != nil { + m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + err := m.Marshal(h.Response, obj) + if err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) } } @@ -377,18 +380,14 @@ func ParseJsonBody(req *restful.Request, value interface{}) error { return nil } -func newJsonpbMarshaler() jsonpb.Marshaler { - return jsonpb.Marshaler{Indent: " ", EmitDefaults: true} -} - func (h *Handler) handleResponse(obj api.ResponseMessage) error { if !enableProtoCache { - m := newJsonpbMarshaler() + m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} return m.Marshal(h.Response, obj) } cacheVal := convert(obj) if cacheVal == nil { - m := newJsonpbMarshaler() + m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} return m.Marshal(h.Response, obj) } if saveVal := protoCache.Get(cacheVal.CacheType, cacheVal.Key); saveVal != nil { @@ -402,7 +401,7 @@ func (h *Handler) handleResponse(obj api.ResponseMessage) error { if err := cacheVal.Marshal(obj); err != nil { accesslog.Warn("[Api-http][ProtoCache] prepare message fail, direct send msg", zap.String("key", cacheVal.Key), zap.Error(err)) - m := newJsonpbMarshaler() + m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} return m.Marshal(h.Response, obj) } @@ -410,7 +409,7 @@ func (h *Handler) handleResponse(obj api.ResponseMessage) error { if !ok || cacheVal == nil { accesslog.Warn("[Api-http][ProtoCache] put cache ignore", zap.String("key", cacheVal.Key), zap.String("cacheType", cacheVal.CacheType)) - m := newJsonpbMarshaler() + m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} return m.Marshal(h.Response, obj) } if len(cacheVal.GetBuf()) > 0 { diff --git a/apiserver/nacosserver/core/storage.go b/apiserver/nacosserver/core/storage.go index 746af9e6c..c7826443f 100644 --- a/apiserver/nacosserver/core/storage.go +++ b/apiserver/nacosserver/core/storage.go @@ -197,7 +197,7 @@ func (n *NacosDataStorage) syncTask() { // 计算需要 refresh 的服务信息列表 for _, ns := range nsList { - _, svcs := n.cacheMgr.Service().ListServices(context.Background(), ns.Name) + _, svcs := n.cacheMgr.Service().ListServices(ns.Name) for _, svc := range svcs { revision := n.cacheMgr.Service().GetRevisionWorker().GetServiceInstanceRevision(svc.ID) oldRevision, ok := n.revisions[svc.ID] diff --git a/apiserver/nacosserver/model/service.go b/apiserver/nacosserver/model/service.go index 86666d307..c3e8569c8 100644 --- a/apiserver/nacosserver/model/service.go +++ b/apiserver/nacosserver/model/service.go @@ -18,7 +18,6 @@ package model import ( - "context" "strings" "github.com/polarismesh/polaris/service" @@ -44,7 +43,7 @@ type ServiceMetadata struct { func HandleServiceListRequest(discoverSvr service.DiscoverServer, namespace string, groupName string, pageNo int, pageSize int) ([]string, int) { - _, services := discoverSvr.Cache().Service().ListServices(context.Background(), namespace) + _, services := discoverSvr.Cache().Service().ListServices(namespace) offset := (pageNo - 1) * pageSize limit := pageSize if offset < 0 { diff --git a/apiserver/nacosserver/server.go b/apiserver/nacosserver/server.go index 43437ebac..913cc3281 100644 --- a/apiserver/nacosserver/server.go +++ b/apiserver/nacosserver/server.go @@ -148,7 +148,7 @@ func copyOption(m map[string]interface{}) map[string]interface{} { func (n *NacosServer) initPolarisResource() error { var err error - n.namespaceSvr, err = namespace.GetOriginServer() + n.namespaceSvr, err = namespace.GetServer() if err != nil { return err } diff --git a/apiserver/nacosserver/v1/config/access.go b/apiserver/nacosserver/v1/config/access.go index 0d3a6d65f..ed484fdda 100644 --- a/apiserver/nacosserver/v1/config/access.go +++ b/apiserver/nacosserver/v1/config/access.go @@ -142,7 +142,7 @@ func (n *ConfigServer) ConfigImport(req *restful.Request, rsp *restful.Response) var metaDataItem *ZipItem var items = make([]*ZipItem, 0, 32) - err := handler.ProcessZip(func(f *zip.File, data []byte) { + handler.ProcessZip(func(f *zip.File, data []byte) { if (f.Name == ConfigExportMetadata || f.Name == ConfigExpotrMetadataV2) && metaDataItem == nil { metaDataItem = &ZipItem{ Name: f.Name, @@ -155,10 +155,6 @@ func (n *ConfigServer) ConfigImport(req *restful.Request, rsp *restful.Response) Data: data, }) }) - if err != nil { - nacoshttp.WrirteNacosErrorResponse(err, rsp) - return - } policy := req.QueryParameter("policy") diff --git a/apiserver/nacosserver/v1/discover/instance.go b/apiserver/nacosserver/v1/discover/instance.go index 9e632a25b..87db26324 100644 --- a/apiserver/nacosserver/v1/discover/instance.go +++ b/apiserver/nacosserver/v1/discover/instance.go @@ -77,8 +77,8 @@ func (n *DiscoverServer) handleUpdate(ctx context.Context, namespace, serviceNam return nil } -func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, svcName string, ins *model.Instance) error { - specIns := model.PrepareSpecInstance(namespace, svcName, ins) +func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, service string, ins *model.Instance) error { + specIns := model.PrepareSpecInstance(namespace, service, ins) resp := n.discoverSvr.DeregisterInstance(ctx, specIns) if apimodel.Code(resp.GetCode().GetValue()) != apimodel.Code_ExecuteSuccess { return &model.NacosError{ @@ -90,19 +90,19 @@ func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, svcNam } // handleBeat com.alibaba.nacos.naming.core.InstanceOperatorClientImpl#handleBeat -func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, svcName string, +func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, service string, clientBeat *model.ClientBeat) (map[string]interface{}, error) { - svcName = model.ReplaceNacosService(svcName) - svc := n.discoverSvr.Cache().Service().GetServiceByName(svcName, namespace) + service = model.ReplaceNacosService(service) + svc := n.discoverSvr.Cache().Service().GetServiceByName(service, namespace) if svc == nil { return nil, &model.NacosError{ ErrCode: int32(model.ExceptionCode_ServerError), - ErrMsg: "service not found: " + svcName + "@" + namespace, + ErrMsg: "service not found: " + service + "@" + namespace, } } resp := n.healthSvr.Report(ctx, &apiservice.Instance{ - Service: utils.NewStringValue(model.ReplaceNacosService(svcName)), + Service: utils.NewStringValue(model.ReplaceNacosService(service)), Namespace: utils.NewStringValue(namespace), Host: utils.NewStringValue(clientBeat.Ip), Port: utils.NewUInt32Value(uint32(clientBeat.Port)), @@ -136,7 +136,7 @@ func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, svcName stri func (n *DiscoverServer) handleQueryInstances(ctx context.Context, params map[string]string) (interface{}, error) { namespace := params[model.ParamNamespaceID] group := model.GetGroupName(params[model.ParamServiceName]) - svcName := model.GetServiceName(params[model.ParamServiceName]) + service := model.GetServiceName(params[model.ParamServiceName]) clusters := params["clusters"] clientIP := params["clientIP"] udpPort, _ := strconv.ParseInt(params["udpPort"], 10, 32) @@ -151,14 +151,14 @@ func (n *DiscoverServer) handleQueryInstances(ctx context.Context, params map[st Port: int(udpPort), NamespaceId: namespace, Group: group, - Service: svcName, + Service: service, Cluster: clusters, Type: core.UDPCPush, }) } filterCtx := &core.FilterContext{ - Service: core.ToNacosService(n.discoverSvr.Cache(), namespace, svcName, group), + Service: core.ToNacosService(n.discoverSvr.Cache(), namespace, service, group), Clusters: strings.Split(clusters, ","), EnableOnly: true, HealthyOnly: healthyOnly, diff --git a/apiserver/nacosserver/v1/endpoints.go b/apiserver/nacosserver/v1/endpoints.go index a9b4cd58d..321679d7b 100644 --- a/apiserver/nacosserver/v1/endpoints.go +++ b/apiserver/nacosserver/v1/endpoints.go @@ -69,5 +69,5 @@ func (n *NacosV1Server) FetchNacosEndpoints(req *restful.Request, rsp *restful.R } rsp.WriteHeader(http.StatusOK) - _, _ = rsp.Write([]byte(strings.Join(ips, "\n"))) + rsp.Write([]byte(strings.Join(ips, "\n"))) } diff --git a/auth/api.go b/auth/api.go index ebc0227f3..56fe7fdf0 100644 --- a/auth/api.go +++ b/auth/api.go @@ -166,14 +166,10 @@ type UserHelper interface { // PolicyHelper . type PolicyHelper interface { - // GetPolicyRule . - GetPolicyRule(id string) *authcommon.StrategyDetail // CreatePrincipal 创建 principal 的默认 policy 资源 CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error // CleanPrincipal 清理 principal 所关联的 policy、role 资源 CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error - // GetRole . - GetRole(id string) *authcommon.Role } // OperatorInfo 根据 token 解析出来的具体额外信息 @@ -194,7 +190,7 @@ type OperatorInfo struct { Anonymous bool } -func NewAnonymousOperatorInfo() OperatorInfo { +func NewAnonymous() OperatorInfo { return OperatorInfo{ Origin: "", OwnerID: "", diff --git a/auth/policy/auth_checker.go b/auth/policy/auth_checker.go index c8efc8d32..0a25a5958 100644 --- a/auth/policy/auth_checker.go +++ b/auth/policy/auth_checker.go @@ -18,8 +18,6 @@ package policy import ( - "context" - "encoding/json" "strings" "github.com/pkg/errors" @@ -94,12 +92,8 @@ func (d *DefaultAuthChecker) IsOpenAuth() bool { // AllowResourceOperate 是否允许资源的操作 func (d *DefaultAuthChecker) ResourcePredicate(ctx *authcommon.AcquireContext, res *authcommon.ResourceEntry) bool { - // 如果是客户端请求,并且鉴权能力没有开启,那就默认都可以进行操作 - if ctx.IsFromClient() && !d.IsOpenClientAuth() { - return true - } - // 如果是控制台请求,并且鉴权能力没有开启,那就默认都可以进行操作 - if ctx.IsFromConsole() && !d.IsOpenConsoleAuth() { + // 如果鉴权能力没有开启,那就默认都可以进行操作 + if !d.IsOpenAuth() { return true } @@ -107,16 +101,7 @@ func (d *DefaultAuthChecker) ResourcePredicate(ctx *authcommon.AcquireContext, r if !ok { return false } - policyCache := d.cacheMgr.AuthStrategy() - - principals := d.listAllPrincipals(p.(authcommon.Principal)) - for i := range principals { - ret := policyCache.Hint(ctx.GetRequestContext(), principals[i], res) - if ret != apisecurity.AuthAction_DENY { - return true - } - } - return false + return d.cacheMgr.AuthStrategy().Hint(p.(authcommon.Principal), res) != apisecurity.AuthAction_DENY } // CheckClientPermission 执行检查客户端动作判断是否有权限,并且对 RequestContext 注入操作者数据 @@ -144,12 +129,22 @@ func (d *DefaultAuthChecker) CheckConsolePermission(preCtx *authcommon.AcquireCo } // CheckPermission 执行检查动作判断是否有权限 +// +// step 1. 判断是否开启了鉴权 +// step 2. 对token进行检查判断 +// case 1. 如果 token 被禁用 +// a. 读操作,直接放通 +// b. 写操作,快速失败 +// step 3. 拉取token对应的操作者相关信息,注入到请求上下文中 +// step 4. 进行权限检查 func (d *DefaultAuthChecker) CheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { if err := d.userSvr.CheckCredential(authCtx); err != nil { return false, err } - log.Info("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), - zap.Any("method", authCtx.GetMethods()), zap.Any("resources", authCtx.GetAccessResources())) + if log.DebugEnabled() { + log.Debug("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), + zap.String("method", string(authCtx.GetMethod())), zap.Any("resources", authCtx.GetAccessResources())) + } if pass, _ := d.doCheckPermission(authCtx); pass { return true, nil @@ -164,11 +159,11 @@ func (d *DefaultAuthChecker) CheckPermission(authCtx *authcommon.AcquireContext) func (d *DefaultAuthChecker) resyncData(authCtx *authcommon.AcquireContext) error { if err := d.cacheMgr.AuthStrategy().Update(); err != nil { - log.Error("[Auth][Checker] force sync policy failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + log.Error("[Auth][Checker] force sync policy rule to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) return err } if err := d.cacheMgr.Role().Update(); err != nil { - log.Error("[Auth][Checker] force sync role failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + log.Error("[Auth][Checker] force sync role to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) return err } return nil @@ -176,63 +171,32 @@ func (d *DefaultAuthChecker) resyncData(authCtx *authcommon.AcquireContext) erro // doCheckPermission 执行权限检查 func (d *DefaultAuthChecker) doCheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { + p, _ := authCtx.GetAttachments()[authcommon.PrincipalKey].(authcommon.Principal) if d.IsCredible(authCtx) { return true, nil } - cur := authCtx.GetAttachments()[authcommon.PrincipalKey].(authcommon.Principal) - - principals := d.listAllPrincipals(cur) - // 遍历所有的 principal,检查是否有一个符合要求 - for i := range principals { - principal := principals[i] - allowPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("allow", principal) - denyPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("deny", principal) + allowPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("allow", p) + denyPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("deny", p) - resources := authCtx.GetAccessResources() + resources := authCtx.GetAccessResources() - // 先执行 deny 策略 - for i := range denyPolicies { - item := denyPolicies[i] - if d.MatchPolicy(authCtx, item, principal, resources) { - return false, ErrorNotPermission - } - } - - // 处理 allow 策略,只要有一个放开,就可以认为通过 - for i := range allowPolicies { - item := allowPolicies[i] - if d.MatchPolicy(authCtx, item, principal, resources) { - return true, nil - } + // 先执行 deny 策略 + for i := range denyPolicies { + item := denyPolicies[i] + if d.MatchPolicy(authCtx, item, p, resources) { + return false, ErrorNotPermission } } - return false, ErrorNotPermission -} - -func (d *DefaultAuthChecker) listAllPrincipals(p authcommon.Principal) []authcommon.Principal { - principals := make([]authcommon.Principal, 0, 4) - principals = append(principals, p) - // 获取角色列表 - roles := d.cacheMgr.Role().GetPrincipalRoles(p) - for i := range roles { - principals = append(principals, authcommon.Principal{ - PrincipalID: roles[i].ID, - PrincipalType: authcommon.PrincipalRole, - }) - } - // 如果是用户,获取所在的用户组列表 - if p.PrincipalType == authcommon.PrincipalUser { - groups := d.cacheMgr.User().GetUserLinkGroupIds(p.PrincipalID) - for i := range groups { - principals = append(principals, authcommon.Principal{ - PrincipalID: groups[i], - PrincipalType: authcommon.PrincipalGroup, - }) + // 处理 allow 策略,只要有一个放开,就可以认为通过 + for i := range allowPolicies { + item := allowPolicies[i] + if d.MatchPolicy(authCtx, item, p, resources) { + return true, nil } } - return principals + return false, ErrorNotPermission } // IsCredible 检查是否是可信的请求 @@ -263,13 +227,12 @@ func (d *DefaultAuthChecker) IsCredible(authCtx *authcommon.AcquireContext) bool func (d *DefaultAuthChecker) MatchPolicy(authCtx *authcommon.AcquireContext, policy *authcommon.StrategyDetail, principal authcommon.Principal, resources map[apisecurity.ResourceType][]authcommon.ResourceEntry) bool { if !d.MatchCalleeFunctions(authCtx, principal, policy) { - log.Error("server function match policy fail", utils.RequestID(authCtx.GetRequestContext()), - zap.String("principal", principal.String()), zap.String("policy-id", policy.ID)) return false } if !d.MatchResourceOperateable(authCtx, principal, policy) { - log.Error("access resource match policy fail", utils.RequestID(authCtx.GetRequestContext()), - zap.String("principal", principal.String()), zap.String("policy-id", policy.ID)) + return false + } + if !d.MatchResourceConditions(authCtx, principal, policy) { return false } return true @@ -278,92 +241,66 @@ func (d *DefaultAuthChecker) MatchPolicy(authCtx *authcommon.AcquireContext, pol // MatchCalleeFunctions 检查操作方法是否和策略匹配 func (d *DefaultAuthChecker) MatchCalleeFunctions(authCtx *authcommon.AcquireContext, principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { - - // 如果开启了兼容模式,并且策略没有对可调用方法的拦截,那么就认为匹配成功 - if d.conf.Compatible && len(policy.CalleeMethods) == 0 { - return true - } - functions := policy.CalleeMethods - - allMatch := 0 - for _, method := range authCtx.GetMethods() { - curMatch := false - for i := range functions { - if utils.IsMatchAll(functions[i]) { - return true - } - if functions[i] == string(method) { - curMatch = true - break - } - if utils.IsWildMatch(string(method), functions[i]) { - curMatch = true - break - } + for i := range functions { + if functions[i] == string(authCtx.GetMethod()) { + return true } - if curMatch { - allMatch++ + if utils.IsWildMatch(string(authCtx.GetMethod()), functions[i]) { + return true } } - return allMatch == len(authCtx.GetMethods()) + return false } -type ( - compatibleChecker func(ctx context.Context, cacheSvr cachetypes.CacheManager, resource *authcommon.ResourceEntry) bool -) - -var ( - compatibleResource = map[apisecurity.ResourceType]compatibleChecker{ - apisecurity.ResourceType_UserGroups: func(ctx context.Context, cacheSvr cachetypes.CacheManager, - resource *authcommon.ResourceEntry) bool { - saveVal := cacheSvr.User().GetGroup(resource.ID) - if saveVal == nil { - return false - } - operator := utils.ParseUserID(ctx) - _, exist := saveVal.UserIds[operator] - return exist - }, - apisecurity.ResourceType_PolicyRules: func(ctx context.Context, cacheSvr cachetypes.CacheManager, - resource *authcommon.ResourceEntry) bool { - saveVal := cacheSvr.AuthStrategy().GetPolicyRule(resource.ID) - if saveVal == nil { - return false - } - operator := utils.ParseUserID(ctx) - for i := range saveVal.Principals { - if saveVal.Principals[i].PrincipalID == operator { - return true - } - } - return false - }, - } -) - // checkAction 检查操作资源是否和策略匹配 func (d *DefaultAuthChecker) MatchResourceOperateable(authCtx *authcommon.AcquireContext, principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { + matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { + for i := range resources { + actionResult := d.cacheMgr.AuthStrategy().Hint(principal, &resources[i]) + if actionResult.String() == policy.Action { + return true + } + } + return false + } - // 检查下 principal 有没有 condition 信息 - principalCondition := make([]authcommon.Condition, 0, 4) - // 这里主要兼容一些内部特殊场景,可能在 role/user/group 关联某个策略时,会有一些额外的关系属性,这里在 extend 统一查找 - _ = json.Unmarshal([]byte(principal.Extend["condition"]), &principalCondition) - - ctx := context.Background() - if len(principalCondition) != 0 { - ctx = context.WithValue(context.Background(), authcommon.ContextKeyConditions{}, principalCondition) + reqRes := authCtx.GetAccessResources() + isMatch := false + for k, v := range reqRes { + if isMatch = matchCheck(k, v); isMatch { + break + } } + return isMatch +} +// MatchResourceConditions 检查操作资源所拥有的标签是否和策略匹配 +func (d *DefaultAuthChecker) MatchResourceConditions(authCtx *authcommon.AcquireContext, + principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { + conditions := policy.Conditions + for i := range resources { - actionResult := d.cacheMgr.AuthStrategy().Hint(ctx, principal, &resources[i]) - if policy.IsMatchAction(actionResult.String()) { - return true + allMatch := true + for j := range conditions { + condition := conditions[j] + resVal, ok := resources[i].Metadata[condition.Key] + if !ok { + allMatch = false + break + } + compareFunc, ok := conditionCompareDict[condition.CompareFunc] + if !ok { + allMatch = false + break + } + if allMatch = compareFunc(resVal, condition.Value); !allMatch { + break + } } - // 兼容模式下,对于用户组和策略规则,走一遍兜底的检查逻辑 - if _, ok := compatibleResource[resType]; ok && d.conf.Compatible { + if allMatch { return true } } @@ -371,10 +308,19 @@ func (d *DefaultAuthChecker) MatchResourceOperateable(authCtx *authcommon.Acquir } reqRes := authCtx.GetAccessResources() - isMatch := true + isMatch := false for k, v := range reqRes { - subMatch := matchCheck(k, v) - isMatch = isMatch && subMatch + if isMatch = matchCheck(k, v); isMatch { + break + } } return isMatch } + +var ( + conditionCompareDict = map[string]func(string, string) bool{ + "for_any_value:string_equal": func(s1, s2 string) bool { + return s1 == s2 + }, + } +) diff --git a/auth/policy/auth_checker_test.go b/auth/policy/auth_checker_test.go new file mode 100644 index 000000000..4668fc661 --- /dev/null +++ b/auth/policy/auth_checker_test.go @@ -0,0 +1,1136 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/auth/policy" + defaultuser "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" +) + +func newPolicyServer() (*policy.Server, auth.StrategyServer, error) { + return policy.BuildServer() +} + +func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) { + reset(false) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + users := createMockUser(10) + groups := createMockUserGroup(users) + + namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) + services := createMockService(namespaces) + serviceMap := convertServiceSliceToMap(services) + strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) + + cfg, storage := initCache(ctrl) + + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) + storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + cancel() + cacheMgr.Close() + }) + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgr) + + _, svr, err := newPolicyServer() + if err != nil { + t.Fatal(err) + } + if err := svr.Initialize(&auth.Config{ + Strategy: &auth.StrategyConfig{ + Name: auth.DefaultPolicyPluginName, + }, + }, storage, cacheMgr, proxySvr); err != nil { + t.Fatal(err) + } + checker := svr.GetAuthChecker() + + _ = cacheMgr.TestUpdate() + + freeIndex := len(users) + len(groups) + 1 + + t.Run("权限检查非严格模式-主账户资源访问检查", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户资源访问检查(无操作权限)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查非严格模式-子账户资源访问检查(有操作权限)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户资源访问检查(资源无绑定策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户访问用户组资源检查(属于用户组成员)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[(len(users)-1)+2].ID, + Owner: services[(len(users)-1)+2].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户访问用户组资源检查(不属于用户组成员)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[(len(users)-1)+4].ID, + Owner: services[(len(users)-1)+4].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查非严格模式-用户组访问组内成员资源检查", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groups[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(groups[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查非严格模式-token非法-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "users[1].Token") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-token为空-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) +} + +func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { + reset(true) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + users := createMockUser(10) + groups := createMockUserGroup(users) + + namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) + services := createMockService(namespaces) + serviceMap := convertServiceSliceToMap(services) + strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) + + cfg, storage := initCache(ctrl) + + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) + storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + cancel() + cacheMgr.Close() + }) + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgr) + + _, svr, err := newPolicyServer() + if err != nil { + t.Fatal(err) + } + if err := svr.Initialize(&auth.Config{ + Strategy: &auth.StrategyConfig{ + Name: auth.DefaultPolicyPluginName, + }, + }, storage, cacheMgr, proxySvr); err != nil { + t.Fatal(err) + } + checker := svr.GetAuthChecker() + + _ = cacheMgr.TestUpdate() + + freeIndex := len(users) + len(groups) + 1 + + t.Run("权限检查严格模式-主账户操作资源", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-子账户操作资源(无操作权限)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-子账户操作资源(有操作权限)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源有策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源有策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(""), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithOperation(authcommon.Create), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源没有策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + + dchecker := checker.(*policy.DefaultAuthChecker) + oldConf := dchecker.GetConfig() + defer func() { + dchecker.SetConfig(oldConf) + }() + dchecker.SetConfig(&policy.AuthConfig{ + ConsoleOpen: true, + ConsoleStrict: true, + }) + + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源没有策略)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + dchecker := checker.(*policy.DefaultAuthChecker) + oldConf := dchecker.GetConfig() + defer func() { + dchecker.SetConfig(oldConf) + }() + dchecker.SetConfig(&policy.AuthConfig{ + ConsoleOpen: true, + ConsoleStrict: true, + }) + + _, err = dchecker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) +} + +func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) { + reset(false) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + users := createMockUser(10) + groups := createMockUserGroup(users) + + namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) + services := createMockService(namespaces) + serviceMap := convertServiceSliceToMap(services) + strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) + + cfg, storage := initCache(ctrl) + + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) + storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + cancel() + cacheMgr.Close() + }) + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgr) + + _, svr, err := newPolicyServer() + if err != nil { + t.Fatal(err) + } + if err := svr.Initialize(&auth.Config{ + Strategy: &auth.StrategyConfig{ + Name: auth.DefaultPolicyPluginName, + }, + }, storage, cacheMgr, proxySvr); err != nil { + t.Fatal(err) + } + checker := svr.GetAuthChecker() + + _ = cacheMgr.TestUpdate() + + freeIndex := len(users) + len(groups) + 1 + + t.Run("权限检查非严格模式-主账户正常读操作", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) +} + +func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { + reset(true) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + users := createMockUser(10) + groups := createMockUserGroup(users) + + namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) + services := createMockService(namespaces) + serviceMap := convertServiceSliceToMap(services) + strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) + + cfg, storage := initCache(ctrl) + + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) + storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) + storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + cancel() + cacheMgr.Close() + }) + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgr) + + _, svr, err := newPolicyServer() + if err != nil { + t.Fatal(err) + } + if err := svr.Initialize(&auth.Config{ + Strategy: &auth.StrategyConfig{ + Name: auth.DefaultPolicyPluginName, + }, + }, storage, cacheMgr, proxySvr); err != nil { + t.Fatal(err) + } + checker := svr.GetAuthChecker() + dchecker := checker.(*policy.DefaultAuthChecker) + oldConf := dchecker.GetConfig() + defer func() { + dchecker.SetConfig(oldConf) + }() + dchecker.SetConfig(&policy.AuthConfig{ + ConsoleOpen: true, + ConsoleStrict: true, + }) + + _ = cacheMgr.TestUpdate() + + freeIndex := len(users) + len(groups) + 1 + + t.Run("权限检查严格模式-主账户正常读操作", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[0].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[1].ID, + Owner: services[1].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(users[1].Token), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.NoError(t, err, "Should be verify success") + }) + + t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken(""), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[0].ID, + Owner: services[0].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) + + t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { + ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), + // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + ID: services[freeIndex].ID, + Owner: services[freeIndex].Owner, + }, + }, + }), + ) + _, err = checker.CheckConsolePermission(authCtx) + t.Logf("%+v", err) + assert.Error(t, err, "Should be verify fail") + }) +} + +func Test_DefaultAuthChecker_Initialize(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + t.Run("使用未迁移至auth.user.option及auth.strategy.option的配置", func(t *testing.T) { + reset(true) + authChecker := &policy.Server{} + cfg := &auth.Config{} + cfg.SetDefault() + cfg.Name = "" + cfg.Option = map[string]interface{}{ + "consoleOpen": true, + "clientOpen": true, + "salt": "polarismesh@2021", + "strict": false, + } + err := authChecker.ParseOptions(cfg) + assert.NoError(t, err) + assert.Equal(t, &policy.AuthConfig{ + ConsoleOpen: true, + ClientOpen: true, + Strict: false, + ConsoleStrict: false, + ClientStrict: false, + }, authChecker.GetOptions()) + }) + + t.Run("使用完全迁移至auth.user.option及auth.strategy.option的配置", func(t *testing.T) { + reset(true) + authChecker := &policy.Server{} + + cfg := &auth.Config{} + cfg.SetDefault() + cfg.User = &auth.UserConfig{ + Name: "", + Option: map[string]interface{}{"salt": "polarismesh@2021"}, + } + cfg.Strategy = &auth.StrategyConfig{ + Name: "", + Option: map[string]interface{}{ + "consoleOpen": true, + "clientOpen": true, + "strict": false, + }, + } + + err := authChecker.ParseOptions(cfg) + assert.NoError(t, err) + assert.Equal(t, &policy.AuthConfig{ + ConsoleOpen: true, + ConsoleStrict: false, + ClientOpen: true, + Strict: false, + }, authChecker.GetOptions()) + }) + + t.Run("使用部分迁移至auth.user.option及auth.strategy.option的配置(应当报错)", func(t *testing.T) { + reset(true) + authChecker := &policy.Server{} + cfg := &auth.Config{} + cfg.SetDefault() + cfg.Name = "" + cfg.Option = map[string]interface{}{ + "clientOpen": true, + "strict": false, + } + cfg.User = &auth.UserConfig{ + Name: "", + Option: map[string]interface{}{"salt": "polarismesh@2021"}, + } + cfg.Strategy = &auth.StrategyConfig{ + Name: "", + Option: map[string]interface{}{ + "consoleOpen": true, + }, + } + + err := authChecker.ParseOptions(cfg) + assert.NoError(t, err) + }) + +} + +func TestDefaultAuthChecker_isCredible(t *testing.T) { + type fields struct { + conf *policy.AuthConfig + cacheMgr cachetypes.CacheManager + userSvr auth.UserServer + } + type args struct { + authCtx *authcommon.AcquireContext + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &policy.DefaultAuthChecker{} + d.SetConfig(tt.fields.conf) + if got := d.IsCredible(tt.args.authCtx); got != tt.want { + t.Errorf("DefaultAuthChecker.isCredible() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/auth/policy/common_test.go b/auth/policy/common_test.go new file mode 100644 index 000000000..72b0e02ef --- /dev/null +++ b/auth/policy/common_test.go @@ -0,0 +1,403 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "fmt" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "golang.org/x/crypto/bcrypt" + "google.golang.org/protobuf/types/known/wrapperspb" + + defaultuser "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + "github.com/polarismesh/polaris/common/metrics" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +func reset(strict bool) { + +} + +func initCache(ctrl *gomock.Controller) (*cache.Config, *storemock.MockStore) { + metrics.InitMetrics() + /* + - name: service # 加载服务数据 + option: + disableBusiness: false # 不加载业务服务 + needMeta: true # 加载服务元数据 + - name: instance # 加载实例数据 + option: + disableBusiness: false # 不加载业务服务实例 + needMeta: true # 加载实例元数据 + - name: routingConfig # 加载路由数据 + - name: rateLimitConfig # 加载限流数据 + - name: circuitBreakerConfig # 加载熔断数据 + - name: l5 # 加载l5数据 + - name: users + - name: strategyRule + - name: namespace + */ + cfg := &cache.Config{} + storage := storemock.NewMockStore(ctrl) + + mockTx := storemock.NewMockTx(ctrl) + mockTx.EXPECT().Commit().Return(nil).AnyTimes() + mockTx.EXPECT().Rollback().Return(nil).AnyTimes() + mockTx.EXPECT().CreateReadView().Return(nil).AnyTimes() + + storage.EXPECT().StartReadTx().Return(mockTx, nil).AnyTimes() + storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetInstancesCountTx(gomock.Any()).AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetMoreInstances(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]*model.Instance{ + "123": { + Proto: &service_manage.Instance{ + Id: wrapperspb.String(uuid.NewString()), + Host: wrapperspb.String("127.0.0.1"), + Port: wrapperspb.UInt32(8080), + }, + Valid: true, + }, + }, nil).AnyTimes() + storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) + + return cfg, storage +} + +func createMockNamespace(total int, owner string) []*model.Namespace { + namespaces := make([]*model.Namespace, 0, total) + + for i := 0; i < total; i++ { + namespaces = append(namespaces, &model.Namespace{ + Name: fmt.Sprintf("namespace_%d", i), + Owner: owner, + Valid: true, + }) + } + + return namespaces +} + +func createMockService(namespaces []*model.Namespace) []*model.Service { + services := make([]*model.Service, 0, len(namespaces)) + + for i := 0; i < len(namespaces); i++ { + ns := namespaces[i] + services = append(services, &model.Service{ + ID: utils.NewUUID(), + Namespace: ns.Name, + Owner: ns.Owner, + Name: fmt.Sprintf("service_%d", i), + Valid: true, + }) + } + + return services +} + +func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { + strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + + owner := "" + for i := 0; i < len(users); i++ { + user := users[i] + if user.Owner == "" { + owner = user.ID + break + } + } + + for i := 0; i < len(users); i++ { + user := users[i] + service := services[i] + id := utils.NewUUID() + strategies = append(strategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: user.ID, + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: false, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: user.ID, + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: true, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + + for i := 0; i < len(groups); i++ { + group := groups[i] + service := services[len(users)+i] + id := utils.NewUUID() + strategies = append(strategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: group.ID, + PrincipalType: authcommon.PrincipalGroup, + }, + }, + Default: false, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: group.ID, + PrincipalType: authcommon.PrincipalGroup, + }, + }, + Default: true, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + + return defaultStrategies, strategies +} + +func convertServiceSliceToMap(services []*model.Service) map[string]*model.Service { + ret := make(map[string]*model.Service) + + for i := range services { + service := services[i] + ret[service.ID] = service + } + + return ret +} + +// createMockUser 默认 users[0] 为 owner 用户 +func createMockUser(total int, prefix ...string) []*authcommon.User { + users := make([]*authcommon.User, 0, total) + + ownerId := utils.NewUUID() + + nameTemp := "user-%d" + if len(prefix) != 0 { + nameTemp = prefix[0] + nameTemp + } + + for i := 0; i < total; i++ { + id := fmt.Sprintf("fake-user-id-%d-%s", i, utils.NewUUID()) + if i == 0 { + id = ownerId + } + pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) + token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") + users = append(users, &authcommon.User{ + ID: id, + Name: fmt.Sprintf(nameTemp, i), + Password: string(pwd), + Owner: func() string { + if id == ownerId { + return "" + } + return ownerId + }(), + Source: "Polaris", + Mobile: "", + Email: "", + Type: func() authcommon.UserRoleType { + if id == ownerId { + return authcommon.OwnerUserRole + } + return authcommon.SubAccountUserRole + }(), + Token: token, + TokenEnable: true, + Valid: true, + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + return users +} + +func createApiMockUser(total int, prefix ...string) []*apisecurity.User { + users := make([]*apisecurity.User, 0, total) + + models := createMockUser(total, prefix...) + + for i := range models { + users = append(users, &apisecurity.User{ + Name: utils.NewStringValue("test-" + models[i].Name), + Password: utils.NewStringValue("123456"), + Source: utils.NewStringValue("Polaris"), + Comment: utils.NewStringValue(models[i].Comment), + Mobile: utils.NewStringValue(models[i].Mobile), + Email: utils.NewStringValue(models[i].Email), + }) + } + + return users +} + +func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { + groups := make([]*authcommon.UserGroupDetail, 0, len(users)) + + for i := range users { + user := users[i] + id := utils.NewUUID() + + token, _ := defaultuser.CreateToken("", id, "polarismesh@2021") + groups = append(groups, &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ + ID: id, + Name: fmt.Sprintf("test-group-%d", i), + Owner: users[0].ID, + Token: token, + TokenEnable: true, + Valid: true, + Comment: "", + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }, + UserIds: map[string]struct{}{ + user.ID: {}, + }, + }) + } + + return groups +} + +// createMockApiUserGroup +func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { + musers := make([]*authcommon.User, 0, len(users)) + for i := range users { + musers = append(musers, &authcommon.User{ + ID: users[i].GetId().GetValue(), + }) + } + + models := createMockUserGroup(musers) + ret := make([]*apisecurity.UserGroup, 0, len(models)) + + for i := range models { + ret = append(ret, &apisecurity.UserGroup{ + Name: utils.NewStringValue(models[i].Name), + Comment: utils.NewStringValue(models[i].Comment), + Relation: &apisecurity.UserGroupRelation{ + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(users[i].GetId().GetValue()), + }, + }, + }, + }) + } + + return ret +} diff --git a/auth/policy/default.go b/auth/policy/default.go index d9b372902..ccff9599b 100644 --- a/auth/policy/default.go +++ b/auth/policy/default.go @@ -29,7 +29,7 @@ import ( type ServerProxyFactory func(svr *Server, pre auth.StrategyServer) (auth.StrategyServer, error) var ( - // serverProxyFactories auth.StrategyServer API 代理工厂 + // serverProxyFactories auth.UserServer API 代理工厂 serverProxyFactories = map[string]ServerProxyFactory{} ) diff --git a/auth/policy/helper.go b/auth/policy/helper.go index 41eb1676d..e77f14de1 100644 --- a/auth/policy/helper.go +++ b/auth/policy/helper.go @@ -1,32 +1,14 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - package policy import ( "context" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" ) type DefaultPolicyHelper struct { @@ -36,16 +18,12 @@ type DefaultPolicyHelper struct { checker auth.AuthChecker } -func (h *DefaultPolicyHelper) GetRole(id string) *authcommon.Role { - return h.cacheMgr.Role().GetRole(id) -} - -func (h *DefaultPolicyHelper) GetPolicyRule(id string) *authcommon.StrategyDetail { - return h.cacheMgr.AuthStrategy().GetPolicyRule(id) -} - // CreatePrincipal 创建 principal 的默认 policy 资源 func (h *DefaultPolicyHelper) CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { + if !h.options.OpenPrincipalDefaultPolicy { + return nil + } + if err := h.storage.AddStrategy(tx, defaultPrincipalPolicy(p)); err != nil { return err } @@ -54,50 +32,25 @@ func (h *DefaultPolicyHelper) CreatePrincipal(ctx context.Context, tx store.Tx, func defaultPrincipalPolicy(p authcommon.Principal) *authcommon.StrategyDetail { // Create the user's default weight policy - ruleId := utils.NewUUID() - - resources := []authcommon.StrategyResource{} - if p.PrincipalType == authcommon.PrincipalUser { - resources = append(resources, authcommon.StrategyResource{ - StrategyID: ruleId, - ResType: int32(apisecurity.ResourceType_Users), - ResID: p.PrincipalID, - }) - } - return &authcommon.StrategyDetail{ - ID: ruleId, - Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalUser, p.Name), - Action: apisecurity.AuthAction_ALLOW.String(), - Default: true, - Owner: p.Owner, - Revision: utils.NewUUID(), - Source: "Polaris", - Resources: resources, - Principals: []authcommon.Principal{p}, - CalleeMethods: []string{ - // 用户操作权限 - string(authcommon.DescribeUsers), - string(authcommon.DescribeUserToken), - string(authcommon.UpdateUser), - string(authcommon.UpdateUserPassword), - string(authcommon.EnableUserToken), - string(authcommon.ResetUserToken), - // 鉴权策略 - string(authcommon.DescribeAuthPolicies), - string(authcommon.DescribeAuthPolicyDetail), - // 角色 - string(authcommon.DescribeAuthRoles), - }, - Valid: true, - Comment: "default principal auth policy rule", + ID: utils.NewUUID(), + Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalUser, p.Name), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Default: true, + Owner: p.Owner, + Revision: utils.NewUUID(), + Resources: []authcommon.StrategyResource{}, + Valid: true, + Comment: "Default Strategy", } } // CleanPrincipal 清理 principal 所关联的 policy、role 资源 func (h *DefaultPolicyHelper) CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { - if err := h.storage.CleanPrincipalPolicies(tx, p); err != nil { - return err + if h.options.OpenPrincipalDefaultPolicy { + if err := h.storage.CleanPrincipalPolicies(tx, p); err != nil { + return err + } } if err := h.storage.CleanPrincipalRoles(tx, &p); err != nil { diff --git a/auth/policy/inteceptor/auth/server.go b/auth/policy/inteceptor/auth/server.go index c80c3acf9..128046b18 100644 --- a/auth/policy/inteceptor/auth/server.go +++ b/auth/policy/inteceptor/auth/server.go @@ -27,9 +27,7 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -80,21 +78,17 @@ func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.Aut resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.CreateStrategy(authCtx.GetRequestContext(), strategy) + return svr.nextSvr.CreateStrategy(ctx, strategy) } // UpdateStrategies 批量更新策略 func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { resources := make([]authcommon.ResourceEntry, 0, len(reqs)) for i := range reqs { - entry := authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: reqs[i].GetId().GetValue(), - } - if saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(reqs[i].GetId().GetValue()); saveRule != nil { - entry.Metadata = saveRule.Metadata - } - resources = append(resources, entry) + item := reqs[i] + resources = append(resources, authcommon.ResourceEntry{ + ID: item.GetId().GetValue(), + }) } authCtx := authcommon.NewAcquireContext( @@ -102,30 +96,23 @@ func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.Mod authcommon.WithOperation(authcommon.Modify), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.UpdateAuthPolicies), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_PolicyRules: resources, - }), ) if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.UpdateStrategies(authCtx.GetRequestContext(), reqs) + return svr.nextSvr.UpdateStrategies(ctx, reqs) } // DeleteStrategies 删除策略 func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { resources := make([]authcommon.ResourceEntry, 0, len(reqs)) for i := range reqs { - entry := authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: reqs[i].GetId().GetValue(), - } - if saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(reqs[i].GetId().GetValue()); saveRule != nil { - entry.Metadata = saveRule.Metadata - } - resources = append(resources, entry) + item := reqs[i] + resources = append(resources, authcommon.ResourceEntry{ + ID: item.GetId().GetValue(), + }) } authCtx := authcommon.NewAcquireContext( @@ -133,16 +120,13 @@ func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.Aut authcommon.WithOperation(authcommon.Delete), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.DeleteAuthPolicies), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_PolicyRules: resources, - }), ) if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.DeleteStrategies(authCtx.GetRequestContext(), reqs) + return svr.nextSvr.DeleteStrategies(ctx, reqs) } // GetStrategies 获取资源列表 @@ -156,62 +140,44 @@ func (svr *Server) GetStrategies(ctx context.Context, query map[string]string) * authcommon.WithMethod(authcommon.DescribeAuthPolicies), ) - checker := svr.GetAuthChecker() - if _, err := checker.CheckConsolePermission(authCtx); err != nil { + if err := svr.userSvr.CheckCredential(authCtx); err != nil { return api.NewAuthBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - ctx = authCtx.GetRequestContext() - ctx = cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { - ok := checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: sd.ID, - Metadata: sd.Metadata, + checker := svr.GetAuthChecker() + cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { + return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: sd.ID, }) - if ok { - return true - } - // 兼容老版本的策略查询逻辑 - if compatible, _ := ctx.Value(model.ContextKeyCompatible{}).(bool); compatible { - for i := range sd.Principals { - if sd.Principals[i].PrincipalID == utils.ParseUserID(ctx) { - return true - } - } - } - return false }) - authCtx.SetRequestContext(ctx) return svr.nextSvr.GetStrategies(ctx, query) } // GetStrategy 获取策略详细 func (svr *Server) GetStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { - entry := authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: strategy.GetId().GetValue(), - } - saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(strategy.GetId().GetValue()) - if saveRule != nil { - entry.Metadata = saveRule.Metadata - } - authCtx := authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithOperation(authcommon.Read), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.DescribeAuthPolicyDetail), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_PolicyRules: {entry}, - }), ) checker := svr.GetAuthChecker() + if _, err := checker.CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetStrategy(authCtx.GetRequestContext(), strategy) + + cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { + return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: sd.ID, + }) + }) + + return svr.nextSvr.GetStrategy(ctx, strategy) } // GetPrincipalResources 获取某个 principal 的所有可操作资源列表 @@ -228,7 +194,7 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s if _, err := checker.CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetPrincipalResources(authCtx.GetRequestContext(), query) + return svr.nextSvr.GetPrincipalResources(ctx, query) } // GetAuthChecker 获取鉴权检查器 @@ -243,103 +209,20 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e // CreateRoles 批量创建角色 func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.AuthModule), - authcommon.WithMethod(authcommon.CreateAuthRoles), - ) - - if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp - } - return svr.nextSvr.CreateRoles(authCtx.GetRequestContext(), reqs) + return nil } // UpdateRoles 批量更新角色 func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - resources := make([]authcommon.ResourceEntry, 0, len(reqs)) - for i := range reqs { - entry := authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Roles, - ID: reqs[i].GetId(), - } - if saveRule := svr.nextSvr.PolicyHelper().GetRole(reqs[i].GetId()); saveRule != nil { - entry.Metadata = saveRule.Metadata - } - resources = append(resources, entry) - } - - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(authcommon.Modify), - authcommon.WithModule(authcommon.AuthModule), - authcommon.WithMethod(authcommon.UpdateAuthRoles), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Roles: resources, - }), - ) - - if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp - } - return svr.nextSvr.UpdateRoles(authCtx.GetRequestContext(), reqs) + return nil } // DeleteRoles 批量删除角色 func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - resources := make([]authcommon.ResourceEntry, 0, len(reqs)) - for i := range reqs { - entry := authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Roles, - ID: reqs[i].GetId(), - } - if saveRule := svr.nextSvr.PolicyHelper().GetRole(reqs[i].GetId()); saveRule != nil { - entry.Metadata = saveRule.Metadata - } - resources = append(resources, entry) - } - - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(authcommon.Modify), - authcommon.WithModule(authcommon.AuthModule), - authcommon.WithMethod(authcommon.DeleteAuthRoles), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Roles: resources, - }), - ) - - if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp - } - return svr.nextSvr.DeleteRoles(authCtx.GetRequestContext(), reqs) + return nil } // GetRoles 查询角色列表 func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.AuthModule), - authcommon.WithMethod(authcommon.DescribeAuthRoles), - ) - - if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewAuthBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - } - - checker := svr.GetAuthChecker() - ctx = cachetypes.AppendAuthRolePredicate(ctx, func(ctx context.Context, sd *authcommon.Role) bool { - return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Roles, - ID: sd.ID, - Metadata: sd.Metadata, - }) - }) - - return svr.nextSvr.GetRoles(ctx, query) + return nil } diff --git a/auth/policy/inteceptor/paramcheck/server.go b/auth/policy/inteceptor/paramcheck/server.go index 3e707201c..3ae25a934 100644 --- a/auth/policy/inteceptor/paramcheck/server.go +++ b/auth/policy/inteceptor/paramcheck/server.go @@ -31,7 +31,6 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/log" authcommon "github.com/polarismesh/polaris/common/model/auth" - commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -60,10 +59,8 @@ func NewServer(nextSvr auth.StrategyServer) auth.StrategyServer { } type Server struct { - storage store.Store - cacheMgr cachetypes.CacheManager - nextSvr auth.StrategyServer - userSvr auth.UserServer + nextSvr auth.StrategyServer + userSvr auth.UserServer } // PolicyHelper implements auth.StrategyServer. @@ -74,8 +71,6 @@ func (svr *Server) PolicyHelper() auth.PolicyHelper { // Initialize 执行初始化动作 func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr auth.UserServer) error { svr.userSvr = userSvr - svr.cacheMgr = cacheMgr - svr.storage = storage return svr.nextSvr.Initialize(options, storage, cacheMgr, userSvr) } @@ -85,30 +80,12 @@ func (svr *Server) Name() string { } // CreateStrategy 创建策略 -func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { - if err := svr.checkCreateStrategy(req); err != nil { - return err - } - return svr.nextSvr.CreateStrategy(ctx, req) +func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { + return svr.nextSvr.CreateStrategy(ctx, strategy) } // UpdateStrategies 批量更新策略 func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { - batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - var rsp *apiservice.Response - strategy, err := svr.storage.GetStrategyDetail(reqs[i].GetId().GetValue()) - if err != nil { - log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), zap.Error(err)) - rsp = api.NewModifyAuthStrategyResponse(commonstore.StoreCode2APICode(err), reqs[i]) - } - if strategy == nil { - continue - } else { - rsp = svr.checkUpdateStrategy(ctx, reqs[i], strategy) - } - api.Collect(batchResp, rsp) - } return svr.nextSvr.UpdateStrategies(ctx, reqs) } @@ -181,140 +158,3 @@ func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *a func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetRoles(ctx, query) } - -// checkCreateStrategy 检查创建鉴权策略的请求 -func (svr *Server) checkCreateStrategy(req *apisecurity.AuthStrategy) *apiservice.Response { - // 检查名称信息 - if err := CheckName(req.GetName()); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_InvalidUserName, req) - } - // 检查用户是否存在 - if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetPrincipals())); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUser, req) - } - // 检查用户组是否存在 - if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetPrincipals())); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) - } - // 检查资源是否存在 - if errResp := svr.checkResourceExist(req.GetResources()); errResp != nil { - return errResp - } - return nil -} - -// checkUpdateStrategy 检查更新鉴权策略的请求 -// Case 1. 修改的是默认鉴权策略的话,只能修改资源,不能添加用户 or 用户组 -// Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 -func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy, - saved *authcommon.StrategyDetail) *apiservice.Response { - if saved.Default { - if len(req.AddPrincipals.Users) != 0 || - len(req.AddPrincipals.Groups) != 0 || - len(req.RemovePrincipals.Groups) != 0 || - len(req.RemovePrincipals.Users) != 0 { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowModifyDefaultStrategyPrincipal, req) - } - - // 主账户的默认策略禁止编辑 - if len(saved.Principals) == 1 && saved.Principals[0].PrincipalType == authcommon.PrincipalUser { - if saved.Principals[0].PrincipalID == utils.ParseOwnerID(ctx) { - return api.NewAuthResponse(apimodel.Code_NotAllowModifyOwnerDefaultStrategy) - } - } - } - - // 检查用户是否存在 - if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetAddPrincipals())); err != nil { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUser, req) - } - - // 检查用户组是否存 - if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetAddPrincipals())); err != nil { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) - } - - // 检查资源是否存在 - if errResp := svr.checkResourceExist(req.GetAddResources()); errResp != nil { - return errResp - } - return nil -} - -// checkUserExist 检查用户是否存在 -func (svr *Server) checkUserExist(users []*apisecurity.User) error { - if len(users) == 0 { - return nil - } - return svr.userSvr.GetUserHelper().CheckUsersExist(context.TODO(), users) -} - -// checkUserGroupExist 检查用户组是否存在 -func (svr *Server) checkGroupExist(groups []*apisecurity.UserGroup) error { - if len(groups) == 0 { - return nil - } - return svr.userSvr.GetUserHelper().CheckGroupsExist(context.TODO(), groups) -} - -// checkResourceExist 检查资源是否存在 -func (svr *Server) checkResourceExist(resources *apisecurity.StrategyResources) *apiservice.Response { - namespaces := resources.GetNamespaces() - - nsCache := svr.cacheMgr.Namespace() - for index := range namespaces { - val := namespaces[index] - if val.GetId().GetValue() == "*" { - break - } - if ns := nsCache.GetNamespace(val.GetId().GetValue()); ns == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundNamespace) - } - } - - services := resources.GetServices() - svcCache := svr.cacheMgr.Service() - for index := range services { - val := services[index] - if val.GetId().GetValue() == "*" { - break - } - if svc := svcCache.GetServiceByID(val.GetId().GetValue()); svc == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundService) - } - } - - return nil -} - -func convertPrincipalsToUsers(principals *apisecurity.Principals) []*apisecurity.User { - if principals == nil { - return make([]*apisecurity.User, 0) - } - - users := make([]*apisecurity.User, 0, len(principals.Users)) - for k := range principals.GetUsers() { - user := principals.GetUsers()[k] - users = append(users, &apisecurity.User{ - Id: user.Id, - }) - } - - return users -} - -func convertPrincipalsToGroups(principals *apisecurity.Principals) []*apisecurity.UserGroup { - if principals == nil { - return make([]*apisecurity.UserGroup, 0) - } - - groups := make([]*apisecurity.UserGroup, 0, len(principals.Groups)) - for k := range principals.GetGroups() { - group := principals.GetGroups()[k] - groups = append(groups, &apisecurity.UserGroup{ - Id: group.Id, - }) - } - - return groups -} diff --git a/auth/policy/inteceptor/paramcheck/utils.go b/auth/policy/inteceptor/paramcheck/utils.go deleted file mode 100644 index c055016a2..000000000 --- a/auth/policy/inteceptor/paramcheck/utils.go +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package paramcheck - -import ( - "errors" - "regexp" - "unicode/utf8" - - "github.com/golang/protobuf/ptypes/wrappers" - - "github.com/polarismesh/polaris/common/utils" -) - -var ( - regNameStr = regexp.MustCompile("^[\u4E00-\u9FA5A-Za-z0-9_\\-.]+$") - regEmail = regexp.MustCompile(`^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$`) -) - -// CheckName 名称检查 -func CheckName(name *wrappers.StringValue) error { - if name == nil { - return errors.New(utils.NilErrString) - } - - if name.GetValue() == "" { - return errors.New(utils.EmptyErrString) - } - - if name.GetValue() == "polariadmin" { - return errors.New("illegal username") - } - - if utf8.RuneCountInString(name.GetValue()) > utils.MaxNameLength { - return errors.New("name too long") - } - - if ok := regNameStr.MatchString(name.GetValue()); !ok { - return errors.New("name contains invalid character") - } - - return nil -} diff --git a/auth/policy/main_test.go b/auth/policy/main_test.go new file mode 100644 index 000000000..4d5302ba7 --- /dev/null +++ b/auth/policy/main_test.go @@ -0,0 +1,215 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "errors" + + _ "github.com/go-sql-driver/mysql" + bolt "go.etcd.io/bbolt" + + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/cache" + _ "github.com/polarismesh/polaris/cache" + api "github.com/polarismesh/polaris/common/api/v1" + commonlog "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/namespace" + "github.com/polarismesh/polaris/plugin" + _ "github.com/polarismesh/polaris/plugin/cmdb/memory" + _ "github.com/polarismesh/polaris/plugin/discoverevent/local" + _ "github.com/polarismesh/polaris/plugin/healthchecker/memory" + _ "github.com/polarismesh/polaris/plugin/healthchecker/redis" + _ "github.com/polarismesh/polaris/plugin/history/logger" + _ "github.com/polarismesh/polaris/plugin/password" + _ "github.com/polarismesh/polaris/plugin/ratelimit/token" + _ "github.com/polarismesh/polaris/plugin/statis/logger" + _ "github.com/polarismesh/polaris/plugin/statis/prometheus" + "github.com/polarismesh/polaris/service/healthcheck" + "github.com/polarismesh/polaris/store" + "github.com/polarismesh/polaris/store/boltdb" + _ "github.com/polarismesh/polaris/store/boltdb" + _ "github.com/polarismesh/polaris/store/mysql" + sqldb "github.com/polarismesh/polaris/store/mysql" + testsuit "github.com/polarismesh/polaris/test/suit" +) + +const ( + tblUser string = "user" + tblStrategy string = "strategy" + tblGroup string = "group" +) + +type Bootstrap struct { + Logger map[string]*commonlog.Options +} + +type TestConfig struct { + Bootstrap Bootstrap `yaml:"bootstrap"` + Cache cache.Config `yaml:"cache"` + Namespace namespace.Config `yaml:"namespace"` + HealthChecks healthcheck.Config `yaml:"healthcheck"` + Store store.Config `yaml:"store"` + Auth auth.Config `yaml:"auth"` + Plugin plugin.Config `yaml:"plugin"` +} + +type AuthTestSuit struct { + testsuit.DiscoverTestSuit +} + +// 判断一个resp是否执行成功 +func respSuccess(resp api.ResponseMessage) bool { + + ret := api.CalcCode(resp) == 200 + + return ret +} + +type options func(cfg *TestConfig) + +func (d *AuthTestSuit) cleanAllUser() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from user where name like 'test%'"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblUser)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} + +func (d *AuthTestSuit) cleanAllUserGroup() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from user_group where name like 'test%'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from user_group_relation"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblGroup)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} + +func (d *AuthTestSuit) cleanAllAuthStrategy() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from auth_strategy where id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from auth_principal where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from auth_strategy_resource where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblStrategy)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} diff --git a/auth/policy/role.go b/auth/policy/role.go index 213167334..2417883ba 100644 --- a/auth/policy/role.go +++ b/auth/policy/role.go @@ -1,201 +1,28 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - package policy import ( "context" - "fmt" - "time" - "github.com/gogo/protobuf/jsonpb" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "go.uber.org/zap" - - cachetypes "github.com/polarismesh/polaris/cache/api" - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - commonstore "github.com/polarismesh/polaris/common/store" - "github.com/polarismesh/polaris/common/utils" ) // CreateRoles 批量创建角色 func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := svr.CreateRole(ctx, reqs[i]) - api.Collect(responses, rsp) - } - return api.FormatBatchWriteResponse(responses) -} - -// CreateRole 创建角色 -func (svr *Server) CreateRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { - req.Owner = utils.ParseOwnerID(ctx) - - saveData := &authcommon.Role{} - saveData.FromSpec(req) - - if err := svr.storage.AddRole(saveData); err != nil { - log.Error("[Auth][Role] create role into store", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - - return api.NewResponse(apimodel.Code_ExecuteSuccess) + return nil } // UpdateRoles 批量更新角色 func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := svr.UpdateRole(ctx, reqs[i]) - api.Collect(responses, rsp) - } - return api.FormatBatchWriteResponse(responses) -} - -// UpdateRole 批量更新角色 -func (svr *Server) UpdateRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { - newData := &authcommon.Role{} - newData.FromSpec(req) - - saveData, err := svr.storage.GetRole(newData.ID) - if err != nil { - log.Error("[Auth][Role] get one role from store", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - if saveData == nil { - log.Error("[Auth][Role] not find expect role", utils.RequestID(ctx), - zap.String("id", newData.ID)) - return api.NewAuthResponse(apimodel.Code_NotFoundResource) - } - - newData.Name = saveData.Name - newData.Owner = saveData.Owner - - if err := svr.storage.AddRole(newData); err != nil { - log.Error("[Auth][Role] update role into store", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - - return api.NewResponse(apimodel.Code_ExecuteSuccess) + return nil } // DeleteRoles 批量删除角色 func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := svr.DeleteRole(ctx, reqs[i]) - api.Collect(responses, rsp) - } - return api.FormatBatchWriteResponse(responses) -} - -// DeleteRole 批量删除角色 -func (svr *Server) DeleteRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { - newData := &authcommon.Role{} - newData.FromSpec(req) - - saveData, err := svr.storage.GetRole(newData.ID) - if err != nil { - log.Error("[Auth][Role] get one role from store", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - if saveData == nil { - return api.NewAuthResponse(apimodel.Code_ExecuteSuccess) - } - - tx, err := svr.storage.StartTx() - if err != nil { - log.Error("[Auth][Role] start tx", utils.RequestID(ctx), zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - defer func() { - _ = tx.Rollback() - }() - - if err := svr.storage.DeleteRole(tx, newData); err != nil { - log.Error("[Auth][Role] update role into store", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - if err := svr.storage.CleanPrincipalPolicies(tx, authcommon.Principal{ - PrincipalID: saveData.ID, - PrincipalType: authcommon.PrincipalRole, - }); err != nil { - log.Error("[Auth][Role] clean role link policies", utils.RequestID(ctx), - zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - - if err := tx.Commit(); err != nil { - log.Error("[Auth][Role] delete role commit tx", utils.RequestID(ctx), zap.Error(err)) - return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) - } - - return api.NewResponse(apimodel.Code_ExecuteSuccess) + return nil } // GetRoles 查询角色列表 -func (svr *Server) GetRoles(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { - offset, limit, _ := utils.ParseOffsetAndLimit(filters) - - total, ret, err := svr.cacheMgr.Role().Query(ctx, cachetypes.RoleSearchArgs{ - Filters: filters, - Offset: offset, - Limit: limit, - }) - if err != nil { - log.Error("[Auth][Role] query roles list", utils.RequestID(ctx), zap.Error(err)) - return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) - } - - rsp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) - rsp.Amount = utils.NewUInt32Value(total) - rsp.Size = utils.NewUInt32Value(uint32(len(ret))) - - for i := range ret { - if err := api.AddAnyDataIntoBatchQuery(rsp, ret[i].ToSpec()); err != nil { - log.Error("[Auth][Role] add role to query list", utils.RequestID(ctx), zap.Error(err)) - return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) - } - } - return rsp -} - -func recordRoleEntry(ctx context.Context, req *apisecurity.Role, data *authcommon.Role, op model.OperationType) *model.RecordEntry { - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - - entry := &model.RecordEntry{ - ResourceType: model.RAuthRole, - ResourceName: fmt.Sprintf("%s(%s)", data.Name, data.ID), - OperationType: op, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - - return entry +func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + return nil } diff --git a/auth/policy/server.go b/auth/policy/server.go index fffcae44d..5603f51c0 100644 --- a/auth/policy/server.go +++ b/auth/policy/server.go @@ -39,8 +39,6 @@ import ( // AuthConfig 鉴权配置 type AuthConfig struct { - // Compatible 兼容模式 - Compatible bool `json:"compatible" xml:"compatible"` // ConsoleOpen 控制台是否开启鉴权 ConsoleOpen bool `json:"consoleOpen" xml:"consoleOpen"` // ClientOpen 是否开启客户端接口鉴权 @@ -54,13 +52,13 @@ type AuthConfig struct { ClientStrict bool `json:"clientStrict"` // CredibleHeaders 可信请求 Header CredibleHeaders map[string]string + // OpenPrincipalDefaultPolicy 是否开启 principal 默认策略 + OpenPrincipalDefaultPolicy bool `json:"openPrincipalDefaultPolicy"` } // DefaultAuthConfig 返回一个默认的鉴权配置 func DefaultAuthConfig() *AuthConfig { return &AuthConfig{ - // 针对旧鉴权逻辑的兼容模式 - Compatible: true, // 针对控制台接口,默认开启鉴权操作 ConsoleOpen: true, // 这里默认开启 OpenAPI 的强 Token 检查模式 @@ -106,9 +104,7 @@ func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMg checker := &DefaultAuthChecker{ policyMgr: svr, } - if err := checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr); err != nil { - return err - } + checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr) svr.checker = checker return nil } @@ -221,7 +217,8 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e log.Info("[Auth][Server] add resource to principal default strategy", zap.Any("resource", afterCtx.GetAttachments()[authcommon.ResourceAttachmentKey]), - zap.Any("add_user", addUserIds), zap.Any("add_group", addGroupIds), zap.Any("remove_user", removeUserIds), + zap.Any("add_user", addUserIds), + zap.Any("add_group", addGroupIds), zap.Any("remove_user", removeUserIds), zap.Any("remove_group", removeGroupIds), ) @@ -244,6 +241,7 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e log.Error("[Auth][Server] remove group link resource", zap.Error(err)) return err } + return nil } diff --git a/auth/policy/server_test.go b/auth/policy/server_test.go new file mode 100644 index 000000000..325113678 --- /dev/null +++ b/auth/policy/server_test.go @@ -0,0 +1,302 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/auth" + authmock "github.com/polarismesh/polaris/auth/mock" + "github.com/polarismesh/polaris/auth/policy" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +func Test_AfterResourceOperation(t *testing.T) { + svr := &policy.Server{} + + t.Run("not_need_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext()) + assert.NoError(t, err) + }) + + t.Run("read_op", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Read), + )) + assert.NoError(t, err) + }) + + t.Run("from_client_not_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + )) + assert.NoError(t, err) + }) + + t.Run("from_console_not_auth", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromConsole(), + )) + assert.NoError(t, err) + }) + + t.Run("not_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + err := svr.AfterResourceOperation(authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + )) + assert.NoError(t, err) + }) + + t.Run("invalid_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + ctx.SetAttachment(authcommon.TokenDetailInfoKey, map[string]string{}) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("empty_token_detial", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + t.Run("origin_empty", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: "", + }) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("is_anonymous", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: "123", + Anonymous: true, + }) + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + }) + + t.Run("change_principal_policy", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockChecker := authmock.NewMockAuthChecker(ctrl) + svr.MockAuthChecker(mockChecker) + mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() + mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() + + ctx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Create), + authcommon.WithFromClient(), + ) + + ownerId := "mock_auth_owner" + curUserId := "123" + + t.Run("user", func(t *testing.T) { + ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: curUserId, + OperatorID: curUserId, + OwnerID: ownerId, + Role: authcommon.OwnerUserRole, + IsUserToken: true, + }) + + initMockAcquireContext(ctx) + + t.Run("not_found_user", func(t *testing.T) { + userSvr := authmock.NewMockUserServer(ctrl) + mockHelper := authmock.NewMockUserHelper(ctrl) + + userSvr.EXPECT().GetUserHelper().Return(mockHelper) + mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(nil) + + svr.MockUserServer(userSvr) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "not found target user", err.Error()) + }) + + t.Run("found_user", func(t *testing.T) { + userSvr := authmock.NewMockUserServer(ctrl) + mockHelper := authmock.NewMockUserHelper(ctrl) + + userSvr.EXPECT().GetUserHelper().Return(mockHelper).AnyTimes() + mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(&security.User{ + Id: wrapperspb.String(curUserId), + Owner: wrapperspb.String(ownerId), + }).AnyTimes() + + svr.MockUserServer(userSvr) + + t.Run("store_has_err", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock_err")) + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "mock_err", err.Error()) + }) + + t.Run("not_found_default_policy", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.Error(t, err) + assert.Equal(t, "not found default strategy rule", err.Error()) + }) + + t.Run("not_op_resource", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) + svr.MockStore(mockStore) + + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("invalid_op_resource", func(t *testing.T) { + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + + mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) + svr.MockStore(mockStore) + + ctx.SetAttachment(authcommon.ResourceAttachmentKey, map[string]interface{}{}) + + err := svr.AfterResourceOperation(ctx) + assert.NoError(t, err) + }) + + t.Run("delete_resource", func(t *testing.T) { + delCtx := authcommon.NewAcquireContext( + authcommon.WithOperation(authcommon.Delete), + authcommon.WithFromClient(), + ) + delCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ + Origin: curUserId, + OperatorID: curUserId, + OwnerID: ownerId, + Role: authcommon.OwnerUserRole, + IsUserToken: true, + }) + + initMockAcquireContext(delCtx) + + sctrl := gomock.NewController(t) + defer sctrl.Finish() + mockStore := storemock.NewMockStore(sctrl) + mockStore.EXPECT().RemoveStrategyResources(gomock.Any()).DoAndReturn(func(args interface{}) error { + resources := args.([]authcommon.StrategyResource) + for i := range resources { + assert.True(t, resources[i].StrategyID == "", utils.MustJson(resources[i])) + } + return nil + }).Times(1) + + svr.MockStore(mockStore) + err := svr.AfterResourceOperation(delCtx) + assert.NoError(t, err) + }) + }) + }) + }) +} + +func initMockAcquireContext(ctx *authcommon.AcquireContext) { + ctx.SetAttachment(authcommon.LinkUsersKey, []string{}) + ctx.SetAttachment(authcommon.LinkGroupsKey, []string{}) + ctx.SetAttachment(authcommon.RemoveLinkUsersKey, []string{}) + ctx.SetAttachment(authcommon.RemoveLinkGroupsKey, []string{}) +} diff --git a/auth/policy/strategy.go b/auth/policy/strategy.go index 022d39d2c..7e48ed12b 100644 --- a/auth/policy/strategy.go +++ b/auth/policy/strategy.go @@ -20,7 +20,6 @@ package policy import ( "context" "fmt" - "reflect" "strconv" "strings" "time" @@ -44,12 +43,16 @@ import ( type ( // StrategyDetail2Api strategy detail to *apisecurity.AuthStrategy func - StrategyDetail2Api func(ctx context.Context, user *authcommon.StrategyDetail) *apisecurity.AuthStrategy + StrategyDetail2Api func(user *authcommon.StrategyDetail) *apisecurity.AuthStrategy ) // CreateStrategy 创建鉴权策略 func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { req.Owner = utils.NewStringValue(utils.ParseOwnerID(ctx)) + if checkErrResp := svr.checkCreateStrategy(req); checkErrResp != nil { + return checkErrResp + } + req.Resources = svr.normalizeResource(req.Resources) data := svr.createAuthStrategyModel(req) @@ -97,9 +100,11 @@ func (svr *Server) UpdateStrategies( // Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 // Case 3. 主账户的默认策略不得修改 func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + strategy, err := svr.storage.GetStrategyDetail(req.GetId().GetValue()) if err != nil { - log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), + log.Error("[Auth][Strategy] get strategy from store", utils.ZapRequestID(requestID), zap.Error(err)) return api.NewModifyAuthStrategyResponse(commonstore.StoreCode2APICode(err), req) } @@ -107,6 +112,10 @@ func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAu return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundAuthStrategyRule, req) } + if checkErrResp := svr.checkUpdateStrategy(ctx, req, strategy); checkErrResp != nil { + return checkErrResp + } + req.AddResources = svr.normalizeResource(req.AddResources) data, needUpdate := svr.updateAuthStrategyAttribute(ctx, req, strategy) if !needUpdate { @@ -115,11 +124,11 @@ func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAu if err := svr.storage.UpdateStrategy(data); err != nil { log.Error("[Auth][Strategy] update strategy into store", - utils.RequestID(ctx), zap.Error(err)) + utils.ZapRequestID(requestID), zap.Error(err)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } - log.Info("[Auth][Strategy] update strategy into store", utils.RequestID(ctx), + log.Info("[Auth][Strategy] update strategy into store", utils.ZapRequestID(requestID), zap.String("name", strategy.Name)) svr.RecordHistory(authModifyStrategyRecordEntry(ctx, req, data, model.OUpdate)) @@ -142,9 +151,11 @@ func (svr *Server) DeleteStrategies( // Case 1. 只有该策略的 owner 账户可以删除策略 // Case 2. 默认策略不能被删除,默认策略只能随着账户的删除而被清理 func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + strategy, err := svr.storage.GetStrategyDetail(req.GetId().GetValue()) if err != nil { - log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), + log.Error("[Auth][Strategy] get strategy from store", utils.ZapRequestID(requestID), zap.Error(err)) return api.NewAuthStrategyResponse(commonstore.StoreCode2APICode(err), req) } @@ -154,7 +165,7 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra } if strategy.Default { - log.Error("[Auth][Strategy] delete default strategy is denied", utils.RequestID(ctx)) + log.Error("[Auth][Strategy] delete default strategy is denied", utils.ZapRequestID(requestID)) return api.NewAuthStrategyResponseWithMsg(apimodel.Code_BadRequest, "default strategy can't delete", req) } @@ -164,11 +175,11 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra if err := svr.storage.DeleteStrategy(req.GetId().GetValue()); err != nil { log.Error("[Auth][Strategy] delete strategy from store", - utils.RequestID(ctx), zap.Error(err)) + utils.ZapRequestID(requestID), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } - log.Info("[Auth][Strategy] delete strategy from store", utils.RequestID(ctx), + log.Info("[Auth][Strategy] delete strategy from store", utils.ZapRequestID(requestID), zap.String("name", req.Name.GetValue())) svr.RecordHistory(authStrategyRecordEntry(ctx, req, strategy, model.ODelete)) @@ -187,13 +198,10 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra // a. 如果当前是超级管理账户,则按照传入的 query 进行查询即可 // b. 如果当前是主账户,则自动注入 owner 字段,即只能查看策略的 owner 是自己的策略 // c. 如果当前是子账户,则自动注入 principal_id 以及 principal_type 字段,即稚嫩查询与自己有关的策略 + func (svr *Server) GetStrategies(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { filters = ParseStrategySearchArgs(ctx, filters) offset, limit, _ := utils.ParseOffsetAndLimit(filters) - - // 透传兼容模式信息数据 - ctx = context.WithValue(ctx, model.ContextKeyCompatible{}, svr.options.Compatible) - total, strategies, err := svr.cacheMgr.AuthStrategy().Query(ctx, cachetypes.PolicySearchArgs{ Filters: filters, Offset: offset, @@ -211,9 +219,9 @@ func (svr *Server) GetStrategies(ctx context.Context, filters map[string]string) if strings.Compare(filters["show_detail"], "true") == 0 { log.Info("[Auth][Strategy] fill strategy detail", utils.RequestID(ctx)) - resp.AuthStrategies = enhancedAuthStrategy2Api(ctx, strategies, svr.authStrategyFull2Api) + resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategyFull2Api) } else { - resp.AuthStrategies = enhancedAuthStrategy2Api(ctx, strategies, svr.authStrategy2Api) + resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategy2Api) } return resp @@ -250,6 +258,22 @@ func ParseStrategySearchArgs(ctx context.Context, searchFilters map[string]strin searchFilters["principal_type"] = "1" } } + + if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { + // 如果当前账户不是 admin 角色,既不是走资源视角查看,也不是指定principal查看,那么只能查询当前操作用户被关联到的鉴权策略, + if _, ok := searchFilters["res_id"]; !ok { + // 设置 owner 参数,只能查看对应 owner 下的策略 + searchFilters["owner"] = utils.ParseOwnerID(ctx) + if _, ok := searchFilters["principal_id"]; !ok { + // 如果当前不是 owner 角色,那么只能查询与自己有关的策略 + if !utils.ParseIsOwner(ctx) { + searchFilters["principal_id"] = utils.ParseUserID(ctx) + searchFilters["principal_type"] = strconv.Itoa(int(authcommon.PrincipalUser)) + } + } + } + } + return searchFilters } @@ -312,11 +336,12 @@ func (svr *Server) GetStrategy(ctx context.Context, req *apisecurity.AuthStrateg return api.NewAuthStrategyResponse(apimodel.Code_NotAllowedAccess, req) } - return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, svr.authStrategyFull2Api(ctx, ret)) + return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, svr.authStrategyFull2Api(ret)) } // GetPrincipalResources 获取某个principal可以获取到的所有资源ID数据信息 func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) if len(query) == 0 { return api.NewAuthResponse(apimodel.Code_EmptyRequest) } @@ -352,7 +377,7 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s item := groups[i] res, err := svr.storage.GetStrategyResources(item.GetId().GetValue(), authcommon.PrincipalGroup) if err != nil { - log.Error("[Auth][Strategy] get principal link resource", utils.RequestID(ctx), + log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } @@ -362,17 +387,21 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s pResources, err := svr.storage.GetStrategyResources(principalId, authcommon.PrincipalType(principalRole)) if err != nil { - log.Error("[Auth][Strategy] get principal link resource", utils.RequestID(ctx), + log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } resources = append(resources, pResources...) tmp := &apisecurity.AuthStrategy{ - Resources: &apisecurity.StrategyResources{}, + Resources: &apisecurity.StrategyResources{ + Namespaces: make([]*apisecurity.StrategyResourceEntry, 0), + Services: make([]*apisecurity.StrategyResourceEntry, 0), + ConfigGroups: make([]*apisecurity.StrategyResourceEntry, 0), + }, } - svr.enrichResourceInfo(ctx, tmp, &authcommon.StrategyDetail{ + svr.fillResourceInfo(tmp, &authcommon.StrategyDetail{ Resources: resourceDeduplication(resources), }) @@ -380,17 +409,16 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s } // enhancedAuthStrategy2Api -func enhancedAuthStrategy2Api(ctx context.Context, s []*authcommon.StrategyDetail, - fn StrategyDetail2Api) []*apisecurity.AuthStrategy { +func enhancedAuthStrategy2Api(s []*authcommon.StrategyDetail, fn StrategyDetail2Api) []*apisecurity.AuthStrategy { out := make([]*apisecurity.AuthStrategy, 0, len(s)) for k := range s { - out = append(out, fn(ctx, s[k])) + out = append(out, fn(s[k])) } return out } // authStrategy2Api -func (svr *Server) authStrategy2Api(ctx context.Context, s *authcommon.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategy2Api(s *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if s == nil { return nil } @@ -411,7 +439,7 @@ func (svr *Server) authStrategy2Api(ctx context.Context, s *authcommon.StrategyD } // authStrategyFull2Api -func (svr *Server) authStrategyFull2Api(ctx context.Context, data *authcommon.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategyFull2Api(data *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if data == nil { return nil } @@ -437,29 +465,36 @@ func (svr *Server) authStrategyFull2Api(ctx context.Context, data *authcommon.St Mtime: utils.NewStringValue(commontime.Time2String(data.ModifyTime)), Action: apisecurity.AuthAction(apisecurity.AuthAction_value[data.Action]), DefaultStrategy: utils.NewBoolValue(data.Default), - Functions: data.CalleeMethods, - Metadata: data.Metadata, } - svr.enrichPrincipalInfo(out, data) - svr.enrichResourceInfo(ctx, out, data) + svr.fillPrincipalInfo(out, data) + svr.fillResourceInfo(out, data) return out } // createAuthStrategyModel 创建鉴权策略的存储模型 func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) *authcommon.StrategyDetail { - ret := &authcommon.StrategyDetail{} - ret.FromSpec(strategy) + ret := &authcommon.StrategyDetail{ + ID: utils.NewUUID(), + Name: strategy.Name.GetValue(), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: strategy.Comment.GetValue(), + Default: false, + Owner: strategy.Owner.GetValue(), + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Now(), + ModifyTime: time.Now(), + } // 收集涉及的资源信息 resEntry := make([]authcommon.StrategyResource, 0, 20) - for resType, ptrGetter := range resourceFieldPointerGetters { - slicePtr := ptrGetter(strategy.Resources) - if slicePtr.Elem().IsNil() { - continue - } - resEntry = append(resEntry, svr.collectResourceEntry(ret.ID, resType, slicePtr.Elem(), false)...) - } + resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Namespaces, + strategy.GetResources().GetNamespaces(), false)...) + resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Services, + strategy.GetResources().GetServices(), false)...) + resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_ConfigGroups, + strategy.GetResources().GetConfigGroups(), false)...) // 收集涉及的 principal 信息 principals := make([]authcommon.Principal, 0, 20) @@ -467,8 +502,6 @@ func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) * strategy.GetPrincipals().GetUsers())...) principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalGroup, strategy.GetPrincipals().GetGroups())...) - principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalRole, - strategy.GetPrincipals().GetRoles())...) ret.Resources = resEntry ret.Principals = principals @@ -481,14 +514,11 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap saved *authcommon.StrategyDetail) (*authcommon.ModifyStrategyDetail, bool) { var needUpdate bool ret := &authcommon.ModifyStrategyDetail{ - ID: strategy.Id.GetValue(), - Name: saved.Name, - Action: saved.Action, - Comment: saved.Comment, - ModifyTime: time.Now(), - CalleeMethods: saved.CalleeMethods, - Conditions: saved.Conditions, - Metadata: saved.Metadata, + ID: strategy.Id.GetValue(), + Name: saved.Name, + Action: saved.Action, + Comment: saved.Comment, + ModifyTime: time.Now(), } // 只有 owner 可以修改的属性 @@ -510,32 +540,6 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap if computePrincipalChange(ret, strategy) { needUpdate = true } - if strategy.Functions != nil { - needUpdate = true - ret.CalleeMethods = strategy.Functions - } - if strategy.Metadata != nil { - needUpdate = true - ret.Metadata = strategy.Metadata - } - if strategy.Action != saved.GetAction() { - needUpdate = true - ret.Action = strategy.GetAction().String() - } - if strategy.ResourceLabels != nil { - needUpdate = true - ret.Conditions = func() []authcommon.Condition { - conditions := make([]authcommon.Condition, 0, len(strategy.GetResourceLabels())) - for index := range strategy.GetResourceLabels() { - conditions = append(conditions, authcommon.Condition{ - Key: strategy.GetResourceLabels()[index].GetKey(), - Value: strategy.GetResourceLabels()[index].GetValue(), - CompareFunc: strategy.GetResourceLabels()[index].GetCompareType(), - }) - } - return conditions - }() - } return ret, needUpdate } @@ -544,16 +548,13 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap func (svr *Server) computeResourceChange( modify *authcommon.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { var needUpdate bool - - // 收集涉及的资源信息 addResEntry := make([]authcommon.StrategyResource, 0) - for resType, ptrGetter := range resourceFieldPointerGetters { - slicePtr := ptrGetter(strategy.AddResources) - if slicePtr.Elem().IsNil() { - continue - } - addResEntry = append(addResEntry, svr.collectResourceEntry(modify.ID, resType, slicePtr.Elem(), false)...) - } + addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, + strategy.GetAddResources().GetNamespaces(), false)...) + addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, + strategy.GetAddResources().GetServices(), false)...) + addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_ConfigGroups, + strategy.GetAddResources().GetConfigGroups(), false)...) if len(addResEntry) != 0 { needUpdate = true @@ -561,13 +562,12 @@ func (svr *Server) computeResourceChange( } removeResEntry := make([]authcommon.StrategyResource, 0) - for resType, ptrGetter := range resourceFieldPointerGetters { - slicePtr := ptrGetter(strategy.RemoveResources) - if slicePtr.Elem().IsNil() { - continue - } - removeResEntry = append(removeResEntry, svr.collectResourceEntry(modify.ID, resType, slicePtr.Elem(), true)...) - } + removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, + strategy.GetRemoveResources().GetNamespaces(), true)...) + removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, + strategy.GetRemoveResources().GetServices(), true)...) + removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_ConfigGroups, + strategy.GetRemoveResources().GetConfigGroups(), true)...) if len(removeResEntry) != 0 { needUpdate = true @@ -585,8 +585,6 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a strategy.GetAddPrincipals().GetUsers())...) addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetAddPrincipals().GetGroups())...) - addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalRole, - strategy.GetAddPrincipals().GetRoles())...) if len(addPrincipals) != 0 { needUpdate = true @@ -598,8 +596,6 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a strategy.GetRemovePrincipals().GetUsers())...) removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetRemovePrincipals().GetGroups())...) - removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalRole, - strategy.GetRemovePrincipals().GetRoles())...) if len(removePrincipals) != 0 { needUpdate = true @@ -609,26 +605,19 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a return needUpdate } -type pbStringValue interface { - GetValue() string -} - // collectResEntry 将资源ID转换为对应的 []authcommon.StrategyResource 数组 -func (svr *Server) collectResourceEntry(ruleId string, resType apisecurity.ResourceType, - res reflect.Value, delete bool) []authcommon.StrategyResource { - if res.Kind() != reflect.Slice || res.Len() == 0 { - return []authcommon.StrategyResource{} +func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceType, + res []*apisecurity.StrategyResourceEntry, delete bool) []authcommon.StrategyResource { + resEntries := make([]authcommon.StrategyResource, 0, len(res)+1) + if len(res) == 0 { + return resEntries } - resEntries := make([]authcommon.StrategyResource, 0, res.Len()) - for i := 0; i < res.Len(); i++ { - item := res.Index(i).Elem() - resId := item.FieldByName("Id").Interface().(pbStringValue) - resName := item.FieldByName("Name").Interface().(pbStringValue) + for index := range res { // 如果是添加的动作,那么需要进行归一化处理 if !delete { // 归一化处理 - if resId.GetValue() == "*" || resName.GetValue() == "*" { + if res[index].GetId().GetValue() == "*" || res[index].GetName().GetValue() == "*" { return []authcommon.StrategyResource{ { StrategyID: ruleId, @@ -642,7 +631,7 @@ func (svr *Server) collectResourceEntry(ruleId string, resType apisecurity.Resou entry := authcommon.StrategyResource{ StrategyID: ruleId, ResType: int32(resType), - ResID: resId.GetValue(), + ResID: res[index].GetId().GetValue(), } resEntries = append(resEntries, entry) @@ -651,165 +640,94 @@ func (svr *Server) collectResourceEntry(ruleId string, resType apisecurity.Resou return resEntries } -// normalizeResource 对于资源进行归一化处理, 如果出现 * 的话,则该资源访问策略就是 * -func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) *apisecurity.StrategyResources { - if resources == nil { - return &apisecurity.StrategyResources{} - } - for _, ptrGetter := range resourceFieldPointerGetters { - slicePtr := ptrGetter(resources) - if slicePtr.Elem().IsNil() { - continue - } - sliceVal := slicePtr.Elem() - for i := 0; i < sliceVal.Len(); i++ { - item := sliceVal.Index(i).Elem() - resId := item.FieldByName("Id").Interface().(pbStringValue) - if resId.GetValue() == utils.MatchAll { - sliceVal.Set(reflect.ValueOf([]*apisecurity.StrategyResourceEntry{{ - Id: utils.NewStringValue("*"), - }})) - } - } +// collectPrincipalEntry 将 Principal 转换为对应的 []authcommon.Principal 数组 +func collectPrincipalEntry(ruleID string, uType authcommon.PrincipalType, res []*apisecurity.Principal) []authcommon.Principal { + principals := make([]authcommon.Principal, 0, len(res)+1) + if len(res) == 0 { + return principals } - return resources -} -// enrichPrincipalInfo 填充 principal 摘要信息 -func (svr *Server) enrichPrincipalInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { - users := make([]*apisecurity.Principal, 0, len(data.Principals)) - groups := make([]*apisecurity.Principal, 0, len(data.Principals)) - roles := make([]*apisecurity.Principal, 0, len(data.Principals)) - for index := range data.Principals { - principal := data.Principals[index] - switch principal.PrincipalType { - case authcommon.PrincipalUser: - if user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ - Id: wrapperspb.String(principal.PrincipalID), - }); user != nil { - users = append(users, &apisecurity.Principal{ - Id: utils.NewStringValue(user.GetId().GetValue()), - Name: utils.NewStringValue(user.GetName().GetValue()), - }) - } - case authcommon.PrincipalGroup: - if group := svr.userSvr.GetUserHelper().GetGroup(context.TODO(), &apisecurity.UserGroup{ - Id: wrapperspb.String(principal.PrincipalID), - }); group != nil { - groups = append(groups, &apisecurity.Principal{ - Id: utils.NewStringValue(group.GetId().GetValue()), - Name: utils.NewStringValue(group.GetName().GetValue()), - }) - } - case authcommon.PrincipalRole: - if role := svr.PolicyHelper().GetRole(principal.PrincipalID); role != nil { - roles = append(roles, &apisecurity.Principal{ - Id: utils.NewStringValue(role.ID), - Name: utils.NewStringValue(role.Name), - }) - } - } + for index := range res { + principals = append(principals, authcommon.Principal{ + StrategyID: ruleID, + PrincipalID: res[index].GetId().GetValue(), + PrincipalType: uType, + }) } - resp.Principals = &apisecurity.Principals{ - Users: users, - Groups: groups, - Roles: roles, - } + return principals } -// enrichResourceInfo 填充资源摘要信息 -func (svr *Server) enrichResourceInfo(ctx context.Context, resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { - allMatch := map[apisecurity.ResourceType]struct{}{} - resp.Resources = &apisecurity.StrategyResources{ - Namespaces: make([]*apisecurity.StrategyResourceEntry, 0, 4), - ConfigGroups: make([]*apisecurity.StrategyResourceEntry, 0, 4), - Services: make([]*apisecurity.StrategyResourceEntry, 0, 4), - RouteRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), - RatelimitRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), - CircuitbreakerRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), - FaultdetectRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), - LaneRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), - Users: make([]*apisecurity.StrategyResourceEntry, 0, 4), - UserGroups: make([]*apisecurity.StrategyResourceEntry, 0, 4), - Roles: make([]*apisecurity.StrategyResourceEntry, 0, 4), - AuthPolicies: make([]*apisecurity.StrategyResourceEntry, 0, 4), +// checkCreateStrategy 检查创建鉴权策略的请求 +func (svr *Server) checkCreateStrategy(req *apisecurity.AuthStrategy) *apiservice.Response { + // 检查名称信息 + if err := CheckName(req.GetName()); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_InvalidUserName, req) } - - for index := range data.Resources { - res := data.Resources[index] - svr.enrichResourceDetial(ctx, res, allMatch, resp) + // 检查用户是否存在 + if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetPrincipals())); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUser, req) + } + // 检查用户组是否存在 + if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetPrincipals())); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) } + // 检查资源是否存在 + if errResp := svr.checkResourceExist(req.GetResources()); errResp != nil { + return errResp + } + return nil } -func (svr *Server) enrichResourceDetial(ctx context.Context, item authcommon.StrategyResource, - allMatch map[apisecurity.ResourceType]struct{}, resp *apisecurity.AuthStrategy) { - - resType := apisecurity.ResourceType(item.ResType) - slicePtr := resourceFieldPointerGetters[resType](resp.Resources) - if slicePtr.Elem().IsNil() { - return - } - sliceVal := slicePtr.Elem() - - if item.ResID == "*" { - allMatch[resType] = struct{}{} - sliceVal.Set(reflect.ValueOf([]*apisecurity.StrategyResourceEntry{ - { - Id: utils.NewStringValue("*"), - Namespace: utils.NewStringValue("*"), - Name: utils.NewStringValue("*"), - }, - })) - return - } - if _, ok := allMatch[resType]; !ok { - if data := resourceConvert[resType](ctx, svr, item); data != nil { - // 创建一个新数组并把元素的值追加进去 - resArr := reflect.Append(sliceVal, reflect.ValueOf(data)) - sliceVal.Set(resArr) +// checkUpdateStrategy 检查更新鉴权策略的请求 +// Case 1. 修改的是默认鉴权策略的话,只能修改资源,不能添加用户 or 用户组 +// Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 +func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy, + saved *authcommon.StrategyDetail) *apiservice.Response { + userId := utils.ParseUserID(ctx) + if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { + if !utils.ParseIsOwner(ctx) || userId != saved.Owner { + log.Error("[Auth][Strategy] modify strategy denied, current user not owner", + utils.ZapRequestID(utils.ParseRequestID(ctx)), + zap.String("user", userId), + zap.String("owner", saved.Owner), + zap.String("strategy", saved.ID)) + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowedAccess, req) } } -} -// filter different types of Strategy resources -func resourceDeduplication(resources []authcommon.StrategyResource) []authcommon.StrategyResource { - rLen := len(resources) - ret := make([]authcommon.StrategyResource, 0, rLen) - filters := map[apisecurity.ResourceType]map[string]struct{}{} - - est := struct{}{} - for i := range resources { - res := resources[i] - filter, ok := filters[apisecurity.ResourceType(res.ResType)] - if !ok { - filters[apisecurity.ResourceType(res.ResType)] = map[string]struct{}{} - filter = filters[apisecurity.ResourceType(res.ResType)] + if saved.Default { + if len(req.AddPrincipals.Users) != 0 || + len(req.AddPrincipals.Groups) != 0 || + len(req.RemovePrincipals.Groups) != 0 || + len(req.RemovePrincipals.Users) != 0 { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowModifyDefaultStrategyPrincipal, req) } - if _, exist := filter[res.ResID]; !exist { - filter[res.ResID] = est - ret = append(ret, res) + + // 主账户的默认策略禁止编辑 + if len(saved.Principals) == 1 && saved.Principals[0].PrincipalType == authcommon.PrincipalUser { + if saved.Principals[0].PrincipalID == utils.ParseOwnerID(ctx) { + return api.NewAuthResponse(apimodel.Code_NotAllowModifyOwnerDefaultStrategy) + } } } - return ret -} -// collectPrincipalEntry 将 Principal 转换为对应的 []authcommon.Principal 数组 -func collectPrincipalEntry(ruleID string, uType authcommon.PrincipalType, res []*apisecurity.Principal) []authcommon.Principal { - principals := make([]authcommon.Principal, 0, len(res)+1) - if len(res) == 0 { - return principals + // 检查用户是否存在 + if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetAddPrincipals())); err != nil { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUser, req) } - for index := range res { - principals = append(principals, authcommon.Principal{ - StrategyID: ruleID, - PrincipalID: res[index].GetId().GetValue(), - PrincipalType: uType, - }) + // 检查用户组是否存 + if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetAddPrincipals())); err != nil { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) } - return principals + // 检查资源是否存在 + if errResp := svr.checkResourceExist(req.GetAddResources()); errResp != nil { + return errResp + } + + return nil } // authStrategyRecordEntry 转换为鉴权策略的记录结构体 @@ -851,266 +769,295 @@ func authModifyStrategyRecordEntry( return entry } -var ( - resourceFieldNames = map[string]apisecurity.ResourceType{ - "namespaces": apisecurity.ResourceType_Namespaces, - "service": apisecurity.ResourceType_Services, - "config_groups": apisecurity.ResourceType_ConfigGroups, - "route_rules": apisecurity.ResourceType_RouteRules, - "ratelimit_rules": apisecurity.ResourceType_RateLimitRules, - "circuitbreaker_rules": apisecurity.ResourceType_CircuitBreakerRules, - "faultdetect_rules": apisecurity.ResourceType_FaultDetectRules, - "lane_rules": apisecurity.ResourceType_LaneRules, - "users": apisecurity.ResourceType_Users, - "user_groups": apisecurity.ResourceType_UserGroups, - "roles": apisecurity.ResourceType_Roles, - "auth_policies": apisecurity.ResourceType_PolicyRules, - } - - resourceFieldPointerGetters = map[apisecurity.ResourceType]func(*apisecurity.StrategyResources) reflect.Value{ - apisecurity.ResourceType_Namespaces: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetNamespaces() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.Namespaces) - }, - apisecurity.ResourceType_Services: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetServices() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.Services) - }, - apisecurity.ResourceType_ConfigGroups: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetConfigGroups() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.ConfigGroups) - }, - apisecurity.ResourceType_RouteRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetRouteRules() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.RouteRules) - }, - apisecurity.ResourceType_RateLimitRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetRatelimitRules() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.RatelimitRules) - }, - apisecurity.ResourceType_CircuitBreakerRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetCircuitbreakerRules() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.CircuitbreakerRules) - }, - apisecurity.ResourceType_FaultDetectRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetFaultdetectRules() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.FaultdetectRules) - }, - apisecurity.ResourceType_LaneRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetLaneRules() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.LaneRules) - }, - apisecurity.ResourceType_Users: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetUsers() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.Users) - }, - apisecurity.ResourceType_UserGroups: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetUserGroups() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.UserGroups) - }, - apisecurity.ResourceType_Roles: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetRoles() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.Roles) - }, - apisecurity.ResourceType_PolicyRules: func(as *apisecurity.StrategyResources) reflect.Value { - if as.GetAuthPolicies() == nil { - return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) - } - return reflect.ValueOf(&as.AuthPolicies) - }, +func convertPrincipalsToUsers(principals *apisecurity.Principals) []*apisecurity.User { + if principals == nil { + return make([]*apisecurity.User, 0) + } + + users := make([]*apisecurity.User, 0, len(principals.Users)) + for k := range principals.GetUsers() { + user := principals.GetUsers()[k] + users = append(users, &apisecurity.User{ + Id: user.Id, + }) } - resourceConvert = map[apisecurity.ResourceType]func(context.Context, - *Server, authcommon.StrategyResource) *apisecurity.StrategyResourceEntry{ + return users +} - // 注册、配置、治理 - apisecurity.ResourceType_Namespaces: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.Namespace().GetNamespace(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found namespace in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_ConfigGroups: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - id, _ := strconv.ParseUint(item.ResID, 10, 64) - user := svr.cacheMgr.ConfigGroup().GetGroupByID(id) - if user == nil { - log.Warn("[Auth][Strategy] not found config_group in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Namespace), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_Services: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.Namespace().GetNamespace(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found namespace in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_RouteRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.RoutingConfig().GetRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found route_rule in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_LaneRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.LaneRule().GetRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found lane_rule in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_RateLimitRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.RateLimit().GetRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found ratelimit_rule in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_CircuitBreakerRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.CircuitBreaker().GetRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found circuitbreaker_rule in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - apisecurity.ResourceType_FaultDetectRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.FaultDetector().GetRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found faultdetect_rule in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil - } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Namespace: utils.NewStringValue(user.Name), - Name: utils.NewStringValue(user.Name), - } - }, - // 鉴权资源 - apisecurity.ResourceType_Users: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.User().GetUserByID(item.ResID) +func convertPrincipalsToGroups(principals *apisecurity.Principals) []*apisecurity.UserGroup { + if principals == nil { + return make([]*apisecurity.UserGroup, 0) + } + + groups := make([]*apisecurity.UserGroup, 0, len(principals.Groups)) + for k := range principals.GetGroups() { + group := principals.GetGroups()[k] + groups = append(groups, &apisecurity.UserGroup{ + Id: group.Id, + }) + } + + return groups +} + +// checkUserExist 检查用户是否存在 +func (svr *Server) checkUserExist(users []*apisecurity.User) error { + if len(users) == 0 { + return nil + } + return svr.userSvr.GetUserHelper().CheckUsersExist(context.TODO(), users) +} + +// checkUserGroupExist 检查用户组是否存在 +func (svr *Server) checkGroupExist(groups []*apisecurity.UserGroup) error { + if len(groups) == 0 { + return nil + } + return svr.userSvr.GetUserHelper().CheckGroupsExist(context.TODO(), groups) +} + +// checkResourceExist 检查资源是否存在 +func (svr *Server) checkResourceExist(resources *apisecurity.StrategyResources) *apiservice.Response { + namespaces := resources.GetNamespaces() + + nsCache := svr.cacheMgr.Namespace() + for index := range namespaces { + val := namespaces[index] + if val.GetId().GetValue() == "*" { + break + } + ns := nsCache.GetNamespace(val.GetId().GetValue()) + if ns == nil { + return api.NewAuthResponse(apimodel.Code_NotFoundNamespace) + } + } + + services := resources.GetServices() + svcCache := svr.cacheMgr.Service() + for index := range services { + val := services[index] + if val.GetId().GetValue() == "*" { + break + } + svc := svcCache.GetServiceByID(val.GetId().GetValue()) + if svc == nil { + return api.NewAuthResponse(apimodel.Code_NotFoundService) + } + } + + return nil +} + +// normalizeResource 对于资源进行归一化处理, 如果出现 * 的话,则该资源访问策略就是 * +func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) *apisecurity.StrategyResources { + namespaces := resources.GetNamespaces() + for index := range namespaces { + val := namespaces[index] + if val.GetId().GetValue() == "*" { + resources.Namespaces = []*apisecurity.StrategyResourceEntry{{ + Id: utils.NewStringValue("*"), + }} + break + } + } + services := resources.GetServices() + for index := range services { + val := services[index] + if val.GetId().GetValue() == "*" { + resources.Services = []*apisecurity.StrategyResourceEntry{{ + Id: utils.NewStringValue("*"), + }} + break + } + } + return resources +} + +// fillPrincipalInfo 填充 principal 摘要信息 +func (svr *Server) fillPrincipalInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { + users := make([]*apisecurity.Principal, 0, len(data.Principals)) + groups := make([]*apisecurity.Principal, 0, len(data.Principals)) + for index := range data.Principals { + principal := data.Principals[index] + if principal.PrincipalType == authcommon.PrincipalUser { + user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ + Id: wrapperspb.String(principal.PrincipalID), + }) if user == nil { - log.Warn("[Auth][Strategy] not found user in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil + continue } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Name: utils.NewStringValue(user.Name), + users = append(users, &apisecurity.Principal{ + Id: utils.NewStringValue(user.GetId().GetValue()), + Name: utils.NewStringValue(user.GetName().GetValue()), + }) + } else { + group := svr.userSvr.GetUserHelper().GetGroup(context.TODO(), &apisecurity.UserGroup{ + Id: wrapperspb.String(principal.PrincipalID), + }) + if group == nil { + continue } - }, - apisecurity.ResourceType_UserGroups: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.User().GetGroup(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found user_group in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil + groups = append(groups, &apisecurity.Principal{ + Id: utils.NewStringValue(group.GetId().GetValue()), + Name: utils.NewStringValue(group.GetName().GetValue()), + }) + } + } + + resp.Principals = &apisecurity.Principals{ + Users: users, + Groups: groups, + } +} + +// fillResourceInfo 填充资源摘要信息 +func (svr *Server) fillResourceInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { + namespaces := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) + services := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) + configGroups := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) + + var ( + autoAllNs bool + autoAllSvc bool + autoAllConfigGroup bool + ) + + for index := range data.Resources { + res := data.Resources[index] + switch res.ResType { + case int32(apisecurity.ResourceType_Namespaces): + if res.ResID == "*" { + autoAllNs = true + namespaces = []*apisecurity.StrategyResourceEntry{ + { + Id: utils.NewStringValue("*"), + Namespace: utils.NewStringValue("*"), + Name: utils.NewStringValue("*"), + }, + } + continue } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Name: utils.NewStringValue(user.Name), + + if !autoAllNs { + ns := svr.cacheMgr.Namespace().GetNamespace(res.ResID) + if ns == nil { + log.Warn("[Auth][Strategy] not found namespace in fill-info", + zap.String("id", data.ID), zap.String("namespace", res.ResID)) + continue + } + namespaces = append(namespaces, &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(ns.Name), + Namespace: utils.NewStringValue(ns.Name), + Name: utils.NewStringValue(ns.Name), + }) } - }, - apisecurity.ResourceType_Roles: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.Role().GetRole(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found role in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil + case int32(apisecurity.ResourceType_Services): + if res.ResID == "*" { + autoAllSvc = true + services = []*apisecurity.StrategyResourceEntry{ + { + Id: utils.NewStringValue("*"), + Namespace: utils.NewStringValue("*"), + Name: utils.NewStringValue("*"), + }, + } + continue } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Name: utils.NewStringValue(user.Name), + + if !autoAllSvc { + svc := svr.cacheMgr.Service().GetServiceByID(res.ResID) + if svc == nil { + log.Warn("[Auth][Strategy] not found service in fill-info", + zap.String("id", data.ID), zap.String("service", res.ResID)) + continue + } + services = append(services, &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(svc.ID), + Namespace: utils.NewStringValue(svc.Namespace), + Name: utils.NewStringValue(svc.Name), + }) } - }, - apisecurity.ResourceType_PolicyRules: func(ctx context.Context, svr *Server, - item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { - user := svr.cacheMgr.AuthStrategy().GetPolicyRule(item.ResID) - if user == nil { - log.Warn("[Auth][Strategy] not found auth_policy in fill-info", - zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) - return nil + case int32(apisecurity.ResourceType_ConfigGroups): + if res.ResID == "*" { + autoAllConfigGroup = true + configGroups = []*apisecurity.StrategyResourceEntry{ + { + Id: utils.NewStringValue("*"), + Namespace: utils.NewStringValue("*"), + Name: utils.NewStringValue("*"), + }, + } + continue } - return &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(item.ResID), - Name: utils.NewStringValue(user.Name), + if !autoAllConfigGroup { + groupId, err := strconv.ParseUint(res.ResID, 10, 64) + if err != nil { + log.Warn("[Auth][Strategy] invalid resource id", + zap.String("id", data.ID), zap.String("config_file_group", res.ResID)) + continue + } + group := svr.cacheMgr.ConfigGroup().GetGroupByID(groupId) + if group == nil { + log.Warn("[Auth][Strategy] not found config_file_group in fill-info", + zap.String("id", data.ID), zap.String("config_file_group", res.ResID)) + continue + } + configGroups = append(configGroups, &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(res.ResID), + Namespace: utils.NewStringValue(group.Namespace), + Name: utils.NewStringValue(group.Name), + }) } - }, + } } -) + + resp.Resources = &apisecurity.StrategyResources{ + Namespaces: namespaces, + Services: services, + ConfigGroups: configGroups, + } +} + +type resourceFilter struct { + ns map[string]struct{} + svc map[string]struct{} + conf map[string]struct{} +} + +func (f *resourceFilter) GetFilter(t apisecurity.ResourceType) (map[string]struct{}, bool) { + switch t { + case apisecurity.ResourceType_Namespaces: + return f.ns, true + case apisecurity.ResourceType_Services: + return f.svc, true + case apisecurity.ResourceType_ConfigGroups: + return f.conf, true + } + return nil, false +} + +// filter different types of Strategy resources +func resourceDeduplication(resources []authcommon.StrategyResource) []authcommon.StrategyResource { + rLen := len(resources) + ret := make([]authcommon.StrategyResource, 0, rLen) + rf := resourceFilter{ + ns: make(map[string]struct{}, rLen), + svc: make(map[string]struct{}, rLen), + conf: make(map[string]struct{}, rLen), + } + + est := struct{}{} + for i := range resources { + res := resources[i] + filter, ok := rf.GetFilter(apisecurity.ResourceType(res.ResType)) + if !ok { + continue + } + if _, exist := filter[res.ResID]; !exist { + rf.ns[res.ResID] = est + ret = append(ret, res) + } + } + return ret +} diff --git a/auth/policy/strategy_test.go b/auth/policy/strategy_test.go new file mode 100644 index 000000000..5626fd7d6 --- /dev/null +++ b/auth/policy/strategy_test.go @@ -0,0 +1,957 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package policy_test + +import ( + "context" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/golang/mock/gomock" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/auth/policy" + defaultuser "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/eventhub" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +type StrategyTest struct { + admin *authcommon.User + ownerOne *authcommon.User + ownerTwo *authcommon.User + + namespaces []*model.Namespace + services []*model.Service + strategies []*authcommon.StrategyDetail + allStrategies []*authcommon.StrategyDetail + defaultStrategies []*authcommon.StrategyDetail + + users []*authcommon.User + groups []*authcommon.UserGroupDetail + + storage *storemock.MockStore + cacheMgn *cache.CacheManager + checker auth.AuthChecker + + svr auth.StrategyServer + + cancel context.CancelFunc + + ctrl *gomock.Controller +} + +func newStrategyTest(t *testing.T) *StrategyTest { + reset(false) + eventhub.InitEventHub() + + ctrl := gomock.NewController(t) + + users := createMockUser(10) + groups := createMockUserGroup(users) + + namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) + services := createMockService(namespaces) + serviceMap := convertServiceSliceToMap(services) + defaultStrategies, strategies := createMockStrategy(users, groups, services[:len(users)+len(groups)]) + + allStrategies := make([]*authcommon.StrategyDetail, 0, len(defaultStrategies)+len(strategies)) + allStrategies = append(allStrategies, defaultStrategies...) + allStrategies = append(allStrategies, strategies...) + + cfg, storage := initCache(ctrl) + + storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(allStrategies, nil) + storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) + storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) + storage.EXPECT().GetStrategyResources(gomock.Eq(users[1].ID), gomock.Any()).AnyTimes().Return(strategies[1].Resources, nil) + storage.EXPECT().GetStrategyResources(gomock.Eq(groups[1].ID), gomock.Any()).AnyTimes().Return(strategies[len(users)-1+2].Resources, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + err = cacheMgn.OpenResourceCache([]cachetypes.ConfigEntry{ + { + Name: cachetypes.ServiceName, + Option: map[string]interface{}{ + "disableBusiness": false, + "needMeta": true, + }, + }, + { + Name: cachetypes.InstanceName, + }, + { + Name: cachetypes.NamespaceName, + }, + { + Name: cachetypes.UsersName, + }, + { + Name: cachetypes.StrategyRuleName, + }, + }...) + if err != nil { + t.Fatal(err) + } + if err := cache.TestRun(ctx, cacheMgn); err != nil { + t.Fatal(err) + } + _ = cacheMgn.TestUpdate() + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgn) + + _, svr, err := newPolicyServer() + if err != nil { + t.Fatal(err) + } + if err := svr.Initialize(&auth.Config{ + Strategy: &auth.StrategyConfig{ + Name: auth.DefaultPolicyPluginName, + }, + }, storage, cacheMgn, proxySvr); err != nil { + t.Fatal(err) + } + checker := svr.GetAuthChecker() + + t.Cleanup(func() { + cacheMgn.Close() + }) + + return &StrategyTest{ + ownerOne: users[0], + + users: users, + groups: groups, + + namespaces: namespaces, + services: services, + strategies: strategies, + allStrategies: allStrategies, + defaultStrategies: defaultStrategies, + + storage: storage, + cacheMgn: cacheMgn, + checker: checker, + + cancel: cancel, + + svr: svr, + + ctrl: ctrl, + } +} + +func (g *StrategyTest) Clean() { + g.cancel() + _ = g.cacheMgn.Close() +} + +func Test_GetPrincipalResources(t *testing.T) { + + strategyTest := newStrategyTest(t) + defer strategyTest.Clean() + + _ = strategyTest.cacheMgn.TestUpdate() + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + ret := strategyTest.svr.GetPrincipalResources(valCtx, map[string]string{ + "principal_id": strategyTest.users[1].ID, + "principal_type": "user", + }) + + t.Logf("GetPrincipalResources resp : %+v", ret) + assert.EqualValues(t, api.ExecuteSuccess, ret.Code.GetValue(), "need query success") + resources := ret.Resources + assert.Equal(t, 2, len(resources.GetServices()), "need query 2 service resources") +} + +func Test_CreateStrategy(t *testing.T) { + + strategyTest := newStrategyTest(t) + defer strategyTest.Clean() + + _ = strategyTest.cacheMgn.TestUpdate() + + t.Run("正常创建鉴权策略", func(t *testing.T) { + strategyTest.storage.EXPECT().AddStrategy(gomock.Any(), gomock.Any()).Return(nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := utils.NewUUID() + + resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyId}, + Name: &wrapperspb.StringValue{ + Value: "正常创建鉴权策略", + }, + Principals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{{ + Id: &wrapperspb.StringValue{ + Value: strategyTest.users[1].ID, + }, + Name: &wrapperspb.StringValue{ + Value: strategyTest.users[1].Name, + }, + }}, + Groups: []*apisecurity.Principal{}, + }, + Resources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + Action: 0, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("创建鉴权策略-非owner用户发起", func(t *testing.T) { + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + strategyId := utils.NewUUID() + + resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyId}, + Name: &wrapperspb.StringValue{ + Value: "创建鉴权策略-非owner用户发起", + }, + Principals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{{ + Id: &wrapperspb.StringValue{ + Value: strategyTest.users[1].ID, + }, + Name: &wrapperspb.StringValue{ + Value: strategyTest.users[1].Name, + }, + }}, + Groups: []*apisecurity.Principal{}, + }, + Resources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + Action: 0, + }) + + assert.Equal(t, api.OperationRoleException, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("创建鉴权策略-关联用户不存在", func(t *testing.T) { + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := utils.NewUUID() + + resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyId}, + Name: &wrapperspb.StringValue{ + Value: "创建鉴权策略-关联用户不存在", + }, + Principals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{{ + Id: &wrapperspb.StringValue{ + Value: utils.NewUUID(), + }, + Name: &wrapperspb.StringValue{ + Value: "user-1", + }, + }}, + Groups: []*apisecurity.Principal{}, + }, + Resources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + Action: 0, + }) + + assert.Equal(t, api.NotFoundUser, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("创建鉴权策略-关联用户组不存在", func(t *testing.T) { + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := utils.NewUUID() + + resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyId}, + Name: &wrapperspb.StringValue{ + Value: "创建鉴权策略-关联用户组不存在", + }, + Principals: &apisecurity.Principals{ + Groups: []*apisecurity.Principal{{ + Id: &wrapperspb.StringValue{ + Value: utils.NewUUID(), + }, + Name: &wrapperspb.StringValue{ + Value: "user-1", + }, + }}, + }, + Resources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + Action: 0, + }) + + assert.Equal(t, api.NotFoundUserGroup, resp.Code.GetValue(), resp.Info.GetValue()) + }) + +} + +func Test_UpdateStrategy(t *testing.T) { + strategyTest := newStrategyTest(t) + defer strategyTest.Clean() + + _ = strategyTest.cacheMgn.TestUpdate() + + t.Run("正常更新鉴权策略", func(t *testing.T) { + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) + strategyTest.storage.EXPECT().UpdateStrategy(gomock.Any()).Return(nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := strategyTest.strategies[0].ID + + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{ + Value: strategyId, + }, + Name: &wrapperspb.StringValue{ + Value: strategyTest.strategies[0].Name, + }, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: strategyTest.users[2].ID}, + }, + }, + }, + RemovePrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: strategyTest.users[3].ID}, + }, + }, + }, + AddResources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{ + {Id: &wrapperspb.StringValue{Value: strategyTest.namespaces[0].Name}}, + }, + Services: []*apisecurity.StrategyResourceEntry{ + {Id: &wrapperspb.StringValue{Value: strategyTest.services[0].ID}}, + }, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + RemoveResources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: strategyId, + }, + Namespaces: []*apisecurity.StrategyResourceEntry{ + {Id: &wrapperspb.StringValue{Value: strategyTest.namespaces[1].Name}}, + }, + Services: []*apisecurity.StrategyResourceEntry{ + {Id: &wrapperspb.StringValue{Value: strategyTest.services[1].ID}}, + }, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + }, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新鉴权策略-非owner用户发起", func(t *testing.T) { + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + strategyId := utils.NewUUID() + + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{ + Value: strategyId, + }, + Name: &wrapperspb.StringValue{ + Value: strategyTest.strategies[0].Name, + }, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{}, + Groups: []*apisecurity.Principal{}, + }, + RemovePrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{}, + Groups: []*apisecurity.Principal{}, + }, + AddResources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: "", + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + RemoveResources: &apisecurity.StrategyResources{ + StrategyId: &wrapperspb.StringValue{ + Value: "", + }, + Namespaces: []*apisecurity.StrategyResourceEntry{}, + Services: []*apisecurity.StrategyResourceEntry{}, + ConfigGroups: []*apisecurity.StrategyResourceEntry{}, + }, + }, + }) + + assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新鉴权策略-目标策略不存在", func(t *testing.T) { + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + + strategyId := strategyTest.defaultStrategies[0].ID + + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{Value: strategyId}, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, + }, + }, + }, + }, + }) + + assert.Equal(t, api.NotFoundAuthStrategyRule, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新鉴权策略-owner不为自己", func(t *testing.T) { + oldOwner := strategyTest.strategies[2].Owner + + defer func() { + strategyTest.strategies[2].Owner = oldOwner + }() + + strategyTest.strategies[2].Owner = strategyTest.users[2].ID + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[2], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := strategyTest.strategies[2].ID + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{Value: strategyId}, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, + }, + }, + }, + }, + }) + + assert.Equal(t, api.NotAllowedAccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新鉴权策略-关联用户不存在", func(t *testing.T) { + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := strategyTest.strategies[0].ID + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{Value: strategyId}, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, + }, + }, + }, + }, + }) + + assert.Equal(t, api.NotFoundUser, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新鉴权策略-关联用户组不存在", func(t *testing.T) { + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := strategyTest.strategies[0].ID + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{Value: strategyId}, + AddPrincipals: &apisecurity.Principals{ + Groups: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, + }, + }, + }, + }, + }) + + assert.Equal(t, api.NotFoundUserGroup, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新默认鉴权策略-不能更改principals成员", func(t *testing.T) { + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.defaultStrategies[0], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + strategyId := strategyTest.defaultStrategies[0].ID + resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ + { + Id: &wrapperspb.StringValue{Value: strategyId}, + AddPrincipals: &apisecurity.Principals{ + Users: []*apisecurity.Principal{ + { + Id: &wrapperspb.StringValue{Value: strategyTest.users[3].ID}, + }, + }, + }, + }, + }) + + assert.Equal(t, api.NotAllowModifyDefaultStrategyPrincipal, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + +} + +func Test_DeleteStrategy(t *testing.T) { + strategyTest := newStrategyTest(t) + defer strategyTest.Clean() + + _ = strategyTest.cacheMgn.TestUpdate() + + t.Run("正常删除鉴权策略", func(t *testing.T) { + index := rand.Intn(len(strategyTest.strategies)) + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) + strategyTest.storage.EXPECT().DeleteStrategy(gomock.Any()).Return(nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + + resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ + {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("删除鉴权策略-非owner用户发起", func(t *testing.T) { + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ + {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[rand.Intn(len(strategyTest.strategies))].ID}}, + }) + + assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("删除鉴权策略-目标策略不存在", func(t *testing.T) { + + index := rand.Intn(len(strategyTest.strategies)) + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ + {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("删除鉴权策略-目标为默认鉴权策略", func(t *testing.T) { + index := rand.Intn(len(strategyTest.defaultStrategies)) + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.defaultStrategies[index], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ + {Id: &wrapperspb.StringValue{Value: strategyTest.defaultStrategies[index].ID}}, + }) + + assert.Equal(t, api.BadRequest, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + + t.Run("删除鉴权策略-目标owner不为自己", func(t *testing.T) { + index := rand.Intn(len(strategyTest.defaultStrategies)) + oldOwner := strategyTest.strategies[index].Owner + + defer func() { + strategyTest.strategies[index].Owner = oldOwner + }() + + strategyTest.strategies[index].Owner = strategyTest.users[2].ID + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ + {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, + }) + + assert.Equal(t, api.NotAllowedAccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) + }) + +} + +func Test_GetStrategy(t *testing.T) { + strategyTest := newStrategyTest(t) + defer strategyTest.Clean() + + t.Run("正常查询鉴权策略", func(t *testing.T) { + // 主账户查询自己的策略 + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[0].ID}, + }) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) + + // 主账户查询自己自账户的策略 + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[1], nil) + valCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + _ = strategyTest.cacheMgn.TestUpdate() + resp = strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[1].ID}, + }) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("查询鉴权策略-目标owner不为自己", func(t *testing.T) { + t.Skip() + var index int + for { + index = rand.Intn(len(strategyTest.defaultStrategies)) + if index != 2 { + break + } + } + oldOwner := strategyTest.strategies[index].Owner + + defer func() { + strategyTest.strategies[index].Owner = oldOwner + }() + + strategyTest.strategies[index].Owner = strategyTest.users[2].ID + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) + + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}, + }) + + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("查询鉴权策略-非owner用户查询自己的", func(t *testing.T) { + + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[1], nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[1].ID}, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("查询鉴权策略-非owner用户查询自己所在用户组的", func(t *testing.T) { + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[len(strategyTest.users)-1+2], nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[len(strategyTest.users)-1+2].ID}, + }) + + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("查询鉴权策略-非owner用户查询别人的", func(t *testing.T) { + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[2], nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: strategyTest.strategies[2].ID}, + }) + + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("查询鉴权策略-目标策略不存在", func(t *testing.T) { + strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) + + valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) + + _ = strategyTest.cacheMgn.TestUpdate() + resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ + Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, + }) + + assert.Equal(t, api.NotFoundAuthStrategyRule, resp.Code.GetValue(), resp.Info.GetValue()) + }) + +} + +func Test_parseStrategySearchArgs(t *testing.T) { + type args struct { + ctx context.Context + searchFilters map[string]string + } + tests := []struct { + name string + args args + want map[string]string + }{ + { + name: "res_type(namespace) 查询", + args: args{ + ctx: func() context.Context { + ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") + ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) + return ctx + }(), + searchFilters: map[string]string{ + "res_type": "namespace", + }, + }, + want: map[string]string{ + "res_type": "0", + "owner": "owner", + }, + }, + { + name: "res_type(service) 查询", + args: args{ + ctx: func() context.Context { + ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") + ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) + return ctx + }(), + searchFilters: map[string]string{ + "res_type": "service", + }, + }, + want: map[string]string{ + "res_type": "1", + "owner": "owner", + }, + }, + { + name: "principal_type(user) 查询", + args: args{ + ctx: func() context.Context { + ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") + ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.SubAccountUserRole) + ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, false) + return ctx + }(), + searchFilters: map[string]string{ + "principal_type": "user", + }, + }, + want: map[string]string{ + "principal_type": "1", + "owner": "owner", + "principal_id": "user", + }, + }, + { + name: "principal_type(group) 查询", + args: args{ + ctx: func() context.Context { + ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") + ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) + return ctx + }(), + searchFilters: map[string]string{ + "principal_type": "group", + }, + }, + want: map[string]string{ + "principal_type": "2", + "owner": "owner", + }, + }, + { + name: "按照资源ID查询", + args: args{ + ctx: func() context.Context { + ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") + ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") + ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) + ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) + return ctx + }(), + searchFilters: map[string]string{ + "res_id": "res_id", + }, + }, + want: map[string]string{ + "res_id": "res_id", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := policy.ParseStrategySearchArgs(tt.args.ctx, tt.args.searchFilters); !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseStrategySearchArgs() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_AuthServer_NormalOperateStrategy(t *testing.T) { + suit := &AuthTestSuit{} + if err := suit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + suit.cleanAllAuthStrategy() + suit.cleanAllUser() + suit.cleanAllUserGroup() + suit.Destroy() + }) + + users := createApiMockUser(10, "test") + + t.Run("正常创建用户", func(t *testing.T) { + resp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + }) + + t.Run("正常更新用户", func(t *testing.T) { + users[0].Comment = utils.NewStringValue("update user comment") + resp := suit.UserServer().UpdateUser(suit.DefaultCtx, users[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ + "id": users[0].GetId().GetValue(), + }) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + assert.Equal(t, 1, int(qresp.Amount.GetValue())) + assert.Equal(t, 1, int(qresp.Size.GetValue())) + + retUsers := qresp.GetUsers()[0] + assert.Equal(t, users[0].GetComment().GetValue(), retUsers.GetComment().GetValue()) + }) + + t.Run("正常删除用户", func(t *testing.T) { + resp := suit.UserServer().DeleteUsers(suit.DefaultCtx, []*apisecurity.User{users[3]}) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ + "id": users[3].GetId().GetValue(), + }) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + assert.Equal(t, 0, int(qresp.Amount.GetValue())) + assert.Equal(t, 0, int(qresp.Size.GetValue())) + }) + + t.Run("正常更新用户Token", func(t *testing.T) { + resp := suit.UserServer().ResetUserToken(suit.DefaultCtx, users[0]) + if !api.IsSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + _ = suit.CacheMgr().TestUpdate() + + qresp := suit.UserServer().GetUserToken(suit.DefaultCtx, users[0]) + if !api.IsSuccess(qresp) { + t.Fatal(qresp.String()) + } + assert.Equal(t, resp.GetUser().GetAuthToken().GetValue(), qresp.GetUser().GetAuthToken().GetValue()) + }) +} diff --git a/auth/user/common_test.go b/auth/user/common_test.go new file mode 100644 index 000000000..a5479c62a --- /dev/null +++ b/auth/user/common_test.go @@ -0,0 +1,407 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package defaultuser_test + +import ( + "fmt" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "golang.org/x/crypto/bcrypt" + "google.golang.org/protobuf/types/known/wrapperspb" + + defaultuser "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + "github.com/polarismesh/polaris/common/metrics" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +var ( + _defaultSalt = "polarismesh@2021" +) + +func reset(strict bool) { +} + +func initCache(ctrl *gomock.Controller) (*cache.Config, *storemock.MockStore) { + metrics.InitMetrics() + /* + - name: service # 加载服务数据 + option: + disableBusiness: false # 不加载业务服务 + needMeta: true # 加载服务元数据 + - name: instance # 加载实例数据 + option: + disableBusiness: false # 不加载业务服务实例 + needMeta: true # 加载实例元数据 + - name: routingConfig # 加载路由数据 + - name: rateLimitConfig # 加载限流数据 + - name: circuitBreakerConfig # 加载熔断数据 + - name: l5 # 加载l5数据 + - name: users + - name: strategyRule + - name: namespace + */ + cfg := &cache.Config{} + storage := storemock.NewMockStore(ctrl) + + mockTx := storemock.NewMockTx(ctrl) + mockTx.EXPECT().Commit().Return(nil).AnyTimes() + mockTx.EXPECT().Rollback().Return(nil).AnyTimes() + mockTx.EXPECT().CreateReadView().Return(nil).AnyTimes() + + storage.EXPECT().StartReadTx().Return(mockTx, nil).AnyTimes() + storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetInstancesCountTx(gomock.Any()).AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetMoreInstances(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]*model.Instance{ + "123": { + Proto: &service_manage.Instance{ + Id: wrapperspb.String(uuid.NewString()), + Host: wrapperspb.String("127.0.0.1"), + Port: wrapperspb.UInt32(8080), + }, + Valid: true, + }, + }, nil).AnyTimes() + storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) + + return cfg, storage +} + +func createMockNamespace(total int, owner string) []*model.Namespace { + namespaces := make([]*model.Namespace, 0, total) + + for i := 0; i < total; i++ { + namespaces = append(namespaces, &model.Namespace{ + Name: fmt.Sprintf("namespace_%d", i), + Owner: owner, + Valid: true, + }) + } + + return namespaces +} + +func createMockService(namespaces []*model.Namespace) []*model.Service { + services := make([]*model.Service, 0, len(namespaces)) + + for i := 0; i < len(namespaces); i++ { + ns := namespaces[i] + services = append(services, &model.Service{ + ID: utils.NewUUID(), + Namespace: ns.Name, + Owner: ns.Owner, + Name: fmt.Sprintf("service_%d", i), + Valid: true, + }) + } + + return services +} + +// createMockUser 默认 users[0] 为 owner 用户 +func createMockUser(total int, prefix ...string) []*authcommon.User { + users := make([]*authcommon.User, 0, total) + + ownerId := utils.NewUUID() + + nameTemp := "user-%d" + if len(prefix) != 0 { + nameTemp = prefix[0] + nameTemp + } + + for i := 0; i < total; i++ { + id := fmt.Sprintf("fake-user-id-%d-%s", i, utils.NewUUID()) + if i == 0 { + id = ownerId + } + pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) + token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") + users = append(users, &authcommon.User{ + ID: id, + Name: fmt.Sprintf(nameTemp, i), + Password: string(pwd), + Owner: func() string { + if id == ownerId { + return "" + } + return ownerId + }(), + Source: "Polaris", + Mobile: "", + Email: "", + Type: func() authcommon.UserRoleType { + if id == ownerId { + return authcommon.OwnerUserRole + } + return authcommon.SubAccountUserRole + }(), + Token: token, + TokenEnable: true, + Valid: true, + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + return users +} + +func createApiMockUser(total int, prefix ...string) []*apisecurity.User { + users := make([]*apisecurity.User, 0, total) + + models := createMockUser(total, prefix...) + + for i := range models { + users = append(users, &apisecurity.User{ + Name: utils.NewStringValue("test-" + models[i].Name), + Password: utils.NewStringValue("123456"), + Source: utils.NewStringValue("Polaris"), + Comment: utils.NewStringValue(models[i].Comment), + Mobile: utils.NewStringValue(models[i].Mobile), + Email: utils.NewStringValue(models[i].Email), + }) + } + + return users +} + +func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { + groups := make([]*authcommon.UserGroupDetail, 0, len(users)) + + for i := range users { + user := users[i] + id := utils.NewUUID() + + token, _ := defaultuser.CreateToken("", id, _defaultSalt) + + groups = append(groups, &authcommon.UserGroupDetail{ + UserGroup: &authcommon.UserGroup{ + ID: id, + Name: fmt.Sprintf("test-group-%d", i), + Owner: users[0].ID, + Token: token, + TokenEnable: true, + Valid: true, + Comment: "", + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }, + UserIds: map[string]struct{}{ + user.ID: {}, + }, + }) + } + + return groups +} + +// createMockApiUserGroup +func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { + musers := make([]*authcommon.User, 0, len(users)) + for i := range users { + musers = append(musers, &authcommon.User{ + ID: users[i].GetId().GetValue(), + }) + } + + models := createMockUserGroup(musers) + ret := make([]*apisecurity.UserGroup, 0, len(models)) + + for i := range models { + ret = append(ret, &apisecurity.UserGroup{ + Name: utils.NewStringValue(models[i].Name), + Comment: utils.NewStringValue(models[i].Comment), + Relation: &apisecurity.UserGroupRelation{ + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(users[i].GetId().GetValue()), + }, + }, + }, + }) + } + + return ret +} + +func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { + strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) + + owner := "" + for i := 0; i < len(users); i++ { + user := users[i] + if user.Owner == "" { + owner = user.ID + break + } + } + + for i := 0; i < len(users); i++ { + user := users[i] + service := services[i] + id := utils.NewUUID() + strategies = append(strategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: user.ID, + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: false, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: user.ID, + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: true, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + + for i := 0; i < len(groups); i++ { + group := groups[i] + service := services[len(users)+i] + id := utils.NewUUID() + strategies = append(strategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: group.ID, + PrincipalType: authcommon.PrincipalGroup, + }, + }, + Default: false, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + + defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ + ID: id, + Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), + Action: apisecurity.AuthAction_READ_WRITE.String(), + Comment: "", + Principals: []authcommon.Principal{ + { + PrincipalID: group.ID, + PrincipalType: authcommon.PrincipalGroup, + }, + }, + Default: true, + Owner: owner, + Resources: []authcommon.StrategyResource{ + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: service.Namespace, + }, + { + StrategyID: id, + ResType: int32(apisecurity.ResourceType_Services), + ResID: service.ID, + }, + }, + Valid: true, + Revision: utils.NewUUID(), + CreateTime: time.Time{}, + ModifyTime: time.Time{}, + }) + } + + return defaultStrategies, strategies +} + +func convertServiceSliceToMap(services []*model.Service) map[string]*model.Service { + ret := make(map[string]*model.Service) + + for i := range services { + service := services[i] + ret[service.ID] = service + } + + return ret +} diff --git a/auth/user/group.go b/auth/user/group.go index 18045a3dd..15dbbe487 100644 --- a/auth/user/group.go +++ b/auth/user/group.go @@ -282,6 +282,8 @@ func (svr *Server) EnableGroupToken(ctx context.Context, req *apisecurity.UserGr // ResetGroupToken 刷新用户组的token func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { var ( + requestID = utils.ParseRequestID(ctx) + platformID = utils.ParsePlatformID(ctx) group, errResp = svr.getGroupFromDB(req.Id.GetValue()) ) @@ -295,7 +297,8 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro newToken, err := createGroupToken(group.ID, svr.authOpt.Salt) if err != nil { - log.Error("reset group token", utils.RequestID(ctx), zap.Error(err)) + log.Error("reset group token", utils.ZapRequestID(requestID), + utils.ZapPlatformID(platformID), zap.Error(err)) return api.NewAuthResponseWithMsg(apimodel.Code_ExecuteException, err.Error()) } @@ -309,12 +312,12 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro } if err := svr.storage.UpdateGroup(modifyReq); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } log.Info("reset group token", zap.String("group-id", req.Id.GetValue()), - utils.RequestID(ctx)) + utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) svr.RecordHistory(userGroupRecordEntry(ctx, req, group.UserGroup, model.OUpdate)) req.AuthToken = utils.NewStringValue(newToken) diff --git a/auth/user/group_test.go b/auth/user/group_test.go new file mode 100644 index 000000000..087a95439 --- /dev/null +++ b/auth/user/group_test.go @@ -0,0 +1,822 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package defaultuser_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/auth" + defaultauth "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + v1 "github.com/polarismesh/polaris/common/api/v1" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +type GroupTest struct { + ctrl *gomock.Controller + + ownerOne *authcommon.User + ownerTwo *authcommon.User + + users []*authcommon.User + groups []*authcommon.UserGroupDetail + newGroups []*authcommon.UserGroupDetail + allGroups []*authcommon.UserGroupDetail + + storage *storemock.MockStore + cacheMgn *cache.CacheManager + checker auth.AuthChecker + cancel context.CancelFunc + + svr auth.UserServer +} + +func newGroupTest(t *testing.T) *GroupTest { + reset(false) + ctrl := gomock.NewController(t) + + users := createMockUser(10) + groups := createMockUserGroup(users) + + newUsers := createMockUser(10) + newGroups := createMockUserGroup(newUsers) + + allGroups := append(groups, newGroups...) + + storage := storemock.NewMockStore(ctrl) + + storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) + storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) + storage.EXPECT().AddGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(append(users, newUsers...), nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allGroups, nil) + + cfg := &cache.Config{} + + ctx, cancel := context.WithCancel(context.Background()) + cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Error(err) + } + _ = cacheMgn.OpenResourceCache([]cachetypes.ConfigEntry{ + { + Name: cachetypes.UsersName, + }, + }...) + t.Cleanup(func() { + _ = cacheMgn.Close() + }) + + _, proxySvr, err := defaultauth.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgn) + _ = cacheMgn.TestUpdate() + return &GroupTest{ + ctrl: ctrl, + ownerOne: users[0], + ownerTwo: newUsers[0], + users: users, + groups: groups, + newGroups: newGroups, + allGroups: allGroups, + storage: storage, + cacheMgn: cacheMgn, + cancel: cancel, + svr: proxySvr, + } +} + +func (g *GroupTest) Clean() { + g.cancel() + g.cacheMgn.Close() + g.ctrl.Finish() +} + +func Test_server_CreateGroup(t *testing.T) { + groupTest := newGroupTest(t) + + defer groupTest.Clean() + + t.Run("正常创建用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + + groups := createMockUserGroup(groupTest.users[:1]) + groups[0].ID = utils.NewUUID() + + groupTest.storage.EXPECT().GetGroupByName(gomock.Any(), gomock.Any()).Return(nil, nil) + + resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groups[0].ID), + Name: utils.NewStringValue(groups[0].Name), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("用户组已存在", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + + groups := createMockUserGroup(groupTest.users[:1]) + groups[0].ID = utils.NewUUID() + + groupTest.storage.EXPECT().GetGroupByName(gomock.Any(), gomock.Any()).Return(groups[0].UserGroup, nil) + + resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groups[0].ID), + Name: utils.NewStringValue(groups[0].Name), + }) + + assert.True(t, resp.GetCode().Value == v1.UserGroupExisted, resp.Info.GetValue()) + }) + + t.Run("子用户去创建用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groups := createMockUserGroup(groupTest.users[:1]) + groups[0].ID = utils.NewUUID() + + resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groups[0].ID), + Name: utils.NewStringValue(groups[0].Name), + }) + + assert.True(t, resp.GetCode().Value == v1.OperationRoleException, resp.Info.GetValue()) + }) + + t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[3], nil) + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[3].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil).AnyTimes() + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil).AnyTimes() + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_server_GetGroup(t *testing.T) { + groupTest := newGroupTest(t) + + defer groupTest.Clean() + t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.newGroups[0], nil) + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.newGroups[0].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil).AnyTimes() + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil).AnyTimes() + + resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_server_UpdateGroup(t *testing.T) { + t.Run("主账户更新用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) + groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).Return(nil) + + req := &apisecurity.ModifyUserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + Comment: &wrapperspb.StringValue{ + Value: "new test group", + }, + AddRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[1].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[2].ID), + }, + { + Id: utils.NewStringValue(groupTest.users[3].ID), + }, + }, + }, + RemoveRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[1].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[5].ID), + }, + }, + }, + } + + resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) + + assert.True(t, resp.Responses[0].Code.GetValue() == v1.ExecuteSuccess, resp.Responses[0].Info.GetValue()) + }) + + t.Run("主账户更新不是自己负责的用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.newGroups[1], nil) + + req := &apisecurity.ModifyUserGroup{ + Id: utils.NewStringValue(groupTest.newGroups[0].ID), + Comment: &wrapperspb.StringValue{ + Value: "new test group", + }, + AddRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[0].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[2].ID), + }, + { + Id: utils.NewStringValue(groupTest.users[3].ID), + }, + }, + }, + RemoveRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[0].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[5].ID), + }, + }, + }, + } + + resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) + assert.True(t, resp.Responses[0].Code.GetValue() == v1.NotAllowedAccess, resp.Responses[0].Info.GetValue()) + }) + + t.Run("子账户更新用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + req := &apisecurity.ModifyUserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + Comment: &wrapperspb.StringValue{ + Value: "new test group", + }, + AddRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[2].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[2].ID), + }, + { + Id: utils.NewStringValue(groupTest.users[3].ID), + }, + }, + }, + RemoveRelations: &apisecurity.UserGroupRelation{ + GroupId: utils.NewStringValue(groupTest.groups[2].ID), + Users: []*apisecurity.User{ + { + Id: utils.NewStringValue(groupTest.users[5].ID), + }, + }, + }, + } + + resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) + assert.True(t, resp.Responses[0].GetCode().Value == v1.OperationRoleException, resp.Responses[0].Info.GetValue()) + }) + + t.Run("更新用户组-啥都没用动过", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil) + + req := &apisecurity.ModifyUserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + Comment: &wrapperspb.StringValue{Value: groupTest.groups[2].Comment}, + AddRelations: &apisecurity.UserGroupRelation{}, + RemoveRelations: &apisecurity.UserGroupRelation{}, + } + + resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) + assert.True(t, resp.Responses[0].GetCode().Value == v1.NoNeedUpdate, resp.Responses[0].Info.GetValue()) + }) + +} + +func Test_server_GetGroupToken(t *testing.T) { + t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) + + resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) + + resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[1].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_server_DeleteGroup(t *testing.T) { + + t.Run("正常删除用户组", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ + { + Id: utils.NewStringValue(groupTest.groups[0].ID), + }, + }) + + assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) + }) + + t.Run("删除用户组-用户组不存在", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(nil, nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ + { + Id: utils.NewStringValue(groupTest.groups[0].ID), + }, + }) + + assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) + }) + + t.Run("删除用户组-不是用户组的owner", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ + { + Id: utils.NewStringValue(groupTest.groups[0].ID), + }, + }) + + assert.True(t, batchResp.Responses[0].GetCode().Value == v1.NotAllowedAccess, batchResp.Responses[0].Info.GetValue()) + }) + + t.Run("删除用户组-非owner角色", func(t *testing.T) { + groupTest := newGroupTest(t) + + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ + { + Id: utils.NewStringValue(groupTest.groups[0].ID), + }, + }) + + assert.True(t, batchResp.Responses[0].GetCode().Value == v1.OperationRoleException, batchResp.Responses[0].Info.GetValue()) + }) + +} + +func Test_server_UpdateGroupToken(t *testing.T) { + t.Run("正常更新用户组Token的Enable状态", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).AnyTimes().Return(nil) + + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) + }) + + t.Run("非Owner角色更新用户组Token的Enable状态", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[2].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.Code.Value == v1.OperationRoleException, batchResp.Info.GetValue()) + }) + + t.Run("更新用户组Token的Enable状态-非group的owner", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + + batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.Code.Value == v1.NotAllowedAccess, batchResp.Info.GetValue()) + }) +} + +func Test_server_RefreshGroupToken(t *testing.T) { + t.Run("正常更新用户组Token的Enable状态", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[0], nil) + groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).Return(nil) + + batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) + }) + + t.Run("非Owner角色更新用户组Token的Enable状态", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[2].Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) + + batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.Code.Value == v1.OperationRoleException, batchResp.Info.GetValue()) + }) + + t.Run("更新用户组Token的Enable状态-非group的owner", func(t *testing.T) { + groupTest := newGroupTest(t) + t.Cleanup(func() { + groupTest.Clean() + }) + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) + + groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[0], nil).AnyTimes() + + batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ + Id: utils.NewStringValue(groupTest.groups[2].ID), + }) + + assert.True(t, batchResp.Code.Value == v1.NotAllowedAccess, batchResp.Info.GetValue()) + }) +} + +func Test_AuthServer_NormalOperateUserGroup(t *testing.T) { + suit := &AuthTestSuit{} + if err := suit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + suit.cleanAllAuthStrategy() + suit.cleanAllUser() + suit.cleanAllUserGroup() + suit.Destroy() + }) + + users := createApiMockUser(10, "test") + for i := range users { + users[i].Id = utils.NewStringValue(utils.NewUUID()) + } + + groups := createMockApiUserGroup([]*apisecurity.User{users[0]}) + + t.Run("正常创建用户组", func(t *testing.T) { + bresp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) + if !respSuccess(bresp) { + t.Fatal(bresp.GetInfo().GetValue()) + } + + _ = suit.CacheMgr().TestUpdate() + + resp := suit.UserServer().CreateGroup(suit.DefaultCtx, groups[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + groups[0].Id = utils.NewStringValue(resp.GetUserGroup().Id.Value) + }) + + t.Run("正常更新用户组", func(t *testing.T) { + + time.Sleep(time.Second) + + req := []*apisecurity.ModifyUserGroup{ + { + Id: utils.NewStringValue(groups[0].GetId().GetValue()), + Name: utils.NewStringValue(groups[0].GetName().GetValue()), + Comment: &wrapperspb.StringValue{ + Value: "update user group", + }, + AddRelations: &apisecurity.UserGroupRelation{ + Users: users[3:], + }, + }, + } + + resp := suit.UserServer().UpdateGroups(suit.DefaultCtx, req) + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + _ = suit.CacheMgr().TestUpdate() + + qresp := suit.UserServer().GetGroup(suit.DefaultCtx, groups[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + assert.Equal(t, req[0].GetComment().GetValue(), qresp.GetUserGroup().GetComment().GetValue()) + assert.Equal(t, len(users[3:])+1, len(qresp.GetUserGroup().GetRelation().GetUsers())) + }) + + t.Run("正常更新用户组Token", func(t *testing.T) { + resp := suit.UserServer().ResetGroupToken(suit.DefaultCtx, groups[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + _ = suit.CacheMgr().TestUpdate() + + qresp := suit.UserServer().GetGroupToken(suit.DefaultCtx, groups[0]) + if !respSuccess(qresp) { + t.Fatal(resp.GetInfo().GetValue()) + } + assert.Equal(t, resp.GetUserGroup().GetAuthToken().GetValue(), qresp.GetUserGroup().GetAuthToken().GetValue()) + }) + + t.Run("正常查询某个用户组下的用户列表", func(t *testing.T) { + qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ + "group_id": groups[0].GetId().GetValue(), + }) + + if !respSuccess(qresp) { + t.Fatal(qresp.GetInfo().GetValue()) + } + + assert.Equal(t, 8, len(qresp.GetUsers())) + + expectUsers := []string{users[0].Id.Value} + for _, u := range users[3:] { + expectUsers = append(expectUsers, u.Id.Value) + } + + retUsers := []string{} + for i := range qresp.GetUsers() { + retUsers = append(retUsers, qresp.GetUsers()[i].Id.Value) + } + assert.ElementsMatch(t, expectUsers, retUsers) + }) + + t.Run("正常查询用户组列表", func(t *testing.T) { + qresp := suit.UserServer().GetGroups(suit.DefaultCtx, map[string]string{}) + + if !respSuccess(qresp) { + t.Fatal(qresp.GetInfo().GetValue()) + } + + assert.True(t, len(qresp.GetUserGroups()) == 1) + assert.Equal(t, groups[0].GetId().GetValue(), qresp.GetUserGroups()[0].Id.GetValue()) + }) + + t.Run("查询某个用户所在的所有分组", func(t *testing.T) { + qresp := suit.UserServer().GetGroups(suit.DefaultCtx, map[string]string{ + "user_id": users[0].GetId().GetValue(), + }) + + if !respSuccess(qresp) { + t.Fatal(qresp.GetInfo().GetValue()) + } + + assert.True(t, len(qresp.GetUserGroups()) == 1) + assert.Equal(t, groups[0].GetId().GetValue(), qresp.GetUserGroups()[0].Id.GetValue()) + }) + + t.Run("正常删除用户组", func(t *testing.T) { + resp := suit.UserServer().DeleteGroups(suit.DefaultCtx, groups) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + qresp := suit.UserServer().GetGroup(suit.DefaultCtx, groups[0]) + + if respSuccess(qresp) { + t.Fatal(qresp.GetInfo().GetValue()) + } + + assert.Equal(t, v1.NotFoundUserGroup, qresp.GetCode().GetValue()) + }) +} diff --git a/auth/user/inteceptor/auth/server.go b/auth/user/inteceptor/auth/server.go index 17f6e877c..6cc411516 100644 --- a/auth/user/inteceptor/auth/server.go +++ b/auth/user/inteceptor/auth/server.go @@ -28,7 +28,6 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" authmodel "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" @@ -47,10 +46,8 @@ type Server struct { } // Initialize 初始化 -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, - cacheMgr cachetypes.CacheManager) error { - svr.policySvr = policySvr - return svr.nextSvr.Initialize(authOpt, storage, policySvr, cacheMgr) +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policyMgr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { + return svr.nextSvr.Initialize(authOpt, storage, policyMgr, cacheMgr) } // Name 用户数据管理server名称 @@ -83,7 +80,7 @@ func (svr *Server) CreateUsers(ctx context.Context, users []*apisecurity.User) * ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.CreateUsers(authCtx.GetRequestContext(), users) } @@ -113,7 +110,7 @@ func (svr *Server) UpdateUser(ctx context.Context, user *apisecurity.User) *apis ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.UpdateUser(authCtx.GetRequestContext(), user) } @@ -143,7 +140,7 @@ func (svr *Server) UpdateUserPassword(ctx context.Context, req *apisecurity.Modi ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.UpdateUserPassword(authCtx.GetRequestContext(), req) } @@ -173,7 +170,7 @@ func (svr *Server) DeleteUsers(ctx context.Context, users []*apisecurity.User) * }), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.DeleteUsers(authCtx.GetRequestContext(), users) } @@ -187,7 +184,7 @@ func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apise authcommon.WithMethod(authcommon.DescribeUsers), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() query["hide_admin"] = strconv.FormatBool(true) @@ -197,7 +194,7 @@ func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apise query["owner"] = utils.ParseOwnerID(ctx) } - ctx = cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { + cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ Type: apisecurity.ResourceType_Users, ID: u.ID, @@ -232,7 +229,7 @@ func (svr *Server) GetUserToken(ctx context.Context, user *apisecurity.User) *ap ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.GetUserToken(authCtx.GetRequestContext(), user) } @@ -261,9 +258,9 @@ func (svr *Server) EnableUserToken(ctx context.Context, user *apisecurity.User) ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.EnableUserToken(authCtx.GetRequestContext(), user) + return svr.nextSvr.EnableUserToken(ctx, user) } // ResetUserToken 重置用户的token @@ -292,7 +289,7 @@ func (svr *Server) ResetUserToken(ctx context.Context, user *apisecurity.User) * if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.ResetUserToken(authCtx.GetRequestContext(), user) + return svr.nextSvr.ResetUserToken(ctx, user) } // CreateGroup 创建用户组 @@ -305,7 +302,7 @@ func (svr *Server) CreateGroup(ctx context.Context, group *apisecurity.UserGroup ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.CreateGroup(authCtx.GetRequestContext(), group) } @@ -337,7 +334,7 @@ func (svr *Server) UpdateGroups(ctx context.Context, groups []*apisecurity.Modif ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } return svr.nextSvr.UpdateGroups(authCtx.GetRequestContext(), groups) } @@ -367,9 +364,9 @@ func (svr *Server) DeleteGroups(ctx context.Context, groups []*apisecurity.UserG ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.DeleteGroups(authCtx.GetRequestContext(), groups) + return svr.nextSvr.DeleteGroups(ctx, groups) } // GetGroups 查询用户组列表(不带用户详细信息) @@ -381,27 +378,23 @@ func (svr *Server) GetGroups(ctx context.Context, query map[string]string) *apis authcommon.WithMethod(authcommon.DescribeUserGroups), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() - ctx = cachetypes.AppendUserGroupPredicate(ctx, func(ctx context.Context, u *authcommon.UserGroupDetail) bool { - ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ + if authcommon.ParseUserRole(ctx) != authmodel.AdminUserRole { + // step 1: 设置 owner 信息,只能查看归属主帐户下的用户组 + query["owner"] = utils.ParseOwnerID(ctx) + } + + cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ Type: apisecurity.ResourceType_UserGroups, ID: u.ID, Metadata: u.Metadata, }) - if ok { - return true - } - // 兼容老版本的策略查询逻辑 - if compatible, _ := ctx.Value(model.ContextKeyCompatible{}).(bool); compatible { - _, exist := u.UserIds[utils.ParseUserID(ctx)] - return exist - } - return false }) - authCtx.SetRequestContext(ctx) - return svr.nextSvr.GetGroups(authCtx.GetRequestContext(), query) + delete(query, "owner") + return svr.nextSvr.GetGroups(ctx, query) } // GetGroup 根据用户组信息,查询该用户组下的用户相信 @@ -428,9 +421,9 @@ func (svr *Server) GetGroup(ctx context.Context, req *apisecurity.UserGroup) *ap ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetGroup(authCtx.GetRequestContext(), req) + return svr.nextSvr.GetGroup(ctx, req) } // GetGroupToken 获取用户组的 token @@ -457,9 +450,9 @@ func (svr *Server) GetGroupToken(ctx context.Context, group *apisecurity.UserGro ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetGroupToken(authCtx.GetRequestContext(), group) + return svr.nextSvr.GetGroupToken(ctx, group) } // EnableGroupToken 取消用户组的 token 使用 @@ -486,9 +479,9 @@ func (svr *Server) EnableGroupToken(ctx context.Context, group *apisecurity.User ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.EnableGroupToken(authCtx.GetRequestContext(), group) + return svr.nextSvr.EnableGroupToken(ctx, group) } // ResetGroupToken 重置用户组的 token @@ -515,7 +508,7 @@ func (svr *Server) ResetGroupToken(ctx context.Context, group *apisecurity.UserG ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.ResetGroupToken(authCtx.GetRequestContext(), group) + return svr.nextSvr.ResetGroupToken(ctx, group) } diff --git a/auth/user/main_test.go b/auth/user/main_test.go new file mode 100644 index 000000000..6ccb1a835 --- /dev/null +++ b/auth/user/main_test.go @@ -0,0 +1,215 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package defaultuser_test + +import ( + "errors" + + _ "github.com/go-sql-driver/mysql" + bolt "go.etcd.io/bbolt" + + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/cache" + _ "github.com/polarismesh/polaris/cache" + api "github.com/polarismesh/polaris/common/api/v1" + commonlog "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/namespace" + "github.com/polarismesh/polaris/plugin" + _ "github.com/polarismesh/polaris/plugin/cmdb/memory" + _ "github.com/polarismesh/polaris/plugin/discoverevent/local" + _ "github.com/polarismesh/polaris/plugin/healthchecker/memory" + _ "github.com/polarismesh/polaris/plugin/healthchecker/redis" + _ "github.com/polarismesh/polaris/plugin/history/logger" + _ "github.com/polarismesh/polaris/plugin/password" + _ "github.com/polarismesh/polaris/plugin/ratelimit/token" + _ "github.com/polarismesh/polaris/plugin/statis/logger" + _ "github.com/polarismesh/polaris/plugin/statis/prometheus" + "github.com/polarismesh/polaris/service/healthcheck" + "github.com/polarismesh/polaris/store" + "github.com/polarismesh/polaris/store/boltdb" + _ "github.com/polarismesh/polaris/store/boltdb" + _ "github.com/polarismesh/polaris/store/mysql" + sqldb "github.com/polarismesh/polaris/store/mysql" + testsuit "github.com/polarismesh/polaris/test/suit" +) + +const ( + tblUser string = "user" + tblStrategy string = "strategy" + tblGroup string = "group" +) + +type Bootstrap struct { + Logger map[string]*commonlog.Options +} + +type TestConfig struct { + Bootstrap Bootstrap `yaml:"bootstrap"` + Cache cache.Config `yaml:"cache"` + Namespace namespace.Config `yaml:"namespace"` + HealthChecks healthcheck.Config `yaml:"healthcheck"` + Store store.Config `yaml:"store"` + Auth auth.Config `yaml:"auth"` + Plugin plugin.Config `yaml:"plugin"` +} + +type AuthTestSuit struct { + testsuit.DiscoverTestSuit +} + +// 判断一个resp是否执行成功 +func respSuccess(resp api.ResponseMessage) bool { + + ret := api.CalcCode(resp) == 200 + + return ret +} + +type options func(cfg *TestConfig) + +func (d *AuthTestSuit) cleanAllUser() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from user where name like 'test%'"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblUser)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} + +func (d *AuthTestSuit) cleanAllUserGroup() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from user_group where name like 'test%'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from user_group_relation"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblGroup)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} + +func (d *AuthTestSuit) cleanAllAuthStrategy() { + if d.Storage.Name() == sqldb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) + + defer dbTx.Rollback() + + if _, err := dbTx.Exec("delete from auth_strategy where id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from auth_principal where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + if _, err := dbTx.Exec("delete from auth_strategy_resource where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { + dbTx.Rollback() + panic(err) + } + + dbTx.Commit() + }() + } else if d.Storage.Name() == boltdb.STORENAME { + func() { + tx, err := d.Storage.StartTx() + if err != nil { + panic(err) + } + + dbTx := tx.GetDelegateTx().(*bolt.Tx) + defer dbTx.Rollback() + + if err := dbTx.DeleteBucket([]byte(tblStrategy)); err != nil { + if !errors.Is(err, bolt.ErrBucketNotFound) { + panic(err) + } + } + + dbTx.Commit() + }() + } +} diff --git a/auth/user/server.go b/auth/user/server.go index dd7254a3f..db1484327 100644 --- a/auth/user/server.go +++ b/auth/user/server.go @@ -76,8 +76,7 @@ func (svr *Server) Name() string { return auth.DefaultUserMgnPluginName } -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, - cacheMgr cachetypes.CacheManager) error { +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { svr.cacheMgr = cacheMgr svr.storage = storage svr.policySvr = policySvr diff --git a/auth/user/user.go b/auth/user/user.go index 5852dacd7..54535f9f0 100644 --- a/auth/user/user.go +++ b/auth/user/user.go @@ -423,13 +423,7 @@ func (svr *Server) ResetUserToken(ctx context.Context, req *apisecurity.User) *a // step 3. 兜底措施:如果开启了鉴权的非严格模式,则根据错误的类型,判断是否转为匿名用户进行访问 // - 如果是访问权限控制相关模块(用户、用户组、权限策略),不得转为匿名用户 func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { - // 如果已经存在了解析出的 token 信息,则直接返回 - if _, ok := authCtx.GetAttachment(authcommon.TokenDetailInfoKey); ok { - return nil - } - checkErr := func() error { - // authToken := utils.ParseAuthToken(authCtx.GetRequestContext()) operator, err := svr.decodeToken(authToken) if err != nil { @@ -465,8 +459,7 @@ func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { log.Warn("[Auth][Checker] parse operator info, downgrade to anonymous", utils.RequestID(authCtx.GetRequestContext()), zap.Error(checkErr)) // 操作者信息解析失败,降级为匿名用户 - authCtx.SetAttachment(authcommon.PrincipalKey, authcommon.NewAnonymousPrincipal()) - authCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.NewAnonymousOperatorInfo()) + authCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.NewAnonymous()) } return nil } diff --git a/auth/user/user_test.go b/auth/user/user_test.go new file mode 100644 index 000000000..3df71acd2 --- /dev/null +++ b/auth/user/user_test.go @@ -0,0 +1,1009 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package defaultuser_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/golang/protobuf/ptypes/wrappers" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/stretchr/testify/assert" + + "github.com/polarismesh/polaris/auth" + defaultuser "github.com/polarismesh/polaris/auth/user" + "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + commonlog "github.com/polarismesh/polaris/common/log" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" + storemock "github.com/polarismesh/polaris/store/mock" +) + +type UserTest struct { + admin *authcommon.User + ownerOne *authcommon.User + ownerTwo *authcommon.User + + users []*authcommon.User + newUsers []*authcommon.User + groups []*authcommon.UserGroupDetail + newGroups []*authcommon.UserGroupDetail + allGroups []*authcommon.UserGroupDetail + + storage *storemock.MockStore + cacheMgn *cache.CacheManager + + svr auth.UserServer + + cancel context.CancelFunc + ctrl *gomock.Controller +} + +func newUserTest(t *testing.T) *UserTest { + ctrl := gomock.NewController(t) + + commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName).SetOutputLevel(commonlog.DebugLevel) + commonlog.GetScopeOrDefaultByName(commonlog.ConfigLoggerName).SetOutputLevel(commonlog.DebugLevel) + + users := createMockUser(10, "one") + newUsers := createMockUser(10, "two") + admin := createMockUser(1, "admin")[0] + admin.Type = authcommon.AdminUserRole + admin.Owner = "" + groups := createMockUserGroup(users) + + storage := storemock.NewMockStore(ctrl) + storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) + storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) + storage.EXPECT().AddUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().GetUserByName(gomock.Eq("create-user-1"), gomock.Any()).AnyTimes().Return(nil, nil) + storage.EXPECT().GetUserByName(gomock.Eq("create-user-2"), gomock.Any()).AnyTimes().Return(&authcommon.User{ + Name: "create-user-2", + }, nil) + + allUsers := append(append(users, newUsers...), admin) + + storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allUsers, nil) + storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) + storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) + storage.EXPECT().DeleteUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + cfg := &cache.Config{} + ctx, cancel := context.WithCancel(context.Background()) + cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) + if err != nil { + t.Fatal(err) + } + + _ = cacheMgn.OpenResourceCache( + []cachetypes.ConfigEntry{ + { + Name: cachetypes.UsersName, + }, + }..., + ) + time.Sleep(5 * time.Second) + + _ = cache.TestRun(ctx, cacheMgn) + + _, proxySvr, err := defaultuser.BuildServer() + if err != nil { + t.Fatal(err) + } + proxySvr.Initialize(&auth.Config{ + User: &auth.UserConfig{ + Name: auth.DefaultUserMgnPluginName, + Option: map[string]interface{}{ + "salt": "polarismesh@2021", + }, + }, + }, storage, nil, cacheMgn) + + _ = cacheMgn.TestUpdate() + + return &UserTest{ + admin: admin, + ownerOne: users[0], + ownerTwo: newUsers[0], + + users: users, + newUsers: newUsers, + groups: groups, + + storage: storage, + cacheMgn: cacheMgn, + svr: proxySvr, + + cancel: cancel, + ctrl: ctrl, + } +} + +func (g *UserTest) Clean() { + g.ctrl.Finish() + g.cancel() + _ = g.cacheMgn.Close() + time.Sleep(2 * time.Second) +} + +func Test_server_CreateUsers(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("主账户创建账户-成功", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-1"}, + Password: &wrappers.StringValue{Value: "create-user-1"}, + }, + } + + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "create users must success") + }) + + t.Run("主账户创建账户-无用户名-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.InvalidUserName, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-密码错误-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-1"}, + Password: &wrappers.StringValue{Value: ""}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.InvalidUserPassword, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-同名用户-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-2"}, + Password: &wrappers.StringValue{Value: "create-user-2"}, + }, + } + + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.UserExisted, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-与主账户同名", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: userTest.ownerOne.Name}, + Password: &wrappers.StringValue{Value: "create-user-2"}, + }, + } + + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.UserExisted, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-token为空-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-2"}, + Password: &wrappers.StringValue{Value: "create-user-2"}, + }, + } + + resp := userTest.svr.CreateUsers(context.Background(), createUsersReq) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.EmptyAutToken, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-token非法-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-2"}, + Password: &wrappers.StringValue{Value: "create-user-2"}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "utils.ContextAuthTokenKey") + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.AuthTokenVerifyException, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("主账户创建账户-token被禁用-失败", func(t *testing.T) { + userTest.users[0].TokenEnable = false + // 让 cache 可以刷新到 + time.Sleep(time.Second) + + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-2"}, + Password: &wrappers.StringValue{Value: "create-user-2"}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.TokenDisabled, resp.Responses[0].Code.GetValue(), "create users must fail") + + userTest.users[0].TokenEnable = true + time.Sleep(time.Second) + }) + + t.Run("子主账户创建账户-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-1"}, + Password: &wrappers.StringValue{Value: "create-user-1"}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), "create users must fail") + }) + + t.Run("用户组token创建账户-失败", func(t *testing.T) { + createUsersReq := []*apisecurity.User{ + { + Id: &wrappers.StringValue{Value: utils.NewUUID()}, + Name: &wrappers.StringValue{Value: "create-user-1"}, + Password: &wrappers.StringValue{Value: "create-user-1"}, + }, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.groups[1].Token) + resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) + + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), "create users must fail") + }) +} + +func Test_server_Login(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("正常登陆", func(t *testing.T) { + rsp := userTest.svr.Login(&apisecurity.LoginRequest{ + Name: &wrappers.StringValue{Value: userTest.users[0].Name}, + Password: &wrappers.StringValue{Value: "polaris"}, + }) + + assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + + t.Run("错误的密码", func(t *testing.T) { + rsp := userTest.svr.Login(&apisecurity.LoginRequest{ + Name: &wrappers.StringValue{Value: userTest.users[0].Name}, + Password: &wrappers.StringValue{Value: "polaris_123"}, + }) + + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_NotAllowedAccess), rsp.GetCode().GetValue()) + assert.Contains(t, rsp.GetInfo().GetValue(), authcommon.ErrorWrongUsernameOrPassword.Error()) + }) +} + +func Test_server_UpdateUser(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("主账户更新账户信息-正常更新自己的信息", func(t *testing.T) { + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: userTest.users[0].ID}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") + }) + + t.Run("主账户更新账户信息-更新不存在的子账户", func(t *testing.T) { + uid := utils.NewUUID() + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: uid}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(nil, nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.NotFoundUser, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("主账户更新账户信息-更新不属于自己的子账户", func(t *testing.T) { + uid := utils.NewUUID() + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: uid}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ + ID: uid, + Owner: utils.NewUUID(), + }, nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户信息-正常更新自己的信息", func(t *testing.T) { + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: userTest.users[1].ID}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户信息-更新别的账户", func(t *testing.T) { + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: userTest.users[2].ID}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("用户组Token更新账户信息-更新别的账户", func(t *testing.T) { + req := &apisecurity.User{ + Id: &wrappers.StringValue{Value: userTest.users[2].ID}, + Comment: &wrappers.StringValue{Value: "update owner account info"}, + } + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.groups[1].Token) + resp := userTest.svr.UpdateUser(reqCtx, req) + + t.Logf("UpdateUsers resp : %+v", resp) + assert.Equal(t, api.OperationRoleException, resp.Code.GetValue(), "update user must fail") + }) +} + +func Test_server_UpdateUserPassword(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("主账户正常更新自身账户密码", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[0].ID}, + OldPassword: &wrappers.StringValue{Value: "polaris"}, + NewPassword: &wrappers.StringValue{Value: "polaris@2021"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") + }) + + t.Run("主账户正常更新自身账户密码-新密码非法", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[0].ID}, + OldPassword: &wrappers.StringValue{Value: "polaris"}, + NewPassword: &wrappers.StringValue{Value: "pola"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") + + req = &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[0].ID}, + OldPassword: &wrappers.StringValue{Value: "polaris"}, + NewPassword: &wrappers.StringValue{Value: ""}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp = userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") + + req = &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[0].ID}, + OldPassword: &wrappers.StringValue{Value: "polaris"}, + NewPassword: &wrappers.StringValue{Value: "polarispolarispolarispolaris"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp = userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("主账户正常更新子账户密码", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[1].ID}, + NewPassword: &wrappers.StringValue{Value: "polaris@sub"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") + }) + + t.Run("主账户正常更新子账户密码-子账户非自己", func(t *testing.T) { + + uid := utils.NewUUID() + + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: uid}, + NewPassword: &wrappers.StringValue{Value: "polaris@subaccount"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ + ID: uid, + Owner: utils.NewUUID(), + }, nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户密码-自身-携带正确原密码", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[2].ID}, + OldPassword: &wrappers.StringValue{Value: "polaris"}, + NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[2].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户密码-自身-携带错误原密码", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[1].ID}, + OldPassword: &wrappers.StringValue{Value: "users[1].Password"}, + NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户密码-自身-无携带原密码", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[1].ID}, + NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") + }) + + t.Run("子账户更新账户密码-不是自己", func(t *testing.T) { + req := &apisecurity.ModifyUserPassword{ + Id: &wrappers.StringValue{Value: userTest.users[2].ID}, + NewPassword: &wrappers.StringValue{Value: "users[2].Password"}, + } + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.UpdateUserPassword(reqCtx, req) + t.Logf("CreateUsers resp : %+v", resp) + assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") + }) +} + +func Test_server_DeleteUser(t *testing.T) { + t.Run("主账户删除自己", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }, + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("主账户删除另外一个主账户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + uid := utils.NewUUID() + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ + ID: uid, + Type: authcommon.OwnerUserRole, + Owner: "", + }, nil) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + { + Id: utils.NewStringValue(uid), + }, + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("主账户删除自己的子账户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.users[1].ID)).Return(userTest.users[1], nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + { + Id: utils.NewStringValue(userTest.users[1].ID), + }, + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户删除不是自己的子账户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + uid := utils.NewUUID() + oid := utils.NewUUID() + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ + ID: uid, + Type: authcommon.OwnerUserRole, + Owner: oid, + }, nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + &apisecurity.User{ + Id: utils.NewStringValue(uid), + }, + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("管理员删除主账户-主账户下没有子账户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil).AnyTimes() + userTest.storage.EXPECT().GetSubCount(gomock.Any()).Return(uint32(0), nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.admin.Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }, + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("管理员删除主账户-主账户下还有子账户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.ownerOne, nil).AnyTimes() + userTest.storage.EXPECT().GetSubCount(gomock.Any()).Return(uint32(1), nil).AnyTimes() + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.admin.Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }, + }) + + assert.True(t, resp.GetCode().Value == api.SubAccountExisted, resp.Info.GetValue()) + }) + + t.Run("子账户删除用户", func(t *testing.T) { + userTest := newUserTest(t) + t.Cleanup(func() { + userTest.Clean() + }) + + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ + &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }, + }) + + assert.True(t, resp.GetCode().Value == api.OperationRoleException, resp.Info.GetValue()) + }) +} + +func Test_server_GetUserToken(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("主账户查询自己的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("子账户查询自己的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + + resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户查询子账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户查询别的主账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.ownerTwo.ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("主账户查询不属于自己子账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.newUsers[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_server_RefreshUserToken(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("主账户刷新自己的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) + + resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[0].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("子账户刷新自己的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户刷新子账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) + resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户刷新别的主账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.ownerTwo, nil).AnyTimes() + resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.ownerTwo.ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("主账户刷新不属于自己子账户的Token", func(t *testing.T) { + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.newUsers[1], nil).AnyTimes() + resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.newUsers[1].ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_server_UpdateUserToken(t *testing.T) { + t.Run("主账户刷新自己的Token状态", func(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + _ = userTest.cacheMgn.TestUpdate() + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.ownerOne.ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) + + t.Run("子账户刷新自己的Token状态", func(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + _ = userTest.cacheMgn.TestUpdate() + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[4].Token) + + userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{}, nil).AnyTimes() + userTest.storage.EXPECT().UpdateUser(gomock.Any()).Return(nil).AnyTimes() + + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[4].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户刷新子账户的Token状态", func(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.users[3].ID)).Return(userTest.users[3], nil) + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.users[3].ID), + }) + + assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) + }) + + t.Run("主账户刷新别的主账户的Token状态", func(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + + t.Logf("operator-id : %s, user-two-owner : %s", userTest.ownerOne.ID, userTest.ownerTwo.ID) + + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerTwo.ID)).Return(userTest.ownerTwo, nil).AnyTimes() + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.ownerTwo.ID), + }) + + assert.Truef(t, resp.GetCode().Value == api.NotAllowedAccess, "code=%d, msg=%s", resp.Code.GetValue(), resp.Info.GetValue()) + }) + + t.Run("主账户刷新不属于自己子账户的Token状态", func(t *testing.T) { + userTest := newUserTest(t) + defer userTest.Clean() + + _ = userTest.cacheMgn.TestUpdate() + reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) + userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.newUsers[3].ID)).Return(userTest.newUsers[3], nil).AnyTimes() + resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ + Id: utils.NewStringValue(userTest.newUsers[3].ID), + }) + + assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) + }) +} + +func Test_AuthServer_NormalOperateUser(t *testing.T) { + suit := &AuthTestSuit{} + if err := suit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + suit.cleanAllAuthStrategy() + suit.cleanAllUser() + suit.cleanAllUserGroup() + suit.Destroy() + }) + + users := createApiMockUser(10, "test") + + t.Run("正常创建用户", func(t *testing.T) { + resp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + }) + + t.Run("非正常创建用户-直接操作存储层", func(t *testing.T) { + err := suit.Storage.AddUser(nil, &authcommon.User{}) + assert.Error(t, err) + }) + + t.Run("正常更新用户", func(t *testing.T) { + users[0].Comment = utils.NewStringValue("update user comment") + resp := suit.UserServer().UpdateUser(suit.DefaultCtx, users[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ + "id": users[0].GetId().GetValue(), + }) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + assert.Equal(t, 1, int(qresp.Amount.GetValue())) + assert.Equal(t, 1, int(qresp.Size.GetValue())) + + retUsers := qresp.GetUsers()[0] + assert.Equal(t, users[0].GetComment().GetValue(), retUsers.GetComment().GetValue()) + }) + + t.Run("正常删除用户", func(t *testing.T) { + resp := suit.UserServer().DeleteUsers(suit.DefaultCtx, []*apisecurity.User{users[3]}) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ + "id": users[3].GetId().GetValue(), + }) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + assert.Equal(t, 0, int(qresp.Amount.GetValue())) + assert.Equal(t, 0, int(qresp.Size.GetValue())) + }) + + t.Run("正常更新用户Token", func(t *testing.T) { + resp := suit.UserServer().ResetUserToken(suit.DefaultCtx, users[0]) + + if !respSuccess(resp) { + t.Fatal(resp.GetInfo().GetValue()) + } + + _ = suit.CacheMgr().TestUpdate() + + qresp := suit.UserServer().GetUserToken(suit.DefaultCtx, users[0]) + if !respSuccess(qresp) { + t.Fatal(resp.GetInfo().GetValue()) + } + assert.Equal(t, resp.GetUser().GetAuthToken().GetValue(), qresp.GetUser().GetAuthToken().GetValue()) + }) +} diff --git a/bootstrap/server.go b/bootstrap/server.go index f3f190632..117679ddb 100644 --- a/bootstrap/server.go +++ b/bootstrap/server.go @@ -76,7 +76,7 @@ func Start(configFilePath string) { fmt.Printf("[ERROR] config yaml marshal fail\n") return } - _, _ = fmt.Println(string(c)) + fmt.Printf(string(c)) // 初始化日志打印 err = log.Configure(cfg.Bootstrap.Logger) diff --git a/cache/api/funcs.go b/cache/api/funcs.go index 3a4e42ef4..0fa3d26a6 100644 --- a/cache/api/funcs.go +++ b/cache/api/funcs.go @@ -314,7 +314,7 @@ func AppendAuthPolicyPredicate(ctx context.Context, p AuthPolicyPredicate) conte func LoadAuthPolicyPredicates(ctx context.Context) []AuthPolicyPredicate { var predicates []AuthPolicyPredicate - val := ctx.Value(authPolicyPredicateCtxKey{}) + val := ctx.Value(userGroupPredicateCtxKey{}) if val != nil { predicates, _ = val.([]AuthPolicyPredicate) } diff --git a/cache/api/types.go b/cache/api/types.go index 506d63aac..d000f1828 100644 --- a/cache/api/types.go +++ b/cache/api/types.go @@ -172,15 +172,6 @@ type CacheManager interface { type ( // NamespacePredicate . NamespacePredicate func(context.Context, *model.Namespace) bool - // NamespaceArgs - NamespaceArgs struct { - // Filter extend filter params - Filter map[string][]string - // Offset - Offset int - // Limit - Limit int - } // NamespaceCache 命名空间的 Cache 接口 NamespaceCache interface { @@ -193,8 +184,6 @@ type ( GetNamespaceList() []*model.Namespace // GetVisibleNamespaces list target namespace can visible other namespaces GetVisibleNamespaces(namespace string) []*model.Namespace - // Query . - Query(context.Context, *NamespaceArgs) (uint32, []*model.Namespace, error) } ) @@ -257,9 +246,9 @@ type ( GetServicesByFilter(ctx context.Context, serviceFilters *ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) // ListServices get service list and revision by namespace - ListServices(ctx context.Context, ns string) (string, []*model.Service) + ListServices(ns string) (string, []*model.Service) // ListAllServices get all service and revision - ListAllServices(ctx context.Context) (string, []*model.Service) + ListAllServices() (string, []*model.Service) // ListServiceAlias list service link alias list ListServiceAlias(namespace, name string) []*model.Service // GetAliasFor get alias reference service info @@ -328,10 +317,30 @@ type ( FaultDetectArgs struct { // Filter extend filter params Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + ServiceNamespace string + DstNamespace string + DstService string + DstMethod string + // Enable + Enable *bool // Offset Offset uint32 // Limit Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []FaultDetectPredicate } // FaultDetectCache fault detect rule cache service @@ -353,10 +362,26 @@ type ( LaneGroupArgs struct { // Filter extend filter params Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + // Enable + Enable *bool // Offset Offset uint32 // Limit Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []FaultDetectPredicate } // LaneCache . LaneCache interface { @@ -403,6 +428,8 @@ type ( OrderField string // OrderType Sorting rules OrderType string + // Predicates 额外的数据检查 + Predicates []RouteRulePredicate } // RouterRuleIterProc Method definition of routing rules @@ -457,6 +484,8 @@ type ( OrderField string // OrderType Sorting rules OrderType string + // Predicates . + Predicates []RateLimitRulePredicate } // RateLimitIterProc rate limit iter func @@ -502,10 +531,34 @@ type ( CircuitBreakerRuleArgs struct { // Filter extend filter params Filter map[string]string + // ID route rule id + ID string + // Name route rule name + Name string + // Service service name + Service string + // Namespace namesapce + Namespace string + // SourceService source service name + SourceService string + // SourceNamespace source service namespace + SourceNamespace string + // DestinationService destination service name + DestinationService string + // DestinationNamespace destination service namespace + DestinationNamespace string + // Enable + Enable *bool // Offset Offset uint32 // Limit Limit uint32 + // OrderField Sort field + OrderField string + // OrderType Sorting rules + OrderType string + // Predicates 额外的数据检查 + Predicates []CircuitBreakerPredicate } // CircuitBreakerCache circuitBreaker配置的cache接口 CircuitBreakerCache interface { @@ -655,12 +708,10 @@ type ( // StrategyCache is a cache for strategy rules. StrategyCache interface { Cache - // GetPolicyRule 获取策略信息 - GetPolicyRule(id string) *authcommon.StrategyDetail // GetPrincipalPolicies 根据 effect 获取 principal 的策略信息 GetPrincipalPolicies(effect string, p authcommon.Principal) []*authcommon.StrategyDetail // Hint 确认某个 principal 对于资源的访问权限 - Hint(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction + Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction // Query . Query(context.Context, PolicySearchArgs) (uint32, []*authcommon.StrategyDetail, error) } diff --git a/cache/auth/policy.go b/cache/auth/policy.go index a7262dd65..771d1bb86 100644 --- a/cache/auth/policy.go +++ b/cache/auth/policy.go @@ -50,6 +50,9 @@ type policyCache struct { // principalResources principalResources map[authcommon.PrincipalType]*utils.SyncMap[string, *authcommon.PrincipalResourceContainer] + allowResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] + denyResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] + singleFlight *singleflight.Group } @@ -86,6 +89,8 @@ func (sc *policyCache) initContainers() { authcommon.PrincipalUser: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), authcommon.PrincipalGroup: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), } + sc.allowResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() + sc.denyResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() } func (sc *policyCache) Name() string { @@ -128,8 +133,7 @@ func (sc *policyCache) setStrategys(strategies []*authcommon.StrategyDetail) (ma for index := range strategies { rule := strategies[index] - cacheData := authcommon.NewPolicyDetailCache(rule) - sc.handlePrincipalPolicies(cacheData) + sc.handlePrincipalPolicies(rule) if !rule.Valid { sc.rules.Delete(rule.ID) remove++ @@ -139,16 +143,17 @@ func (sc *policyCache) setStrategys(strategies []*authcommon.StrategyDetail) (ma } else { update++ } - sc.rules.Store(rule.ID, cacheData) + sc.rules.Store(rule.ID, authcommon.NewPolicyDetailCache(rule)) } lastMtime = int64(math.Max(float64(lastMtime), float64(rule.ModifyTime.Unix()))) } + return map[string]time.Time{sc.Name(): time.Unix(lastMtime, 0)}, add, update, remove } // handlePrincipalPolicies -func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.PolicyDetailCache) { +func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.StrategyDetail) { // 计算 uid -> auth rule principals := rule.Principals @@ -187,7 +192,7 @@ func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.PolicyDetailCach } } -func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule *authcommon.PolicyDetailCache, del bool) { +func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule *authcommon.StrategyDetail, del bool) { linkContainers := sc.allowPolicies[principal.PrincipalType] if rule.Action == apisecurity.AuthAction_DENY.String() { linkContainers = sc.denyPolicies[principal.PrincipalType] @@ -205,45 +210,27 @@ func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule * values.Add(rule.ID) } - principalResources, _ := sc.principalResources[principal.PrincipalType].ComputeIfAbsent(principal.PrincipalID, - func(k string) *authcommon.PrincipalResourceContainer { - return authcommon.NewPrincipalResourceContainer() - }) + principalResources, _ := sc.principalResources[principal.PrincipalType].ComputeIfAbsent(principal.PrincipalID, func(k string) *authcommon.PrincipalResourceContainer { + return authcommon.NewPrincipalResourceContainer() + }) - if oldRule, ok := sc.rules.Load(rule.ID); ok { - // 如果 action 不一致,则需要先清理掉之前的 - if oldRule.GetAction() != rule.GetAction() { - for i := range oldRule.Resources { - principalResources.DelResource(oldRule.GetAction(), oldRule.Resources[i]) - } - } else { - // 如果 action 一致,那么需要 diff 出移除的资源,然后移除 - waitRemove := make([]*authcommon.StrategyResource, 0, 8) - for i := range oldRule.Resources { - item := oldRule.Resources[i] - resContainer, ok := rule.ResourceDict[apisecurity.ResourceType(item.ResType)] - if !ok { - waitRemove = append(waitRemove, &item) - continue - } - if ok := resContainer.Contains(item.ResID); !ok { - waitRemove = append(waitRemove, &item) - } - } - for i := range waitRemove { - item := waitRemove[i] - principalResources.DelResource(rule.GetAction(), *item) + if rule.IsDeny() { + for i := range rule.Resources { + item := rule.Resources[i] + if rule.Valid { + principalResources.SaveDenyResource(item) + } else { + principalResources.DelDenyResource(item) } } + return } - - // 处理新的资源 for i := range rule.Resources { item := rule.Resources[i] if rule.Valid { - principalResources.SaveResource(rule.GetAction(), item) + principalResources.SaveAllowResource(item) } else { - principalResources.DelResource(rule.GetAction(), item) + principalResources.DelAllowResource(item) } } } @@ -284,17 +271,8 @@ func (sc *policyCache) GetPrincipalPolicies(effect string, p authcommon.Principa return result } -func (sc *policyCache) GetPolicyRule(id string) *authcommon.StrategyDetail { - strategy, ok := sc.rules.Load(id) - if !ok { - return nil - } - return strategy.StrategyDetail -} - // GetPrincipalResources 返回 principal 的资源信息,返回顺序为 (allow, deny) -func (sc *policyCache) Hint(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction { - // 先比较下资源是否存在于某些鉴权规则中 +func (sc *policyCache) Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction { resources, ok := sc.principalResources[p.PrincipalType].Load(p.PrincipalID) if !ok { return apisecurity.AuthAction_DENY @@ -303,52 +281,30 @@ func (sc *policyCache) Hint(ctx context.Context, p authcommon.Principal, r *auth if ok { return action } - // 如果没办法从直接的 resource 中判断出来,那就根据资源标签在确认下,注意,这里必须 allMatch 才可以 - if sc.hintLabels(ctx, p, r, sc.GetPrincipalPolicies("deny", p)) { + if sc.hintLabels(p, r, sc.denyResourceLabels) { return apisecurity.AuthAction_DENY } - if sc.hintLabels(ctx, p, r, sc.GetPrincipalPolicies("allow", p)) { + if sc.hintLabels(p, r, sc.allowResourceLabels) { return apisecurity.AuthAction_ALLOW } return apisecurity.AuthAction_DENY } -func (sc *policyCache) hintLabels(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry, - policies []*authcommon.StrategyDetail) bool { - var principalCondition []authcommon.Condition - if val, ok := ctx.Value(authcommon.ContextKeyConditions{}).([]authcommon.Condition); ok { - principalCondition = val - } - - for i := range policies { - item := policies[i] - conditions := item.Conditions - if len(conditions) == 0 { - conditions = principalCondition +func (sc *policyCache) hintLabels(p authcommon.Principal, r *authcommon.ResourceEntry, + containers *utils.SyncMap[string, *utils.RefSyncSet[string]]) bool { + allMatch := true + for k, v := range r.Metadata { + labelVals, ok := sc.denyResourceLabels.Load(k) + if !ok { + allMatch = false } - allMatch := len(conditions) != 0 - for j := range conditions { - condition := conditions[j] - val, ok := r.Metadata[condition.Key] - if !ok { - allMatch = false - break - } - if compareFunc, ok := authcommon.ConditionCompareDict[condition.CompareFunc]; ok { - if allMatch = compareFunc(val, condition.Value); !allMatch { - break - } - } else { - allMatch = false - break - } - } - if allMatch { - return true + allMatch = labelVals.Contains(v) + if !allMatch { + break } } - return false + return allMatch } // Query implements api.StrategyCache. @@ -362,9 +318,9 @@ func (sc *policyCache) Query(ctx context.Context, args types.PolicySearchArgs) ( searchOwner, hasOwner := args.Filters["owner"] searchDefault, hasDefault := args.Filters["default"] searchResType, hasResType := args.Filters["res_type"] - searchResID := args.Filters["res_id"] + searchResID, _ := args.Filters["res_id"] searchPrincipalId, hasPrincipalId := args.Filters["principal_id"] - searchPrincipalType := args.Filters["principal_type"] + searchPrincipalType, _ := args.Filters["principal_type"] predicates := types.LoadAuthPolicyPredicates(ctx) @@ -424,21 +380,32 @@ func (sc *policyCache) Query(ctx context.Context, args types.PolicySearchArgs) ( } rules = append(rules, val.StrategyDetail) }) - sort.Slice(rules, func(i, j int) bool { - return rules[i].ModifyTime.After(rules[j].ModifyTime) - }) + total, ret := sc.toPage(rules, args) return total, ret, nil } func (sc *policyCache) toPage(rules []*authcommon.StrategyDetail, args types.PolicySearchArgs) (uint32, []*authcommon.StrategyDetail) { - total := uint32(len(rules)) - if args.Offset >= total || args.Limit == 0 { - return total, nil + beginIndex := args.Offset + endIndex := beginIndex + args.Limit + totalCount := uint32(len(rules)) + + if totalCount == 0 { + return totalCount, []*authcommon.StrategyDetail{} + } + if beginIndex >= endIndex { + return totalCount, []*authcommon.StrategyDetail{} } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total + if beginIndex >= totalCount { + return totalCount, []*authcommon.StrategyDetail{} } - return total, rules[args.Offset:endIdx] + if endIndex > totalCount { + endIndex = totalCount + } + + sort.Slice(rules, func(i, j int) bool { + return rules[i].ModifyTime.After(rules[j].ModifyTime) + }) + + return totalCount, rules[beginIndex:endIndex] } diff --git a/cache/auth/role.go b/cache/auth/role.go index 80555412b..f249c5b66 100644 --- a/cache/auth/role.go +++ b/cache/auth/role.go @@ -21,13 +21,12 @@ import ( "context" "time" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" - types "github.com/polarismesh/polaris/cache/api" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" ) // NewRoleCache @@ -147,7 +146,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { if isDel { users := role.Users for i := range users { - container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].PrincipalID, + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -155,7 +154,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } groups := role.UserGroups for i := range groups { - container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].PrincipalID, + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -165,7 +164,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } users := role.Users for i := range users { - container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].PrincipalID, + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -173,7 +172,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } groups := role.UserGroups for i := range groups { - container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].PrincipalID, + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -227,8 +226,8 @@ func (r *roleCache) toPage(total uint32, roles []*authcommon.Role, args types.Ro if args.Limit == 0 { return total, roles } - start := args.Limit * args.Offset - end := args.Limit * (args.Offset + 1) + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset if start > total { return total, nil } @@ -240,11 +239,7 @@ func (r *roleCache) toPage(total uint32, roles []*authcommon.Role, args types.Ro // GetPrincipalRoles implements api.RoleCache. func (r *roleCache) GetPrincipalRoles(p authcommon.Principal) []*authcommon.Role { - roleContainers, ok := r.principalRoles[p.PrincipalType] - if !ok { - return nil - } - containers, ok := roleContainers.Load(p.PrincipalID) + containers, ok := r.principalRoles[p.PrincipalType].Load(p.PrincipalID) if !ok { return nil } diff --git a/cache/auth/user.go b/cache/auth/user.go index c8f989d04..1e7084875 100644 --- a/cache/auth/user.go +++ b/cache/auth/user.go @@ -21,7 +21,6 @@ import ( "context" "fmt" "math" - "sort" "sync/atomic" "time" @@ -377,7 +376,7 @@ func (uc *userCache) QueryUsers(ctx context.Context, args types.UserSearchArgs) if hasId && searchId != key { return } - if hasOwner && (val.Owner != searchOwner && val.ID != searchOwner) { + if hasOwner && val.Owner != searchOwner { return } if hasName && !utils.IsWildMatch(val.Name, searchName) { @@ -394,31 +393,28 @@ func (uc *userCache) QueryUsers(ctx context.Context, args types.UserSearchArgs) result = append(result, val) }) - sort.Slice(result, func(i, j int) bool { - return result[i].ModifyTime.After(result[j].ModifyTime) - }) total, ret := uc.listUsersPage(result, args) return total, ret, nil } func (uc *userCache) listUsersPage(users []*authcommon.User, args types.UserSearchArgs) (uint32, []*authcommon.User) { total := uint32(len(users)) - if args.Offset >= total || args.Limit == 0 { + if args.Limit == 0 { + return total, nil + } + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset + if start > total { return total, nil } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total + if end > total { + end = total } - return total, users[args.Offset:endIdx] + return total, users[start:end] } // QueryUserGroups . func (uc *userCache) QueryUserGroups(ctx context.Context, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail, error) { - if err := uc.Update(); err != nil { - return 0, nil, err - } - searchId, hasId := args.Filters["id"] searchName, hasName := args.Filters["name"] searchOwner, hasOwner := args.Filters["owner"] @@ -464,14 +460,13 @@ func (uc *userCache) QueryUserGroups(ctx context.Context, args types.UserGroupSe return total, ret, nil } -func (uc *userCache) listUserGroupsPage(groups []*authcommon.UserGroupDetail, - args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail) { +func (uc *userCache) listUserGroupsPage(groups []*authcommon.UserGroupDetail, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail) { total := uint32(len(groups)) if args.Limit == 0 { return total, nil } - start := args.Limit * args.Offset - end := args.Limit * (args.Offset + 1) + start := args.Limit * (args.Offset - 1) + end := args.Limit * args.Offset if start > total { return total, nil } diff --git a/cache/default.go b/cache/default.go index bde20521e..d89e03fce 100644 --- a/cache/default.go +++ b/cache/default.go @@ -110,7 +110,6 @@ func newCacheManager(ctx context.Context, cacheOpt *Config, storage store.Store) mgr.RegisterCacher(types.CacheRole, cacheauth.NewRoleCache(storage, mgr)) // 北极星SDK Client mgr.RegisterCacher(types.CacheClient, cacheclient.NewClientCache(storage, mgr)) - // 灰度规则 mgr.RegisterCacher(types.CacheGray, cachegray.NewGrayCache(storage, mgr)) if len(mgr.caches) != int(types.CacheLast) { diff --git a/cache/mock/cache_mock.go b/cache/mock/cache_mock.go index 6a8441864..330596ad5 100644 --- a/cache/mock/cache_mock.go +++ b/cache/mock/cache_mock.go @@ -581,22 +581,6 @@ func (mr *MockNamespaceCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockNamespaceCache)(nil).Name)) } -// Query mocks base method. -func (m *MockNamespaceCache) Query(arg0 context.Context, arg1 *api.NamespaceArgs) (uint32, []*model.Namespace, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Query", arg0, arg1) - ret0, _ := ret[0].(uint32) - ret1, _ := ret[1].([]*model.Namespace) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// Query indicates an expected call of Query. -func (mr *MockNamespaceCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockNamespaceCache)(nil).Query), arg0, arg1) -} - // Update mocks base method. func (m *MockNamespaceCache) Update() error { m.ctrl.T.Helper() @@ -845,18 +829,18 @@ func (mr *MockServiceCacheMockRecorder) IteratorServices(iterProc interface{}) * } // ListAllServices mocks base method. -func (m *MockServiceCache) ListAllServices(ctx context.Context) (string, []*model.Service) { +func (m *MockServiceCache) ListAllServices() (string, []*model.Service) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAllServices", ctx) + ret := m.ctrl.Call(m, "ListAllServices") ret0, _ := ret[0].(string) ret1, _ := ret[1].([]*model.Service) return ret0, ret1 } // ListAllServices indicates an expected call of ListAllServices. -func (mr *MockServiceCacheMockRecorder) ListAllServices(ctx interface{}) *gomock.Call { +func (mr *MockServiceCacheMockRecorder) ListAllServices() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllServices", reflect.TypeOf((*MockServiceCache)(nil).ListAllServices), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllServices", reflect.TypeOf((*MockServiceCache)(nil).ListAllServices)) } // ListServiceAlias mocks base method. @@ -874,18 +858,18 @@ func (mr *MockServiceCacheMockRecorder) ListServiceAlias(namespace, name interfa } // ListServices mocks base method. -func (m *MockServiceCache) ListServices(ctx context.Context, ns string) (string, []*model.Service) { +func (m *MockServiceCache) ListServices(ns string) (string, []*model.Service) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListServices", ctx, ns) + ret := m.ctrl.Call(m, "ListServices", ns) ret0, _ := ret[0].(string) ret1, _ := ret[1].([]*model.Service) return ret0, ret1 } // ListServices indicates an expected call of ListServices. -func (mr *MockServiceCacheMockRecorder) ListServices(ctx, ns interface{}) *gomock.Call { +func (mr *MockServiceCacheMockRecorder) ListServices(ns interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServices", reflect.TypeOf((*MockServiceCache)(nil).ListServices), ctx, ns) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServices", reflect.TypeOf((*MockServiceCache)(nil).ListServices), ns) } // Name mocks base method. @@ -2890,20 +2874,6 @@ func (mr *MockStrategyCacheMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStrategyCache)(nil).Close)) } -// GetPolicyRule mocks base method. -func (m *MockStrategyCache) GetPolicyRule(id string) *auth.StrategyDetail { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPolicyRule", id) - ret0, _ := ret[0].(*auth.StrategyDetail) - return ret0 -} - -// GetPolicyRule indicates an expected call of GetPolicyRule. -func (mr *MockStrategyCacheMockRecorder) GetPolicyRule(id interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPolicyRule", reflect.TypeOf((*MockStrategyCache)(nil).GetPolicyRule), id) -} - // GetPrincipalPolicies mocks base method. func (m *MockStrategyCache) GetPrincipalPolicies(effect string, p auth.Principal) []*auth.StrategyDetail { m.ctrl.T.Helper() diff --git a/cache/namespace/namespace.go b/cache/namespace/namespace.go index 41c8b931f..1447ddb36 100644 --- a/cache/namespace/namespace.go +++ b/cache/namespace/namespace.go @@ -18,9 +18,7 @@ package namespace import ( - "context" "math" - "sort" "time" "go.uber.org/zap" @@ -65,7 +63,9 @@ func (nsCache *namespaceCache) Initialize(c map[string]interface{}) error { // Update func (nsCache *namespaceCache) Update() error { // 多个线程竞争,只有一个线程进行更新 - err, _ := nsCache.singleUpdate() + _, err, _ := nsCache.updater.Do(nsCache.Name(), func() (interface{}, error) { + return nil, nsCache.DoCacheUpdate(nsCache.Name(), nsCache.realUpdate) + }) return err } @@ -83,6 +83,7 @@ func (nsCache *namespaceCache) realUpdate() (map[string]time.Time, int64, error) } func (nsCache *namespaceCache) setNamespaces(nsSlice []*model.Namespace) map[string]time.Time { + lastMtime := nsCache.LastMtime(nsCache.Name()).Unix() for index := range nsSlice { @@ -199,96 +200,3 @@ func (nsCache *namespaceCache) GetNamespaceList() []*model.Namespace { return nsArr } - -// forceQueryUpdate 为了确保读取的数据是最新的,这里需要做一个强制 update 的动作进行数据读取处理 -func (nsCache *namespaceCache) forceQueryUpdate() error { - err, shared := nsCache.singleUpdate() - // shared == true,表示当前已经有正在 update 执行的任务,这个任务不一定能够读取到最新的数据 - // 为了避免读取到脏数据,在发起一次 singleUpdate - if shared { - log.Debug("[Cache][Namespace] force query update from store") - err, _ = nsCache.singleUpdate() - } - return err -} - -func (nsCache *namespaceCache) singleUpdate() (error, bool) { - // 多个线程竞争,只有一个线程进行更新 - _, err, shared := nsCache.updater.Do(nsCache.Name(), func() (interface{}, error) { - return nil, nsCache.DoCacheUpdate(nsCache.Name(), nsCache.realUpdate) - }) - return err, shared -} - -func (nsCache *namespaceCache) Query(ctx context.Context, args *types.NamespaceArgs) (uint32, []*model.Namespace, error) { - if err := nsCache.forceQueryUpdate(); err != nil { - return 0, nil, err - } - - ret := make([]*model.Namespace, 0, 32) - - predicates := types.LoadNamespacePredicates(ctx) - - searchName, hasName := args.Filter["name"] - searchOwner, hasOwner := args.Filter["owner"] - - nsCache.ids.ReadRange(func(key string, val *model.Namespace) { - for i := range predicates { - if !predicates[i](ctx, val) { - return - } - } - - if hasName { - matchOne := false - for i := range searchName { - if utils.IsWildMatch(val.Name, searchName[i]) { - matchOne = true - break - } - } - // 如果没有匹配到,直接返回 - if !matchOne { - return - } - } - - if hasOwner { - matchOne := false - for i := range searchOwner { - if utils.IsWildMatch(val.Owner, searchOwner[i]) { - matchOne = true - break - } - } - // 如果没有匹配到,直接返回 - if !matchOne { - return - } - } - - ret = append(ret, val) - }) - - sort.Slice(ret, func(i, j int) bool { - return ret[i].ModifyTime.After(ret[j].ModifyTime) - }) - - total, ret := nsCache.toPage(len(ret), ret, args) - return uint32(total), ret, nil -} - -func (c *namespaceCache) toPage(total int, items []*model.Namespace, - args *types.NamespaceArgs) (int, []*model.Namespace) { - if len(items) == 0 { - return 0, []*model.Namespace{} - } - if args.Limit == 0 { - return total, items - } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total - } - return total, items[args.Offset:endIdx] -} diff --git a/cache/service/circuitbreaker.go b/cache/service/circuitbreaker.go index 03699fa8e..ff02db390 100644 --- a/cache/service/circuitbreaker.go +++ b/cache/service/circuitbreaker.go @@ -22,8 +22,6 @@ import ( "crypto/sha1" "fmt" "sort" - "strconv" - "strings" "sync" "time" @@ -394,155 +392,13 @@ func (c *circuitBreakerCache) GetCircuitBreakerCount() int { return len(names) } -var ( - ignoreCircuitBreakerRuleFilter = map[string]struct{}{ - "brief": {}, - "service": {}, - "serviceNamespace": {}, - "exactName": {}, - "excludeId": {}, - } - - cbBlurSearchFields = map[string]func(*model.CircuitBreakerRule) string{ - "name": func(cbr *model.CircuitBreakerRule) string { - return cbr.Name - }, - "description": func(cbr *model.CircuitBreakerRule) string { - return cbr.Description - }, - "srcservice": func(cbr *model.CircuitBreakerRule) string { - return cbr.SrcService - }, - "dstservice": func(cbr *model.CircuitBreakerRule) string { - return cbr.DstService - }, - "dstmethod": func(cbr *model.CircuitBreakerRule) string { - return cbr.DstMethod - }, - } - - circuitBreakerSort = map[string]func(asc bool, a, b *model.CircuitBreakerRule) bool{ - "mtime": func(asc bool, a, b *model.CircuitBreakerRule) bool { - ret := a.ModifyTime.Before(b.ModifyTime) - return ret && asc - }, - "id": func(asc bool, a, b *model.CircuitBreakerRule) bool { - ret := a.ID < b.ID - return ret && asc - }, - "name": func(asc bool, a, b *model.CircuitBreakerRule) bool { - ret := a.Name < b.Name - return ret && asc - }, - } -) - // Query implements api.CircuitBreakerCache. -func (c *circuitBreakerCache) Query(ctx context.Context, args *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { - if err := c.Update(); err != nil { - return 0, nil, err - } - - predicates := types.LoadCircuitBreakerRulePredicates(ctx) - - searchSvc, hasSvc := args.Filter["service"] - searchNs, hasSvcNs := args.Filter["serviceNamespace"] - exactNameValue, hasExactName := args.Filter["exactName"] - excludeIdValue, hasExcludeId := args.Filter["excludeId"] - - lowerFilter := make(map[string]string, len(args.Filter)) - for k, v := range args.Filter { - if _, ok := ignoreCircuitBreakerRuleFilter[k]; ok { - continue - } - lowerFilter[strings.ToLower(k)] = v - } - - results := make([]*model.CircuitBreakerRule, 0, 32) - c.rules.ReadRange(func(key string, val *model.CircuitBreakerRule) { - if hasSvcNs { - srcNsValue := val.SrcNamespace - dstNsValue := val.DstNamespace - if !((srcNsValue == "*" || srcNsValue == searchNs) || (dstNsValue == "*" || dstNsValue == searchNs)) { - return - } - } - if hasSvc { - srcSvcValue := val.SrcService - dstSvcValue := val.DstService - if !((srcSvcValue == searchSvc || srcSvcValue == "*") || (dstSvcValue == searchSvc || dstSvcValue == "*")) { - return - } - } - if hasExactName && exactNameValue != val.Name { - return - } - if hasExcludeId && excludeIdValue != val.ID { - return - } - for fieldKey, filterValue := range lowerFilter { - getter, isBlur := cbBlurSearchFields[fieldKey] - if isBlur { - if utils.IsWildMatch(getter(val), filterValue) { - return - } - } else if fieldKey == "enable" { - if filterValue != strconv.FormatBool(val.Enable) { - return - } - } else if fieldKey == "level" { - levels := strings.Split(filterValue, ",") - var inLevel = false - for _, level := range levels { - levelInt, _ := strconv.Atoi(level) - if int64(levelInt) == int64(val.Level) { - inLevel = true - break - } - } - if !inLevel { - return - } - } else { - // FIXME 暂时不知道还有什么字段查询需要适配,等待自测验证 - } - } - for i := range predicates { - if !predicates[i](ctx, val) { - return - } - } - - results = append(results, val) - }) - - sortFunc, ok := circuitBreakerSort[args.Filter["order_field"]] - if !ok { - sortFunc = circuitBreakerSort["mtime"] - } - asc := "asc" == strings.ToLower(args.Filter["order_type"]) - sort.Slice(results, func(i, j int) bool { - return sortFunc(asc, results[i], results[j]) - }) - - total, ret := c.toPage(uint32(len(results)), results, args) - return total, ret, nil -} - -func (c *circuitBreakerCache) toPage(total uint32, items []*model.CircuitBreakerRule, - args *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule) { - if args.Limit == 0 { - return total, items - } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total - } - return total, items[args.Offset:endIdx] +func (c *circuitBreakerCache) Query(context.Context, *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { + panic("unimplemented") } -// GetRule implements api.CircuitBreakerCache. -func (c *circuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { - rule, _ := c.rules.Load(id) +// GetRule implements api.FaultDetectCache. +func (f *circuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { + rule, _ := f.rules.Load(id) return rule } diff --git a/cache/service/faultdetect.go b/cache/service/faultdetect.go index 8e59792d6..92a88d9b7 100644 --- a/cache/service/faultdetect.go +++ b/cache/service/faultdetect.go @@ -22,7 +22,6 @@ import ( "crypto/sha1" "fmt" "sort" - "strings" "sync" "time" @@ -351,125 +350,9 @@ func (f *faultDetectCache) GetFaultDetectRuleCount(fun func(k, v interface{}) bo } } -var ( - ignoreFaultDetectRuleFilter = map[string]struct{}{ - "brief": {}, - "service": {}, - "serviceNamespace": {}, - "exactName": {}, - "excludeId": {}, - } - - fdBlurSearchFields = map[string]func(*model.FaultDetectRule) string{ - "name": func(cbr *model.FaultDetectRule) string { - return cbr.Name - }, - "description": func(cbr *model.FaultDetectRule) string { - return cbr.Description - }, - "dstservice": func(cbr *model.FaultDetectRule) string { - return cbr.DstService - }, - "dstmethod": func(cbr *model.FaultDetectRule) string { - return cbr.DstMethod - }, - } - - faultDetectSort = map[string]func(asc bool, a, b *model.FaultDetectRule) bool{ - "mtime": func(asc bool, a, b *model.FaultDetectRule) bool { - ret := a.ModifyTime.Before(b.ModifyTime) - return ret && asc - }, - "id": func(asc bool, a, b *model.FaultDetectRule) bool { - ret := a.ID < b.ID - return ret && asc - }, - "name": func(asc bool, a, b *model.FaultDetectRule) bool { - ret := a.Name < b.Name - return ret && asc - }, - } -) - // Query implements api.FaultDetectCache. -func (f *faultDetectCache) Query(ctx context.Context, args *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { - if err := f.Update(); err != nil { - return 0, nil, err - } - - results := make([]*model.FaultDetectRule, 0, 32) - - predicates := types.LoadFaultDetectRulePredicates(ctx) - - searchSvc, hasSvc := args.Filter["service"] - searchNs, hasSvcNs := args.Filter["serviceNamespace"] - exactNameValue, hasExactName := args.Filter["exactName"] - excludeIdValue, hasExcludeId := args.Filter["excludeId"] - - lowerFilter := make(map[string]string, len(args.Filter)) - for k, v := range args.Filter { - if _, ok := ignoreCircuitBreakerRuleFilter[k]; ok { - continue - } - lowerFilter[strings.ToLower(k)] = v - } - - f.rules.ReadRange(func(key string, val *model.FaultDetectRule) { - if hasSvc && hasSvcNs { - dstServiceValue := val.DstService - dstNamespaceValue := val.DstNamespace - if !(dstServiceValue == searchSvc && dstNamespaceValue == searchNs) { - return - } - } - if hasExactName && exactNameValue != val.Name { - return - } - if hasExcludeId && excludeIdValue != val.ID { - return - } - for fieldKey, filterValue := range lowerFilter { - getter, isBlur := fdBlurSearchFields[fieldKey] - if isBlur { - if utils.IsWildMatch(getter(val), filterValue) { - return - } - } else { - // FIXME 暂时不知道还有什么字段查询需要适配,等待自测验证 - } - } - for i := range predicates { - if !predicates[i](ctx, val) { - return - } - } - - results = append(results, val) - }) - - sortFunc, ok := faultDetectSort[args.Filter["order_field"]] - if !ok { - sortFunc = faultDetectSort["mtime"] - } - asc := "asc" == strings.ToLower(args.Filter["order_type"]) - sort.Slice(results, func(i, j int) bool { - return sortFunc(asc, results[i], results[j]) - }) - - total, ret := f.toPage(uint32(len(results)), results, args) - return total, ret, nil -} - -func (f *faultDetectCache) toPage(total uint32, items []*model.FaultDetectRule, - args *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule) { - if args.Limit == 0 { - return total, items - } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total - } - return total, items[args.Offset:endIdx] +func (f *faultDetectCache) Query(context.Context, *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { + panic("unimplemented") } // GetRule implements api.FaultDetectCache. diff --git a/cache/service/lane.go b/cache/service/lane.go index cb54cb805..7d5b31973 100644 --- a/cache/service/lane.go +++ b/cache/service/lane.go @@ -19,8 +19,6 @@ package service import ( "context" - "sort" - "strings" "time" "github.com/golang/protobuf/proto" @@ -147,8 +145,12 @@ func (lc *LaneCache) processLaneRuleUpsert(old, item *model.LaneGroupProto, affe waitDelServices[ns][svc] = struct{}{} } removeServiceIfExist := func(ns, svc string) { - waitDelServices[ns] = map[string]struct{}{} - delete(waitDelServices[ns], svc) + if _, ok := waitDelServices[ns]; !ok { + waitDelServices[ns] = map[string]struct{}{} + } + if _, ok := waitDelServices[ns][svc]; ok { + delete(waitDelServices[ns], svc) + } } handle := func(rule *model.LaneGroupProto, serviceOp func(ns, svc string), ruleOp func(string, string, *model.LaneGroupProto)) { @@ -354,80 +356,13 @@ func anyToSelector(data *anypb.Any, msg proto.Message) error { return nil } -var ( - laneGroupSort = map[string]func(asc bool, a, b *model.LaneGroupProto) bool{ - "mtime": func(asc bool, a, b *model.LaneGroupProto) bool { - ret := a.ModifyTime.Before(b.ModifyTime) - return ret && asc - }, - "id": func(asc bool, a, b *model.LaneGroupProto) bool { - ret := a.ID < b.ID - return ret && asc - }, - "name": func(asc bool, a, b *model.LaneGroupProto) bool { - ret := a.Name < b.Name - return ret && asc - }, - } -) - // Query implements api.LaneCache. -func (lc *LaneCache) Query(ctx context.Context, args *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { - if err := lc.Update(); err != nil { - return 0, nil, err - } - - predicates := types.LoadLaneRulePredicates(ctx) - - searchName, hasName := args.Filter["name"] - searchId, hasId := args.Filter["id"] - - results := make([]*model.LaneGroupProto, 0, 32) - - lc.rules.ReadRange(func(key string, val *model.LaneGroupProto) { - if hasName && !utils.IsWildMatch(val.Name, searchName) { - return - } - if hasId && val.ID != searchId { - return - } - - for i := range predicates { - if !predicates[i](ctx, val) { - return - } - } - - results = append(results, val) - }) - - sortFunc, ok := laneGroupSort[args.Filter["order_field"]] - if !ok { - sortFunc = laneGroupSort["mtime"] - } - asc := "asc" == strings.ToLower(args.Filter["order_type"]) - sort.Slice(results, func(i, j int) bool { - return sortFunc(asc, results[i], results[j]) - }) - - total, ret := lc.toPage(uint32(len(results)), results, args) - return total, ret, nil -} - -func (lc *LaneCache) toPage(total uint32, items []*model.LaneGroupProto, - args *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto) { - if args.Limit == 0 { - return total, items - } - endIdx := args.Offset + args.Limit - if endIdx > total { - endIdx = total - } - return total, items[args.Offset:endIdx] +func (lc *LaneCache) Query(context.Context, *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { + panic("unimplemented") } // GetRule implements api.LaneCache. -func (lc *LaneCache) GetRule(id string) *model.LaneGroup { - rule, _ := lc.rules.Load(id) +func (f *LaneCache) GetRule(id string) *model.LaneGroup { + rule, _ := f.rules.Load(id) return rule.LaneGroup } diff --git a/cache/service/ratelimit_query.go b/cache/service/ratelimit_query.go index 6673aff4e..53f643073 100644 --- a/cache/service/ratelimit_query.go +++ b/cache/service/ratelimit_query.go @@ -27,14 +27,20 @@ import ( "github.com/polarismesh/polaris/common/utils" ) +// forceUpdate 更新配置 +func (rlc *rateLimitCache) forceUpdate() error { + if err := rlc.Update(); err != nil { + return err + } + return nil +} + // QueryRateLimitRules func (rlc *rateLimitCache) QueryRateLimitRules(ctx context.Context, args types.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { - if err := rlc.Update(); err != nil { + if err := rlc.forceUpdate(); err != nil { return 0, nil, err } - predicates := types.LoadRatelimitRulePredicates(ctx) - hasService := len(args.Service) != 0 hasNamespace := len(args.Namespace) != 0 @@ -59,13 +65,6 @@ func (rlc *rateLimitCache) QueryRateLimitRules(ctx context.Context, args types.R if args.Disable != nil && *args.Disable != rule.Disable { return } - - for i := range predicates { - if !predicates[i](ctx, rule) { - return - } - } - res = append(res, rule) } rlc.IteratorRateLimit(process) diff --git a/cache/service/router_rule.go b/cache/service/router_rule.go index b146aec70..1529494bf 100644 --- a/cache/service/router_rule.go +++ b/cache/service/router_rule.go @@ -41,9 +41,13 @@ type ( container *RouteRuleContainer - lastMtime time.Time + lastMtimeV1 time.Time + lastMtimeV2 time.Time singleFlight singleflight.Group + + // waitDealV1RuleIds Records need to be converted from V1 to V2 routing rules ID + waitDealV1RuleIds *utils.SyncMap[string, *model.RoutingConfig] } ) @@ -57,7 +61,9 @@ func NewRouteRuleCache(s store.Store, cacheMgr types.CacheManager) types.Routing // initialize The function of implementing the cache interface func (rc *RouteRuleCache) Initialize(_ map[string]interface{}) error { - rc.lastMtime = time.Unix(0, 0) + rc.lastMtimeV1 = time.Unix(0, 0) + rc.lastMtimeV2 = time.Unix(0, 0) + rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() rc.container = newRouteRuleContainer() rc.serviceCache = rc.BaseCache.CacheMgr.GetCacher(types.CacheService).(*serviceCache) return nil @@ -74,6 +80,12 @@ func (rc *RouteRuleCache) Update() error { // update The function of implementing the cache interface func (rc *RouteRuleCache) realUpdate() (map[string]time.Time, int64, error) { + outV1, err := rc.storage.GetRoutingConfigsForCache(rc.LastFetchTime(), rc.IsFirstUpdate()) + if err != nil { + log.Errorf("[Cache] routing config v1 cache get from store err: %s", err.Error()) + return nil, -1, err + } + outV2, err := rc.storage.GetRoutingConfigsV2ForCache(rc.LastFetchTime(), rc.IsFirstUpdate()) if err != nil { log.Errorf("[Cache] routing config v2 cache get from store err: %s", err.Error()) @@ -81,16 +93,19 @@ func (rc *RouteRuleCache) realUpdate() (map[string]time.Time, int64, error) { } lastMtimes := map[string]time.Time{} - rc.setRouterRules(lastMtimes, outV2) + rc.setRoutingConfigV1(lastMtimes, outV1) + rc.setRoutingConfigV2(lastMtimes, outV2) rc.container.reload() - return lastMtimes, int64(len(outV2)), err + return lastMtimes, int64(len(outV1) + len(outV2)), err } // Clear The function of implementing the cache interface func (rc *RouteRuleCache) Clear() error { rc.BaseCache.Clear() + rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() rc.container = newRouteRuleContainer() - rc.lastMtime = time.Unix(0, 0) + rc.lastMtimeV1 = time.Unix(0, 0) + rc.lastMtimeV2 = time.Unix(0, 0) return nil } @@ -217,8 +232,50 @@ func (rc *RouteRuleCache) GetRule(id string) *model.ExtendRouterConfig { return rule } -// setRouterRules Store V2 Router Caches -func (rc *RouteRuleCache) setRouterRules(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { +// setRoutingConfigV1 Update the data of the store to the cache and convert to v2 model +func (rc *RouteRuleCache) setRoutingConfigV1(lastMtimes map[string]time.Time, cs []*model.RoutingConfig) { + if len(cs) == 0 { + return + } + lastMtimeV1 := rc.LastMtime(rc.Name()).Unix() + for _, entry := range cs { + if entry.ID == "" { + continue + } + if entry.ModifyTime.Unix() > lastMtimeV1 { + lastMtimeV1 = entry.ModifyTime.Unix() + } + if !entry.Valid { + // Delete the cache converted to V2 + rc.container.deleteV1(entry.ID) + continue + } + rc.waitDealV1RuleIds.Store(entry.ID, entry) + } + + rc.waitDealV1RuleIds.Range(func(key string, val *model.RoutingConfig) { + // Save to the new V2 cache + ok, rules, err := rc.convertV1toV2(val) + if err != nil { + log.Warn("[Cache] routing parse v1 => v2 fail, will try again next", + zap.String("rule-id", val.ID), zap.Error(err)) + return + } + if !ok { + log.Warn("[Cache] routing parse v1 => v2 is nil, will try again next", zap.String("rule-id", val.ID)) + return + } + if ok && len(rules) != 0 { + rc.waitDealV1RuleIds.Delete(key) + rc.container.saveV1(val, rules) + } + }) + lastMtimes[rc.Name()] = time.Unix(lastMtimeV1, 0) + log.Infof("[Cache] convert routing parse v1 => v2 count : %d", rc.container.convertV2Size()) +} + +// setRoutingConfigV2 Store V2 Router Caches +func (rc *RouteRuleCache) setRoutingConfigV2(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { if len(cs) == 0 { return } diff --git a/cache/service/router_rule_bucket.go b/cache/service/router_rule_bucket.go index f60a41eb4..5b32ea53f 100644 --- a/cache/service/router_rule_bucket.go +++ b/cache/service/router_rule_bucket.go @@ -38,7 +38,7 @@ type ServiceWithRouterRules struct { rules map[string]*model.ExtendRouterConfig revision string - customv1RuleRef *utils.AtomicValue[*apitraffic.Routing] + customv1Rules *apitraffic.Routing } func NewServiceWithRouterRules(svcKey model.ServiceKey, direction model.TrafficDirection) *ServiceWithRouterRules { @@ -51,20 +51,19 @@ func NewServiceWithRouterRules(svcKey model.ServiceKey, direction model.TrafficD // AddRouterRule 添加路由规则,注意,这里只会保留处于 Enable 状态的路由规则 func (s *ServiceWithRouterRules) AddRouterRule(rule *model.ExtendRouterConfig) { + if !rule.Enable { + return + } if rule.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { - s.customv1RuleRef = utils.NewAtomicValue[*apitraffic.Routing](&apitraffic.Routing{ + s.customv1Rules = &apitraffic.Routing{ Inbounds: []*apitraffic.Route{}, Outbounds: []*apitraffic.Route{}, - }) + } } s.mutex.Lock() defer s.mutex.Unlock() - if !rule.Enable { - delete(s.rules, rule.ID) - } else { - s.rules[rule.ID] = rule - } + s.rules[rule.ID] = rule } func (s *ServiceWithRouterRules) DelRouterRule(id string) { @@ -93,13 +92,6 @@ func (s *ServiceWithRouterRules) CountRouterRules() int { return len(s.rules) } -func (s *ServiceWithRouterRules) GetRouteRuleV1() *apitraffic.Routing { - if !s.customv1RuleRef.HasValue() { - return nil - } - return s.customv1RuleRef.Load() -} - func (s *ServiceWithRouterRules) Clear() { s.mutex.Lock() defer s.mutex.Unlock() @@ -143,7 +135,7 @@ func (s *ServiceWithRouterRules) reloadRevision() { } func (s *ServiceWithRouterRules) reloadV1Rules() { - if !s.customv1RuleRef.HasValue() { + if s.customv1Rules == nil { return } @@ -159,21 +151,19 @@ func (s *ServiceWithRouterRules) reloadV1Rules() { routes := make([]*apitraffic.Route, 0, 32) for i := range rules { - if rules[i].Policy != apitraffic.RoutingPolicy_RulePolicy.String() { + if rules[i].Priority != uint32(apitraffic.RoutingPolicy_RulePolicy) { continue } routes = append(routes, model.BuildRoutes(rules[i], s.direction)...) } - customv1Rules := &apitraffic.Routing{} + s.customv1Rules = &apitraffic.Routing{} switch s.direction { case model.TrafficDirection_INBOUND: - customv1Rules.Inbounds = routes + s.customv1Rules.Inbounds = routes case model.TrafficDirection_OUTBOUND: - customv1Rules.Outbounds = routes + s.customv1Rules.Outbounds = routes } - - s.customv1RuleRef.Store(customv1Rules) } func newClientRouteRuleContainer(direction model.TrafficDirection) *ClientRouteRuleContainer { @@ -186,9 +176,6 @@ func newClientRouteRuleContainer(direction model.TrafficDirection) *ClientRouteR } type ClientRouteRuleContainer struct { - // lock . - lock sync.RWMutex - direction model.TrafficDirection // key1: namespace, key2: service exactRules *utils.SyncMap[string, *ServiceWithRouterRules] @@ -201,9 +188,6 @@ type ClientRouteRuleContainer struct { func (c *ClientRouteRuleContainer) SearchRouteRuleV2(svc model.ServiceKey) []*model.ExtendRouterConfig { ret := make([]*model.ExtendRouterConfig, 0, 32) - c.lock.RLock() - defer c.lock.RUnlock() - exactRule, existExactRule := c.exactRules.Load(svc.Domain()) if existExactRule { exactRule.IterateRouterRules(func(erc *model.ExtendRouterConfig) { @@ -221,10 +205,6 @@ func (c *ClientRouteRuleContainer) SearchRouteRuleV2(svc model.ServiceKey) []*mo c.allWildcardRules.IterateRouterRules(func(erc *model.ExtendRouterConfig) { ret = append(ret, erc) }) - - sort.Slice(ret, func(i, j int) bool { - return ret[i].Priority < ret[j].Priority - }) return ret } @@ -242,18 +222,18 @@ func (c *ClientRouteRuleContainer) SearchCustomRuleV1(svc model.ServiceKey) (*ap switch c.direction { case model.TrafficDirection_INBOUND: if existExactRule { - ret.Inbounds = append(ret.Inbounds, exactRule.GetRouteRuleV1().GetInbounds()...) + ret.Inbounds = append(ret.Inbounds, exactRule.customv1Rules.Inbounds...) } if existNsWildcardRule { - ret.Inbounds = append(ret.Inbounds, nsWildcardRule.GetRouteRuleV1().GetInbounds()...) + ret.Inbounds = append(ret.Inbounds, nsWildcardRule.customv1Rules.Inbounds...) } default: if existExactRule { - ret.Outbounds = append(ret.Outbounds, exactRule.GetRouteRuleV1().GetOutbounds()...) + ret.Outbounds = append(ret.Outbounds, exactRule.customv1Rules.Outbounds...) revisions = append(revisions, exactRule.revision) } if existNsWildcardRule { - ret.Outbounds = append(ret.Outbounds, nsWildcardRule.GetRouteRuleV1().GetOutbounds()...) + ret.Outbounds = append(ret.Outbounds, nsWildcardRule.customv1Rules.Outbounds...) } } if existExactRule { @@ -263,14 +243,6 @@ func (c *ClientRouteRuleContainer) SearchCustomRuleV1(svc model.ServiceKey) (*ap revisions = append(revisions, nsWildcardRule.revision) } - // 最终在做一次排序 - sort.Slice(ret.Inbounds, func(i, j int) bool { - return model.CompareRoutingV1(ret.Inbounds[i], ret.Inbounds[j]) - }) - sort.Slice(ret.Outbounds, func(i, j int) bool { - return model.CompareRoutingV1(ret.Outbounds[i], ret.Outbounds[j]) - }) - return ret, revisions } @@ -321,19 +293,6 @@ func (c *ClientRouteRuleContainer) RemoveRule(svcKey model.ServiceKey, ruleId st } } -func (c *ClientRouteRuleContainer) CleanAllRule(ruleId string) { - // level1 级别 cache 处理 - c.exactRules.Range(func(key string, svcContainer *ServiceWithRouterRules) { - svcContainer.DelRouterRule(ruleId) - }) - // level2 级别 cache 处理 - c.nsWildcardRules.Range(func(key string, svcContainer *ServiceWithRouterRules) { - svcContainer.DelRouterRule(ruleId) - }) - // level3 级别 cache 处理 - c.allWildcardRules.DelRouterRule(ruleId) -} - func newRouteRuleContainer() *RouteRuleContainer { return &RouteRuleContainer{ rules: utils.NewSyncMap[string, *model.ExtendRouterConfig](), @@ -372,13 +331,7 @@ type RouteRuleContainer struct { func (b *RouteRuleContainer) saveV2(conf *model.ExtendRouterConfig) { b.rules.Store(conf.ID, conf) handler := func(container *ClientRouteRuleContainer, svcKey model.ServiceKey) { - // 避免读取到中间状态数据 - container.lock.Lock() - defer container.lock.Unlock() - b.effect.Add(svcKey) - // 先删除,再保存 - container.CleanAllRule(conf.ID) container.SaveRule(svcKey, conf) } diff --git a/cache/service/router_rule_query.go b/cache/service/router_rule_query.go index 88473f8c8..b5262c287 100644 --- a/cache/service/router_rule_query.go +++ b/cache/service/router_rule_query.go @@ -29,6 +29,14 @@ import ( "github.com/polarismesh/polaris/common/utils" ) +// forceUpdate 更新配置 +func (rc *RouteRuleCache) forceUpdate() error { + if err := rc.Update(); err != nil { + return err + } + return nil +} + func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace, sourceService, destNamespace, destService string, both bool) bool { var ( @@ -113,7 +121,7 @@ func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace // QueryRoutingConfigsV2 Query Route Configuration List func (rc *RouteRuleCache) QueryRoutingConfigsV2(ctx context.Context, args *types.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { - if err := rc.Update(); err != nil { + if err := rc.forceUpdate(); err != nil { return 0, nil, err } hasSvcQuery := len(args.Service) != 0 || len(args.Namespace) != 0 @@ -173,14 +181,7 @@ func (rc *RouteRuleCache) QueryRoutingConfigsV2(ctx context.Context, args *types res = append(res, routeRule) } - predicates := types.LoadRouterRulePredicates(ctx) - rc.IteratorRouterRule(func(key string, value *model.ExtendRouterConfig) { - for i := range predicates { - if !predicates[i](ctx, value) { - return - } - } process(key, value) }) diff --git a/cache/service/service.go b/cache/service/service.go index 6b3ae7e56..03a5d5c4a 100644 --- a/cache/service/service.go +++ b/cache/service/service.go @@ -340,45 +340,13 @@ func (sc *serviceCache) GetServicesCount() int { } // ListServices get service list and revision by namespace -func (sc *serviceCache) ListServices(ctx context.Context, ns string) (string, []*model.Service) { - revision, matchServices := sc.serviceList.ListServices(ns) - predicates := types.LoadServicePredicates(ctx) - ret := make([]*model.Service, 0, len(matchServices)) - for i := range matchServices { - allMatch := true - for j := range predicates { - if !predicates[j](ctx, matchServices[i]) { - allMatch = false - break - } - } - if allMatch { - ret = append(ret, matchServices[i]) - } - } - matchServices = ret - return revision, matchServices +func (sc *serviceCache) ListServices(ns string) (string, []*model.Service) { + return sc.serviceList.ListServices(ns) } // ListAllServices get all service and revision -func (sc *serviceCache) ListAllServices(ctx context.Context) (string, []*model.Service) { - revision, matchServices := sc.serviceList.ListAllServices() - predicates := types.LoadServicePredicates(ctx) - ret := make([]*model.Service, 0, len(matchServices)) - for i := range matchServices { - pass := true - for j := range predicates { - if !predicates[j](ctx, matchServices[i]) { - pass = false - break - } - } - if pass { - ret = append(ret, matchServices[i]) - } - } - matchServices = ret - return revision, matchServices +func (sc *serviceCache) ListAllServices() (string, []*model.Service) { + return sc.serviceList.ListAllServices() } // ListServiceAlias get all service alias by target service diff --git a/cache/service/service_contract.go b/cache/service/service_contract.go index b3dc30a37..a5eddece4 100644 --- a/cache/service/service_contract.go +++ b/cache/service/service_contract.go @@ -116,11 +116,11 @@ func (sc *ServiceContractCache) setContracts(values []*model.EnrichServiceContra item := values[i] if !item.Valid { del++ - _ = sc.upsertValueCache(item, true) + sc.upsertValueCache(item, true) continue } upsert++ - _ = sc.upsertValueCache(item, false) + sc.upsertValueCache(item, false) } return map[string]time.Time{ sc.Name(): lastMtime, diff --git a/cache/service/service_query.go b/cache/service/service_query.go index 25e26006c..27e6f4ee4 100644 --- a/cache/service/service_query.go +++ b/cache/service/service_query.go @@ -74,23 +74,6 @@ func (sc *serviceCache) GetServicesByFilter(ctx context.Context, serviceFilters matchServices = tmpSvcs } - // 这里需要额外做过滤判断 - predicates := types.LoadServicePredicates(ctx) - ret := make([]*model.Service, 0, len(matchServices)) - for i := range matchServices { - pass := true - for pi := range predicates { - if !predicates[pi](ctx, matchServices[i]) { - pass = false - break - } - } - if pass { - ret = append(ret, matchServices[i]) - } - } - matchServices = ret - amount, services := sortBeforeTrim(matchServices, offset, limit) var enhancedServices []*model.EnhancedService diff --git a/common/api/v1/config_response.go b/common/api/v1/config_response.go index d9d1c94b5..ccf34820b 100644 --- a/common/api/v1/config_response.go +++ b/common/api/v1/config_response.go @@ -226,13 +226,6 @@ func NewConfigFileReleaseHistoryResponse( } } -func NewSimpleConfigFileImportResponse(code apimodel.Code) *apiconfig.ConfigImportResponse { - return &apiconfig.ConfigImportResponse{ - Code: &wrappers.UInt32Value{Value: uint32(code)}, - Info: &wrappers.StringValue{Value: code2info[uint32(code)]}, - } -} - func NewConfigFileImportResponse(code apimodel.Code, createConfigFiles, skipConfigFiles, overwriteConfigFiles []*apiconfig.ConfigFile) *apiconfig.ConfigImportResponse { return &apiconfig.ConfigImportResponse{ diff --git a/common/model/auth/acquire_context.go b/common/model/auth/acquire_context.go index 09c78b79a..4a8c7a92d 100644 --- a/common/model/auth/acquire_context.go +++ b/common/model/auth/acquire_context.go @@ -38,7 +38,7 @@ type AcquireContext struct { // Module 来自那个业务层(服务注册与服务治理、配置模块) module BzModule // Method 操作函数 - methods []ServerFunctionName + method ServerFunctionName // Operation 本次操作涉及的动作 operation ResourceOperation // Resources 本次 @@ -98,7 +98,7 @@ func WithModule(module BzModule) acquireContextOption { // WithMethod 本次操作函数名称 func WithMethod(method ServerFunctionName) acquireContextOption { return func(authCtx *AcquireContext) { - authCtx.methods = []ServerFunctionName{method} + authCtx.method = method } } @@ -118,7 +118,7 @@ func WithOperation(operation ResourceOperation) acquireContextOption { // @return acquireContextOption func WithAccessResources(accessResources map[apisecurity.ResourceType][]ResourceEntry) acquireContextOption { return func(authCtx *AcquireContext) { - authCtx.SetAccessResources(accessResources) + authCtx.accessResources = accessResources } } @@ -180,11 +180,6 @@ func (authCtx *AcquireContext) GetOperation() ResourceOperation { return authCtx.operation } -// SetOperation 设置本次操作的类型 -func (authCtx *AcquireContext) SetOperation(op ResourceOperation) { - authCtx.operation = op -} - // GetAccessResources 获取本次请求的资源 // // @receiver authCtx @@ -198,20 +193,7 @@ func (authCtx *AcquireContext) GetAccessResources() map[apisecurity.ResourceType // @receiver authCtx // @param accessRes func (authCtx *AcquireContext) SetAccessResources(accessRes map[apisecurity.ResourceType][]ResourceEntry) { - copyM := map[apisecurity.ResourceType][]ResourceEntry{} - for k, v := range accessRes { - if len(v) == 0 { - continue - } - copyM[k] = v - } - - authCtx.accessResources = copyM -} - -// SetMethod 设置本次请求涉及的操作函数 -func (authCtx *AcquireContext) SetMethod(methods []ServerFunctionName) { - authCtx.methods = methods + authCtx.accessResources = accessRes } // GetAttachments 获取本次请求的额外携带信息 @@ -231,8 +213,8 @@ func (authCtx *AcquireContext) SetAttachment(key string, val interface{}) { } // GetMethod 获取本次请求涉及的操作函数 -func (authCtx *AcquireContext) GetMethods() []ServerFunctionName { - return authCtx.methods +func (authCtx *AcquireContext) GetMethod() ServerFunctionName { + return authCtx.method } // SetFromClient 本次请求来自客户端 diff --git a/common/model/auth/auth.go b/common/model/auth/auth.go index 270fc8574..0f879dc17 100644 --- a/common/model/auth/auth.go +++ b/common/model/auth/auth.go @@ -371,8 +371,6 @@ type StrategyDetail struct { Comment string Default bool Owner string - // 来源 - Source string // CalleeMethods 允许访问的服务端接口 CalleeMethods []string Resources []StrategyResource @@ -385,61 +383,6 @@ type StrategyDetail struct { ModifyTime time.Time } -func (s *StrategyDetail) GetAction() apisecurity.AuthAction { - if s.Action == apisecurity.AuthAction_ALLOW.String() { - return apisecurity.AuthAction_ALLOW - } - if s.Action == apisecurity.AuthAction_READ_WRITE.String() { - return apisecurity.AuthAction_ALLOW - } - return apisecurity.AuthAction_DENY -} - -func (s *StrategyDetail) FromSpec(req *apisecurity.AuthStrategy) { - s.ID = utils.NewUUID() - s.Name = req.Name.GetValue() - s.Action = req.GetAction().String() - s.Comment = req.Comment.GetValue() - s.Default = false - s.Owner = req.Owner.GetValue() - s.Valid = true - s.Source = req.GetSource().GetValue() - s.Revision = utils.NewUUID() - s.CreateTime = time.Now() - s.ModifyTime = time.Now() - s.CalleeMethods = req.GetFunctions() - s.Conditions = make([]Condition, 0, len(req.GetResourceLabels())) - for i := range req.GetResourceLabels() { - item := req.GetResourceLabels()[i] - s.Conditions = append(s.Conditions, Condition{ - Key: item.Key, - Value: item.Value, - CompareFunc: item.CompareType, - }) - } - -} - -func (s *StrategyDetail) IsMatchAction(a string) bool { - saveAction := s.Action - if isAllowAction(saveAction) { - saveAction = apisecurity.AuthAction_ALLOW.String() - } - if isAllowAction(a) { - a = apisecurity.AuthAction_ALLOW.String() - } - return saveAction == a -} - -func isAllowAction(s string) bool { - switch s { - case apisecurity.AuthAction_ALLOW.String(), apisecurity.AuthAction_READ_WRITE.String(): - return true - default: - return false - } -} - func (s *StrategyDetail) IsDeny() bool { return s.Action == apisecurity.AuthAction_DENY.String() } @@ -490,8 +433,6 @@ type ModifyStrategyDetail struct { Action string Comment string Metadata map[string]string - CalleeMethods []string - Conditions []Condition AddPrincipals []Principal RemovePrincipals []Principal AddResources []StrategyResource @@ -520,31 +461,6 @@ type StrategyResource struct { ResID string } -func (s StrategyResource) Key() string { - return strconv.Itoa(int(s.ResType)) + "/" + s.ResID -} - -func forUserPrincipal(id string) Principal { - return Principal{ - PrincipalID: id, - PrincipalType: PrincipalUser, - } -} - -func forUserGroupPrincipal(id string) Principal { - return Principal{ - PrincipalID: id, - PrincipalType: PrincipalGroup, - } -} - -func forRolePrincipal(id string) Principal { - return Principal{ - PrincipalID: id, - PrincipalType: PrincipalRole, - } -} - // Principal 策略相关人 type Principal struct { StrategyID string @@ -552,16 +468,6 @@ type Principal struct { Owner string PrincipalID string PrincipalType PrincipalType - Extend map[string]string -} - -func NewAnonymousPrincipal() Principal { - return Principal{ - Name: "__anonymous__", - PrincipalType: PrincipalUser, - PrincipalID: "__anonymous__", - Extend: map[string]string{}, - } } func (p Principal) String() string { @@ -589,60 +495,6 @@ type Role struct { Comment string CreateTime time.Time ModifyTime time.Time - Users []Principal - UserGroups []Principal -} - -func (r *Role) FromSpec(d *apisecurity.Role) { - r.Name = d.Name - r.Owner = d.Owner - r.Source = d.Source - r.Metadata = d.Metadata - - if len(d.Users) != 0 { - users := make([]Principal, 0, len(d.Users)) - for i := range d.Users { - users = append(users, Principal{PrincipalID: d.Users[i].GetId().GetValue()}) - } - r.Users = users - } - - if len(d.UserGroups) != 0 { - groups := make([]Principal, 0, len(d.UserGroups)) - for i := range d.UserGroups { - groups = append(groups, Principal{PrincipalID: d.UserGroups[i].GetId().GetValue()}) - } - r.UserGroups = groups - } -} - -func (r *Role) ToSpec() *apisecurity.Role { - d := &apisecurity.Role{} - - d.Name = r.Name - d.Owner = r.Owner - d.Source = r.Source - d.Metadata = r.Metadata - - if len(r.Users) != 0 { - users := make([]*apisecurity.User, 0, len(r.Users)) - for i := range r.Users { - users = append(users, &apisecurity.User{ - Id: utils.NewStringValue(r.Users[i].PrincipalID), - }) - } - d.Users = users - } - - if len(d.UserGroups) != 0 { - groups := make([]*apisecurity.UserGroup, 0, len(d.UserGroups)) - for i := range r.UserGroups { - groups = append(groups, &apisecurity.UserGroup{ - Id: utils.NewStringValue(r.UserGroups[i].PrincipalID), - }) - } - d.UserGroups = groups - } - - return d + Users []*User + UserGroups []*UserGroup } diff --git a/common/model/auth/const.go b/common/model/auth/const.go index d58a621fb..88ecf139d 100644 --- a/common/model/auth/const.go +++ b/common/model/auth/const.go @@ -52,6 +52,7 @@ const ( const ( CreateNamespace ServerFunctionName = "CreateNamespace" CreateNamespaces ServerFunctionName = "CreateNamespaces" + DeleteNamespace ServerFunctionName = "DeleteNamespace" DeleteNamespaces ServerFunctionName = "DeleteNamespaces" UpdateNamespaces ServerFunctionName = "UpdateNamespaces" UpdateNamespaceToken ServerFunctionName = "UpdateNamespaceToken" @@ -178,13 +179,7 @@ const ( ) // 全链路灰度 -const ( - CreateLaneGroups ServerFunctionName = "CreateLaneGroups" - DeleteLaneGroups ServerFunctionName = "DeleteLaneGroups" - EnableLaneGroups ServerFunctionName = "EnableLaneGroups" - UpdateLaneGroups ServerFunctionName = "UpdateLaneGroups" - DescribeLaneGroups ServerFunctionName = "DescribeLaneGroups" -) +const () // 用户/用户组 const ( @@ -240,226 +235,6 @@ const ( DescribeCMDBInfo ServerFunctionName = "DescribeCMDBInfo" ) -type ServerFunctionGroup struct { - Name string `json:"name"` - Functions []ServerFunctionName `json:"functions"` -} - -var ServerFunctions = []ServerFunctionGroup{ - { - Name: "Client", - Functions: []ServerFunctionName{ - RegisterInstance, - DeregisterInstance, - ReportServiceContract, - DiscoverServices, - DiscoverInstances, - UpdateInstance, - DiscoverRouterRule, - DiscoverRateLimitRule, - DiscoverCircuitBreakerRule, - DiscoverFaultDetectRule, - DiscoverServiceContract, - DiscoverLaneRule, - DiscoverConfigFile, - WatchConfigFile, - DiscoverConfigFileNames, - DiscoverConfigGroups, - }, - }, - { - Name: "Namespace", - Functions: []ServerFunctionName{ - CreateNamespace, - CreateNamespaces, - DeleteNamespaces, - UpdateNamespaces, - // UpdateNamespaceToken, - DescribeNamespaces, - // DescribeNamespaceToken, - }, - }, - { - Name: "Service", - Functions: []ServerFunctionName{ - CreateServices, - DeleteServices, - UpdateServices, - // UpdateServiceToken, - DescribeAllServices, - DescribeServices, - DescribeServicesCount, - // DescribeServiceToken, - // DescribeServiceOwner, - CreateServiceAlias, - DeleteServiceAliases, - UpdateServiceAlias, - DescribeServiceAliases, - }, - }, - { - Name: "ServiceContract", - Functions: []ServerFunctionName{ - CreateServiceContracts, - DescribeServiceContracts, - DescribeServiceContractVersions, - DeleteServiceContracts, - CreateServiceContractInterfaces, - AppendServiceContractInterfaces, - DeleteServiceContractInterfaces, - }, - }, - { - Name: "Instance", - Functions: []ServerFunctionName{ - CreateInstances, - DeleteInstances, - DeleteInstancesByHost, - UpdateInstances, - UpdateInstancesIsolate, - DescribeInstances, - DescribeInstancesCount, - DescribeInstanceLabels, - CleanInstance, - BatchCleanInstances, - DescribeInstanceLastHeartbeat, - }, - }, - { - Name: "RouteRule", - Functions: []ServerFunctionName{ - CreateRouteRules, - DeleteRouteRules, - UpdateRouteRules, - EnableRouteRules, - DescribeRouteRules, - }, - }, - { - Name: "RateLimitRule", - Functions: []ServerFunctionName{ - CreateRateLimitRules, - DeleteRateLimitRules, - UpdateRateLimitRules, - EnableRateLimitRules, - DescribeRateLimitRules, - }, - }, - { - Name: "CircuitBreakerRule", - Functions: []ServerFunctionName{ - CreateCircuitBreakerRules, - DeleteCircuitBreakerRules, - EnableCircuitBreakerRules, - UpdateCircuitBreakerRules, - DescribeCircuitBreakerRules, - }, - }, - { - Name: "FaultDetectRule", - Functions: []ServerFunctionName{ - CreateFaultDetectRules, - DeleteFaultDetectRules, - EnableFaultDetectRules, - UpdateFaultDetectRules, - DescribeFaultDetectRules, - }, - }, - { - Name: "LaneRule", - Functions: []ServerFunctionName{ - CreateLaneGroups, - DeleteLaneGroups, - EnableLaneGroups, - UpdateLaneGroups, - DescribeLaneGroups, - }, - }, - { - Name: "ConfigGroup", - Functions: []ServerFunctionName{ - CreateConfigFileGroup, - DeleteConfigFileGroup, - UpdateConfigFileGroup, - DescribeConfigFileGroups, - }, - }, - { - Name: "ConfigFile", - Functions: []ServerFunctionName{ - PublishConfigFile, - CreateConfigFile, - UpdateConfigFile, - DeleteConfigFile, - DescribeConfigFileRichInfo, - DescribeConfigFiles, - BatchDeleteConfigFiles, - ExportConfigFiles, - ImportConfigFiles, - DescribeConfigFileReleaseHistories, - DescribeAllConfigFileTemplates, - DescribeConfigFileTemplate, - CreateConfigFileTemplate, - }, - }, - { - Name: "ConfigRelease", - Functions: []ServerFunctionName{ - RollbackConfigFileReleases, - DeleteConfigFileReleases, - StopGrayConfigFileReleases, - DescribeConfigFileRelease, - DescribeConfigFileReleases, - DescribeConfigFileReleaseVersions, - UpsertAndReleaseConfigFile, - }, - }, - { - Name: "User", - Functions: []ServerFunctionName{ - CreateUsers, - DeleteUsers, - DescribeUsers, - DescribeUserToken, - EnableUserToken, - ResetUserToken, - UpdateUser, - UpdateUserPassword, - }, - }, - { - Name: "UserGroup", - Functions: []ServerFunctionName{ - CreateUserGroup, - UpdateUserGroups, - DeleteUserGroups, - DescribeUserGroups, - DescribeUserGroupDetail, - DescribeUserGroupToken, - EnableUserGroupToken, - ResetUserGroupToken, - }, - }, - { - Name: "AuthPolicy", - Functions: []ServerFunctionName{ - CreateAuthPolicy, - UpdateAuthPolicies, - DeleteAuthPolicies, - DescribeAuthPolicies, - DescribeAuthPolicyDetail, - DescribePrincipalResources, - }, - }, - // "AuthRole": { - // CreateAuthRoles, - // UpdateAuthRoles, - // DeleteAuthRoles, - // DescribeAuthRoles, - // DescribeAuthRoleDetail, - // }, -} - var ( SearchTypeMapping = map[string]apisecurity.ResourceType{ "0": apisecurity.ResourceType_Namespaces, diff --git a/common/model/auth/container.go b/common/model/auth/container.go index cea3d4f52..4833b3108 100644 --- a/common/model/auth/container.go +++ b/common/model/auth/container.go @@ -18,22 +18,21 @@ package auth import ( - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/polarismesh/polaris/common/utils" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" ) // PrincipalResourceContainer principal 资源容器 type PrincipalResourceContainer struct { - denyResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]] - allowResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]] + denyResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] + allowResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] } // NewPrincipalResourceContainer 创建 PrincipalResourceContainer 对象 func NewPrincipalResourceContainer() *PrincipalResourceContainer { return &PrincipalResourceContainer{ - allowResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]](), - denyResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]](), + allowResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), + denyResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), } } @@ -41,18 +40,12 @@ func NewPrincipalResourceContainer() *PrincipalResourceContainer { func (p *PrincipalResourceContainer) Hint(rt apisecurity.ResourceType, resId string) (apisecurity.AuthAction, bool) { ids, ok := p.denyResources.Load(rt) if ok { - if ids.Contains(utils.MatchAll) { - return apisecurity.AuthAction_DENY, true - } if ids.Contains(resId) { return apisecurity.AuthAction_DENY, true } } ids, ok = p.allowResources.Load(rt) if ok { - if ids.Contains(utils.MatchAll) { - return apisecurity.AuthAction_ALLOW, true - } if ids.Contains(resId) { return apisecurity.AuthAction_ALLOW, true } @@ -60,50 +53,46 @@ func (p *PrincipalResourceContainer) Hint(rt apisecurity.ResourceType, resId str return 0, false } -// SaveResource 保存资源 -func (p *PrincipalResourceContainer) SaveResource(a apisecurity.AuthAction, r StrategyResource) { - if a == apisecurity.AuthAction_ALLOW { - p.saveResource(p.allowResources, r) - } else { - p.saveResource(p.denyResources, r) - } +// SaveAllowResource 保存允许的资源 +func (p *PrincipalResourceContainer) SaveAllowResource(r StrategyResource) { + p.saveResource(p.allowResources, r) } -// DelResource 删除资源 -func (p *PrincipalResourceContainer) DelResource(a apisecurity.AuthAction, r StrategyResource) { - if a == apisecurity.AuthAction_ALLOW { - p.delResource(p.allowResources, r) - } else { - p.delResource(p.denyResources, r) - } +// DelAllowResource 删除允许的资源 +func (p *PrincipalResourceContainer) DelAllowResource(r StrategyResource) { + p.delResource(p.allowResources, r) +} + +// SaveDenyResource 保存拒绝的资源 +func (p *PrincipalResourceContainer) SaveDenyResource(r StrategyResource) { + p.saveResource(p.denyResources, r) +} + +// DelDenyResource 删除拒绝的资源 +func (p *PrincipalResourceContainer) DelDenyResource(r StrategyResource) { + p.delResource(p.denyResources, r) } func (p *PrincipalResourceContainer) saveResource( - container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]], res StrategyResource) { + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], res StrategyResource) { resType := apisecurity.ResourceType(res.ResType) - container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string, string] { - return utils.NewRefSyncSet[string, string]() + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { + return utils.NewRefSyncSet[string]() }) ids, _ := container.Load(resType) - ids.Add(utils.Reference[string, string]{ - Key: res.ResID, - Referencer: res.StrategyID, - }) + ids.Add(res.ResID) } func (p *PrincipalResourceContainer) delResource( - container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]], res StrategyResource) { + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], r StrategyResource) { - resType := apisecurity.ResourceType(res.ResType) - container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string, string] { - return utils.NewRefSyncSet[string, string]() + resType := apisecurity.ResourceType(r.ResType) + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { + return utils.NewRefSyncSet[string]() }) ids, _ := container.Load(resType) - ids.Remove(utils.Reference[string, string]{ - Key: res.ResID, - Referencer: res.StrategyID, - }) + ids.Remove(r.ResID) } diff --git a/common/model/auth/context.go b/common/model/auth/context.go deleted file mode 100644 index d4ccbae19..000000000 --- a/common/model/auth/context.go +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -type ( - ContextKeyConditions struct{} -) diff --git a/common/model/auth/funcs.go b/common/model/auth/funcs.go deleted file mode 100644 index fb458bfd8..000000000 --- a/common/model/auth/funcs.go +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import "strings" - -/* -* -string_equal -string_not_equal -string_equal_ignore_case -string_not_equal_ignore_case -string_like -string_not_like -date_equal -date_not_equal -date_greater_than -date_greater_than_equal -date_less_than -date_less_than_equal -ip_equal -ip_not_equal -*/ -var ( - ConditionCompareDict = map[string]func(string, string) bool{ - // string_equal - "string_equal": func(s1, s2 string) bool { - return s1 == s2 - }, - "for_any_value:string_equal": func(s1, s2 string) bool { - return s1 == s2 - }, - // string_not_equal - "string_not_equal": func(s1, s2 string) bool { - return s1 != s2 - }, - "for_any_value:string_not_equal": func(s1, s2 string) bool { - return s1 != s2 - }, - // string_equal_ignore_case - "string_equal_ignore_case": strings.EqualFold, - "for_any_value:string_equal_ignore_case": strings.EqualFold, - // string_not_equal_ignore_case - "string_not_equal_ignore_case": func(s1, s2 string) bool { - return !strings.EqualFold(s1, s2) - }, - "for_any_value:string_not_equal_ignore_case": func(s1, s2 string) bool { - return !strings.EqualFold(s1, s2) - }, - } -) diff --git a/common/model/config_file.go b/common/model/config_file.go index af9ce1347..f894901a0 100644 --- a/common/model/config_file.go +++ b/common/model/config_file.go @@ -492,8 +492,6 @@ func ToConfigGroupAPI(group *ConfigFileGroup) *config_manage.ConfigFileGroup { Business: utils.NewStringValue(group.Business), Department: utils.NewStringValue(group.Department), Metadata: group.Metadata, - Editable: utils.NewBoolValue(true), - Deleteable: utils.NewBoolValue(true), } } diff --git a/common/model/context_key.go b/common/model/context_key.go index 2c7cdcc16..9c8377e7d 100644 --- a/common/model/context_key.go +++ b/common/model/context_key.go @@ -22,6 +22,4 @@ type ( ContextKeyAutoCreateNamespace struct{} // ContextKeyAutoCreateService . ContextKeyAutoCreateService struct{} - // ContextKeyCompatible . - ContextKeyCompatible struct{} ) diff --git a/common/model/lane.go b/common/model/lane.go index 9d79320d6..db5539a7f 100644 --- a/common/model/lane.go +++ b/common/model/lane.go @@ -49,7 +49,6 @@ type LaneGroup struct { Revision string Description string Valid bool - Labels map[string]string CreateTime time.Time ModifyTime time.Time // LaneRules id -> *LaneRule diff --git a/common/model/naming.go b/common/model/naming.go index 9394ca3f9..aa3af6ab9 100644 --- a/common/model/naming.go +++ b/common/model/naming.go @@ -521,7 +521,6 @@ type FaultDetectRule struct { DstMethod string Rule string Revision string - Metadata map[string]string Valid bool CreateTime time.Time ModifyTime time.Time diff --git a/common/model/operation.go b/common/model/operation.go index 5c3e2c0e6..c8b520ed2 100644 --- a/common/model/operation.go +++ b/common/model/operation.go @@ -62,15 +62,12 @@ const ( RUserGroup Resource = "UserGroup" RUserGroupRelation Resource = "UserGroupRelation" RAuthStrategy Resource = "AuthStrategy" - RAuthRole Resource = "Role" RConfigGroup Resource = "ConfigGroup" RConfigFile Resource = "ConfigFile" RConfigFileRelease Resource = "ConfigFileRelease" RCircuitBreakerRule Resource = "CircuitBreakerRule" RFaultDetectRule Resource = "FaultDetectRule" RServiceContract Resource = "ServiceContract" - RLaneGroup Resource = "LaneGroup" - RLaneRule Resource = "LaneRule" ) // RecordEntry Operation records diff --git a/common/model/ratelimit.go b/common/model/ratelimit.go index 7de236d5e..a344fe178 100644 --- a/common/model/ratelimit.go +++ b/common/model/ratelimit.go @@ -42,26 +42,6 @@ type RateLimit struct { CreateTime time.Time ModifyTime time.Time EnableTime time.Time - Metadata map[string]string -} - -func (r *RateLimit) CopyNoProto() *RateLimit { - return &RateLimit{ - ID: r.ID, - ServiceID: r.ServiceID, - Name: r.Name, - Method: r.Method, - Labels: r.Labels, - Proto: r.Proto, - Priority: r.Priority, - Rule: r.Rule, - Revision: r.Revision, - Disable: r.Disable, - Valid: r.Valid, - CreateTime: r.CreateTime, - ModifyTime: r.ModifyTime, - EnableTime: r.EnableTime, - } } // Labels2Arguments 适配老的标签到新的参数列表 diff --git a/common/model/routing.go b/common/model/routing.go index 290036b87..85a3b9ff8 100644 --- a/common/model/routing.go +++ b/common/model/routing.go @@ -44,8 +44,6 @@ const ( const ( // V2RuleIDKey v2 版本的规则路由 ID V2RuleIDKey = "__routing_v2_id__" - // V2RuleIDPriority v2 版本的规则路由优先级 - V2RuleIDPriority = "__routing_v2_priority__" // V1RuleIDKey v1 版本的路由规则 ID V1RuleIDKey = "__routing_v1_id__" // V1RuleRouteIndexKey v1 版本 route 规则在自己 route 链中的 index 信息 @@ -118,7 +116,7 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { ) switch r.GetRoutingPolicy() { - case apitraffic.RoutingPolicy_NearbyPolicy: + case apitraffic.RoutingPolicy_RulePolicy: anyValue, err = ptypes.MarshalAny(r.NearbyRouting) if err != nil { return nil, err @@ -148,8 +146,6 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { Etime: commontime.Time2String(r.EnableTime), Priority: r.Priority, Description: r.Description, - Editable: true, - Deleteable: true, } if r.EnableTime.Year() > 2000 { rule.Etime = commontime.Time2String(r.EnableTime) @@ -194,8 +190,6 @@ type RouterConfig struct { ModifyTime time.Time `json:"mtime"` // enabletime The last time the rules enabled EnableTime time.Time `json:"etime"` - // Metadata. - Metadata map[string]string `json:"metadata"` } // GetRoutingPolicy Query routing rules type @@ -311,7 +305,6 @@ func (r *RouterConfig) ParseRouteRuleFromAPI(routing *apitraffic.RouteRule) erro r.Policy = routing.GetRoutingPolicy().String() r.Priority = routing.Priority r.Description = routing.Description - r.Metadata = routing.Metadata // Priority range range [0, 10] if r.Priority > 10 { @@ -385,21 +378,20 @@ func parseSubRouteRule(ruleRouting *apitraffic.RuleRoutingConfig) *RuleRoutingCo for i := range ruleRouting.Rules { item := ruleRouting.Rules[i] - if len(item.Sources) != 0 { - source := item.Sources[0] - wrapper.Caller = ServiceKey{ - Namespace: source.Namespace, - Name: source.Service, - } + source := item.Sources[0] + destination := item.Destinations[0] + + wrapper.Caller = ServiceKey{ + Namespace: source.Namespace, + Name: source.Service, } - if len(item.Destinations) != 0 { - destination := item.Destinations[0] - wrapper.Callee = ServiceKey{ - Namespace: destination.Namespace, - Name: destination.Service, - } + wrapper.Callee = ServiceKey{ + Namespace: destination.Namespace, + Name: destination.Service, } + break } + return wrapper } @@ -611,15 +603,7 @@ func CompareRoutingV2(a, b *ExtendRouterConfig) bool { if a.Priority != b.Priority { return a.Priority < b.Priority } - // 如果优先级相同,则比较规则 ID - return a.ID < b.ID -} - -// CompareRoutingV1 Compare the priority of two routing. -func CompareRoutingV1(a, b *apitraffic.Route) bool { - ap := a.ExtendInfo[V2RuleIDPriority] - bp := b.ExtendInfo[V2RuleIDPriority] - return ap < bp + return a.CreateTime.Before(b.CreateTime) } // ConvertRoutingV1ToExtendV2 The routing rules of the V1 version are converted to V2 version for storage @@ -771,8 +755,7 @@ func BuildInBoundsRoute(item *ExtendRouterConfig) []*apitraffic.Route { Sources: v1sources, Destinations: v1destinations, ExtendInfo: map[string]string{ - V2RuleIDKey: item.ID, - V2RuleIDPriority: fmt.Sprintf("%04d", item.Priority), + V2RuleIDKey: item.ID, }, }) } diff --git a/common/utils/atomic.go b/common/utils/atomic.go index 8ab0b15eb..6cc8eea62 100644 --- a/common/utils/atomic.go +++ b/common/utils/atomic.go @@ -21,7 +21,6 @@ import "sync/atomic" func NewAtomicValue[V any](v V) *AtomicValue[V] { a := new(AtomicValue[V]) - a.a = atomic.Value{} a.Store(v) return a } @@ -30,14 +29,6 @@ type AtomicValue[V any] struct { a atomic.Value } -func (a *AtomicValue[V]) HasValue() bool { - if a == nil { - return false - } - v := a.a.Load() - return v != nil -} - func (a *AtomicValue[V]) Store(val V) { a.a.Store(val) } diff --git a/common/utils/collection.go b/common/utils/collection.go index 7d85de169..534edec67 100644 --- a/common/utils/collection.go +++ b/common/utils/collection.go @@ -57,52 +57,46 @@ func (set *Set[K]) Range(fn func(val K)) { } } -type Reference[K, R comparable] struct { - Key K - Referencer R -} - // NewRefSyncSet returns a new Set -func NewRefSyncSet[K, R comparable]() *RefSyncSet[K, R] { - return &RefSyncSet[K, R]{ - container: map[K]map[R]struct{}{}, +func NewRefSyncSet[K comparable]() *RefSyncSet[K] { + return &RefSyncSet[K]{ + container: make(map[K]int), } } -type RefSyncSet[K, R comparable] struct { - container map[K]map[R]struct{} +type RefSyncSet[K comparable] struct { + container map[K]int lock sync.RWMutex } // Add adds a string to the set -func (set *RefSyncSet[K, R]) Add(val Reference[K, R]) { +func (set *RefSyncSet[K]) Add(val K) { set.lock.Lock() defer set.lock.Unlock() - if _, ok := set.container[val.Key]; !ok { - set.container[val.Key] = map[R]struct{}{} + ref, ok := set.container[val] + if ok { + ref++ } - refs := set.container[val.Key] - refs[val.Referencer] = struct{}{} + set.container[val] = ref } // Remove removes a string from the set -func (set *RefSyncSet[K, R]) Remove(val Reference[K, R]) { +func (set *RefSyncSet[K]) Remove(val K) { set.lock.Lock() defer set.lock.Unlock() - if _, ok := set.container[val.Key]; !ok { - return + ref, ok := set.container[val] + if ok { + ref-- } - refs := set.container[val.Key] - delete(refs, val.Referencer) - if len(refs) == 0 { - delete(set.container, val.Key) + if ref == 0 { + delete(set.container, val) } else { - set.container[val.Key] = refs + set.container[val] = ref } } -func (set *RefSyncSet[K, R]) ToSlice() []K { +func (set *RefSyncSet[K]) ToSlice() []K { set.lock.RLock() defer set.lock.RUnlock() @@ -113,7 +107,7 @@ func (set *RefSyncSet[K, R]) ToSlice() []K { return ret } -func (set *RefSyncSet[K, R]) Range(fn func(val K)) { +func (set *RefSyncSet[K]) Range(fn func(val K)) { set.lock.RLock() snapshot := map[K]struct{}{} for k := range set.container { @@ -126,7 +120,7 @@ func (set *RefSyncSet[K, R]) Range(fn func(val K)) { } } -func (set *RefSyncSet[K, R]) Len() int { +func (set *RefSyncSet[K]) Len() int { set.lock.RLock() defer set.lock.RUnlock() @@ -134,7 +128,7 @@ func (set *RefSyncSet[K, R]) Len() int { } // Contains contains target value -func (set *RefSyncSet[K, R]) Contains(val K) bool { +func (set *RefSyncSet[K]) Contains(val K) bool { set.lock.Lock() defer set.lock.Unlock() @@ -142,7 +136,7 @@ func (set *RefSyncSet[K, R]) Contains(val K) bool { return exist } -func (set *RefSyncSet[K, R]) String() string { +func (set *RefSyncSet[K]) String() string { ret := set.ToSlice() return MustJson(ret) } diff --git a/common/utils/common.go b/common/utils/common.go index f08c1f852..49595eb1b 100644 --- a/common/utils/common.go +++ b/common/utils/common.go @@ -87,9 +87,6 @@ const ( MaxDbCircuitbreakerOwner = 1024 MaxDbCircuitbreakerVersion = 32 - // ratelimit表 - MaxDbRateLimitName = MaxRuleName - MaxRuleName = 64 MaxPlatformIDLength = 32 diff --git a/common/utils/const.go b/common/utils/const.go index 8e01a8994..0320376d1 100644 --- a/common/utils/const.go +++ b/common/utils/const.go @@ -73,6 +73,6 @@ const ( ContextIsFromSystem = StringContext("from-system") // ContextOperator operator info ContextOperator = StringContext("operator") - // ContextRequestHeaders request headers, save value type is map[string][]string + // ContextRequestHeaders request headers ContextRequestHeaders = StringContext("request-headers") ) diff --git a/config/client_test.go b/config/client_test.go index 33a51472e..c95c99b48 100644 --- a/config/client_test.go +++ b/config/client_test.go @@ -792,10 +792,8 @@ func TestServer_GetConfigGroupsWithCache(t *testing.T) { } t.Cleanup(func() { for k := range mockFiles { - testSuit.NamespaceServer().DeleteNamespaces(testSuit.DefaultCtx, []*apimodel.Namespace{ - &apimodel.Namespace{ - Name: wrapperspb.String(k), - }, + testSuit.NamespaceServer().DeleteNamespace(testSuit.DefaultCtx, &apimodel.Namespace{ + Name: wrapperspb.String(k), }) items := mockFiles[k] for _, item := range items { diff --git a/config/config_file_group.go b/config/config_file_group.go index e7b82c9c0..e81f57aae 100644 --- a/config/config_file_group.go +++ b/config/config_file_group.go @@ -256,12 +256,6 @@ func (s *Server) QueryConfigFileGroups(ctx context.Context, log.Error("[Config][Service] get config file count for group error.", utils.RequestID(ctx), utils.ZapNamespace(ret[i].Namespace), utils.ZapGroup(ret[i].Name), zap.Error(err)) } - - // 如果包含特殊标签,也不允许修改 - if _, ok := item.GetMetadata()[model.MetaKey3RdPlatform]; ok { - item.Editable = utils.NewBoolValue(false) - } - item.FileCount = wrapperspb.UInt64(fileCount) values = append(values, item) } diff --git a/config/interceptor/auth/client.go b/config/interceptor/auth/client.go index 06a7c6d50..892b19e1c 100644 --- a/config/interceptor/auth/client.go +++ b/config/interceptor/auth/client.go @@ -29,11 +29,11 @@ import ( ) // UpsertAndReleaseConfigFileFromClient 创建/更新配置文件并发布 -func (s *Server) UpsertAndReleaseConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) UpsertAndReleaseConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.PublishConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } @@ -44,7 +44,7 @@ func (s *Server) UpsertAndReleaseConfigFileFromClient(ctx context.Context, } // CreateConfigFileFromClient 调用config_file的方法创建配置文件 -func (s *Server) CreateConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFile) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{{ @@ -52,8 +52,8 @@ func (s *Server) CreateConfigFileFromClient(ctx context.Context, Name: fileInfo.Name, Group: fileInfo.Group}, }, auth.Create, auth.CreateConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -63,12 +63,12 @@ func (s *Server) CreateConfigFileFromClient(ctx context.Context, } // UpdateConfigFileFromClient 调用config_file的方法更新配置文件 -func (s *Server) UpdateConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) UpdateConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFile) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{fileInfo}, auth.Modify, auth.UpdateConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -78,13 +78,13 @@ func (s *Server) UpdateConfigFileFromClient(ctx context.Context, } // DeleteConfigFileFromClient 删除配置文件,删除配置文件同时会通知客户端 Not_Found -func (s *Server) DeleteConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) DeleteConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -94,16 +94,16 @@ func (s *Server) DeleteConfigFileFromClient(ctx context.Context, } // PublishConfigFileFromClient 调用config_file_release的方法发布配置文件 -func (s *Server) PublishConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) PublishConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFileRelease) *apiconfig.ConfigClientResponse { - authCtx := s.collectClientConfigFileRelease(ctx, + authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{{ Namespace: fileInfo.Namespace, Name: fileInfo.FileName, Group: fileInfo.Group}, }, auth.Create, auth.PublishConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -113,7 +113,7 @@ func (s *Server) PublishConfigFileFromClient(ctx context.Context, } // GetConfigFileWithCache 从缓存中获取配置文件,如果客户端的版本号大于服务端,则服务端重新加载缓存 -func (s *Server) GetConfigFileWithCache(ctx context.Context, +func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, fileInfo *apiconfig.ClientConfigFileInfo) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{{ @@ -121,8 +121,8 @@ func (s *Server) GetConfigFileWithCache(ctx context.Context, Name: fileInfo.FileName, Group: fileInfo.Group}, }, auth.Read, auth.DiscoverConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -131,12 +131,12 @@ func (s *Server) GetConfigFileWithCache(ctx context.Context, } // WatchConfigFiles 监听配置文件变化 -func (s *Server) LongPullWatchFile(ctx context.Context, +func (s *ServerAuthability) LongPullWatchFile(ctx context.Context, request *apiconfig.ClientWatchConfigFileRequest) (config.WatchCallback, error) { authCtx := s.collectClientWatchConfigFiles(ctx, request, auth.Read, auth.WatchConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return func() *apiconfig.ConfigClientResponse { - return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) + return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) }, nil } @@ -147,16 +147,16 @@ func (s *Server) LongPullWatchFile(ctx context.Context, } // GetConfigFileNamesWithCache 获取某个配置分组下的配置文件 -func (s *Server) GetConfigFileNamesWithCache(ctx context.Context, +func (s *ServerAuthability) GetConfigFileNamesWithCache(ctx context.Context, req *apiconfig.ConfigFileGroupRequest) *apiconfig.ConfigClientListResponse { - authCtx := s.collectClientConfigFileRelease(ctx, []*apiconfig.ConfigFileRelease{ + authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ { Namespace: req.GetConfigFileGroup().GetNamespace(), Group: req.GetConfigFileGroup().GetName(), }, }, auth.Read, auth.DiscoverConfigFileNames) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { out := api.NewConfigClientListResponse(auth.ConvertToErrCode(err)) return out } @@ -167,15 +167,15 @@ func (s *Server) GetConfigFileNamesWithCache(ctx context.Context, } // GetConfigGroupsWithCache 获取某个命名空间下的配置分组列表 -func (s *Server) GetConfigGroupsWithCache(ctx context.Context, +func (s *ServerAuthability) GetConfigGroupsWithCache(ctx context.Context, req *apiconfig.ClientConfigFileInfo) *apiconfig.ConfigDiscoverResponse { - authCtx := s.collectClientConfigFileRelease(ctx, []*apiconfig.ConfigFileRelease{ + authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ { Namespace: req.GetNamespace(), }, }, auth.Read, auth.DiscoverConfigGroups) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { out := api.NewConfigDiscoverResponse(auth.ConvertToErrCode(err)) return out } @@ -186,12 +186,12 @@ func (s *Server) GetConfigGroupsWithCache(ctx context.Context, } // CasUpsertAndReleaseConfigFileFromClient 创建/更新配置文件并发布 -func (s *Server) CasUpsertAndReleaseConfigFileFromClient(ctx context.Context, +func (s *ServerAuthability) CasUpsertAndReleaseConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.UpsertAndReleaseConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } diff --git a/config/interceptor/auth/config_file.go b/config/interceptor/auth/config_file.go index f515b8f55..4a3fcaf8c 100644 --- a/config/interceptor/auth/config_file.go +++ b/config/interceptor/auth/config_file.go @@ -28,12 +28,12 @@ import ( ) // CreateConfigFile 创建配置文件 -func (s *Server) CreateConfigFile(ctx context.Context, +func (s *ServerAuthability) CreateConfigFile(ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{configFile}, auth.Create, auth.CreateConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -43,13 +43,13 @@ func (s *Server) CreateConfigFile(ctx context.Context, } // GetConfigFileRichInfo 获取单个配置文件基础信息,包含发布状态等信息 -func (s *Server) GetConfigFileRichInfo(ctx context.Context, +func (s *ServerAuthability) GetConfigFileRichInfo(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{req}, auth.Read, auth.DescribeConfigFileRichInfo) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -57,12 +57,12 @@ func (s *Server) GetConfigFileRichInfo(ctx context.Context, } // SearchConfigFile 查询配置文件 -func (s *Server) SearchConfigFile(ctx context.Context, +func (s *ServerAuthability) SearchConfigFile(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFiles) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -71,12 +71,12 @@ func (s *Server) SearchConfigFile(ctx context.Context, } // UpdateConfigFile 更新配置文件 -func (s *Server) UpdateConfigFile( +func (s *ServerAuthability) UpdateConfigFile( ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{configFile}, auth.Modify, auth.UpdateConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -86,13 +86,13 @@ func (s *Server) UpdateConfigFile( } // DeleteConfigFile 删除配置文件,删除配置文件同时会通知客户端 Not_Found -func (s *Server) DeleteConfigFile(ctx context.Context, +func (s *ServerAuthability) DeleteConfigFile(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -102,12 +102,12 @@ func (s *Server) DeleteConfigFile(ctx context.Context, } // BatchDeleteConfigFile 批量删除配置文件 -func (s *Server) BatchDeleteConfigFile(ctx context.Context, +func (s *ServerAuthability) BatchDeleteConfigFile(ctx context.Context, req []*apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, req, auth.Delete, auth.BatchDeleteConfigFiles) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -116,7 +116,7 @@ func (s *Server) BatchDeleteConfigFile(ctx context.Context, return s.nextServer.BatchDeleteConfigFile(ctx, req) } -func (s *Server) ExportConfigFile(ctx context.Context, +func (s *ServerAuthability) ExportConfigFile(ctx context.Context, configFileExport *apiconfig.ConfigFileExportRequest) *apiconfig.ConfigExportResponse { var configFiles []*apiconfig.ConfigFile for _, group := range configFileExport.Groups { @@ -127,8 +127,8 @@ func (s *Server) ExportConfigFile(ctx context.Context, configFiles = append(configFiles, configFile) } authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Read, auth.ExportConfigFiles) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileExportResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigFileExportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -136,11 +136,11 @@ func (s *Server) ExportConfigFile(ctx context.Context, return s.nextServer.ExportConfigFile(ctx, configFileExport) } -func (s *Server) ImportConfigFile(ctx context.Context, +func (s *ServerAuthability) ImportConfigFile(ctx context.Context, configFiles []*apiconfig.ConfigFile, conflictHandling string) *apiconfig.ConfigImportResponse { authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Create, auth.ImportConfigFiles) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewSimpleConfigFileImportResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigFileImportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -148,7 +148,7 @@ func (s *Server) ImportConfigFile(ctx context.Context, return s.nextServer.ImportConfigFile(ctx, configFiles, conflictHandling) } -func (s *Server) GetAllConfigEncryptAlgorithms( +func (s *ServerAuthability) GetAllConfigEncryptAlgorithms( ctx context.Context) *apiconfig.ConfigEncryptAlgorithmResponse { return s.nextServer.GetAllConfigEncryptAlgorithms(ctx) } diff --git a/config/interceptor/auth/config_file_group.go b/config/interceptor/auth/config_file_group.go index dcad994d7..caa466c57 100644 --- a/config/interceptor/auth/config_file_group.go +++ b/config/interceptor/auth/config_file_group.go @@ -19,29 +19,24 @@ package config_auth import ( "context" - "strconv" apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" - "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/auth" - authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateConfigFileGroup 创建配置文件组 -func (s *Server) CreateConfigFileGroup(ctx context.Context, +func (s *ServerAuthability) CreateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, - authcommon.Create, authcommon.CreateConfigFileGroup) + auth.Create, auth.CreateConfigFileGroup) // 验证 token 信息 - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(authcommon.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -51,79 +46,41 @@ func (s *Server) CreateConfigFileGroup(ctx context.Context, } // QueryConfigFileGroups 查询配置文件组 -func (s *Server) QueryConfigFileGroups(ctx context.Context, +func (s *ServerAuthability) QueryConfigFileGroups(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigGroupAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeConfigFileGroups) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(authcommon.ConvertToErrCode(err)) + authCtx := s.collectConfigGroupAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileGroups) + + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendConfigGroupPredicate(ctx, func(ctx context.Context, cfg *model.ConfigFileGroup) bool { - ok := s.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_ConfigGroups, - ID: strconv.FormatUint(cfg.Id, 10), - Metadata: cfg.Metadata, - }) - if ok { - return true - } - saveNs := s.cacheMgr.Namespace().GetNamespace(cfg.Namespace) - if saveNs == nil { - return false - } - // 检查下是否可以访问对应的 namespace - return s.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Namespaces, - ID: saveNs.Name, - Metadata: saveNs.Metadata, - }) - }) - authCtx.SetRequestContext(ctx) - resp := s.nextServer.QueryConfigFileGroups(ctx, filter) if len(resp.ConfigFileGroups) != 0 { for index := range resp.ConfigFileGroups { - item := resp.ConfigFileGroups[index] - authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_ConfigGroups: { - { - Type: apisecurity.ResourceType_ConfigGroups, - ID: strconv.FormatUint(item.GetId().GetValue(), 10), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateConfigFileGroup}) - // 如果检查不通过,设置 editable 为 false - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = utils.NewBoolValue(false) - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteConfigFileGroup}) - // 如果检查不通过,设置 editable 为 false - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = utils.NewBoolValue(false) + group := resp.ConfigFileGroups[index] + editable := true + // 如果包含特殊标签,也不允许修改 + if _, ok := group.GetMetadata()[model.MetaKey3RdPlatform]; ok { + editable = false } + group.Editable = utils.NewBoolValue(editable) } } return resp } // DeleteConfigFileGroup 删除配置文件组 -func (s *Server) DeleteConfigFileGroup( +func (s *ServerAuthability) DeleteConfigFileGroup( ctx context.Context, namespace, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{{Name: utils.NewStringValue(name), Namespace: utils.NewStringValue(namespace)}}, auth.Delete, auth.DeleteConfigFileGroup) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -133,13 +90,13 @@ func (s *Server) DeleteConfigFileGroup( } // UpdateConfigFileGroup 更新配置文件组 -func (s *Server) UpdateConfigFileGroup(ctx context.Context, +func (s *ServerAuthability) UpdateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, auth.Modify, auth.UpdateConfigFileGroup) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_release.go b/config/interceptor/auth/config_file_release.go index 599d281fa..3a2df3e04 100644 --- a/config/interceptor/auth/config_file_release.go +++ b/config/interceptor/auth/config_file_release.go @@ -28,14 +28,14 @@ import ( ) // PublishConfigFile 发布配置文件 -func (s *Server) PublishConfigFile(ctx context.Context, +func (s *ServerAuthability) PublishConfigFile(ctx context.Context, configFileRelease *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{configFileRelease}, auth.Modify, "PublishConfigFile") - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -45,14 +45,14 @@ func (s *Server) PublishConfigFile(ctx context.Context, } // GetConfigFileRelease 获取配置文件发布内容 -func (s *Server) GetConfigFileRelease(ctx context.Context, +func (s *ServerAuthability) GetConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{req}, auth.Read, auth.DescribeConfigFileRelease) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -60,13 +60,13 @@ func (s *Server) GetConfigFileRelease(ctx context.Context, } // DeleteConfigFileReleases implements ConfigCenterServer. -func (s *Server) DeleteConfigFileReleases(ctx context.Context, +func (s *ServerAuthability) DeleteConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Delete, auth.DeleteConfigFileReleases) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -74,12 +74,13 @@ func (s *Server) DeleteConfigFileReleases(ctx context.Context, } // GetConfigFileReleaseVersions implements ConfigCenterServer. -func (s *Server) GetConfigFileReleaseVersions(ctx context.Context, +func (s *ServerAuthability) GetConfigFileReleaseVersions(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { + authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseVersions) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -87,28 +88,27 @@ func (s *Server) GetConfigFileReleaseVersions(ctx context.Context, } // GetConfigFileReleases implements ConfigCenterServer. -func (s *Server) GetConfigFileReleases(ctx context.Context, +func (s *ServerAuthability) GetConfigFileReleases(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleases) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return s.nextServer.GetConfigFileReleases(ctx, filters) } // RollbackConfigFileReleases implements ConfigCenterServer. -func (s *Server) RollbackConfigFileReleases(ctx context.Context, +func (s *ServerAuthability) RollbackConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Modify, auth.RollbackConfigFileReleases) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -116,12 +116,12 @@ func (s *Server) RollbackConfigFileReleases(ctx context.Context, } // UpsertAndReleaseConfigFile . -func (s *Server) UpsertAndReleaseConfigFile(ctx context.Context, +func (s *ServerAuthability) UpsertAndReleaseConfigFile(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.UpsertAndReleaseConfigFile) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -130,12 +130,12 @@ func (s *Server) UpsertAndReleaseConfigFile(ctx context.Context, return s.nextServer.UpsertAndReleaseConfigFile(ctx, req) } -func (s *Server) StopGrayConfigFileReleases(ctx context.Context, +func (s *ServerAuthability) StopGrayConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Modify, auth.StopGrayConfigFileReleases) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) } diff --git a/config/interceptor/auth/config_file_release_history.go b/config/interceptor/auth/config_file_release_history.go index c2c0bfc75..3b4464c7e 100644 --- a/config/interceptor/auth/config_file_release_history.go +++ b/config/interceptor/auth/config_file_release_history.go @@ -28,14 +28,14 @@ import ( ) // GetConfigFileReleaseHistory 获取配置文件发布历史记录 -func (s *Server) GetConfigFileReleaseHistories(ctx context.Context, +func (s *ServerAuthability) GetConfigFileReleaseHistories(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { + authCtx := s.collectConfigFileReleaseHistoryAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseHistories) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } - ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return s.nextServer.GetConfigFileReleaseHistories(ctx, filter) diff --git a/config/interceptor/auth/config_file_template.go b/config/interceptor/auth/config_file_template.go index 7d83e465f..dae5f9ab9 100644 --- a/config/interceptor/auth/config_file_template.go +++ b/config/interceptor/auth/config_file_template.go @@ -28,10 +28,10 @@ import ( ) // GetAllConfigFileTemplates get all config file templates -func (s *Server) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.ConfigBatchQueryResponse { +func (s *ServerAuthability) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeAllConfigFileTemplates) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } @@ -41,10 +41,10 @@ func (s *Server) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.Confi } // GetConfigFileTemplate get config file template -func (s *Server) GetConfigFileTemplate(ctx context.Context, name string) *apiconfig.ConfigResponse { +func (s *ServerAuthability) GetConfigFileTemplate(ctx context.Context, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeConfigFileTemplate) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } @@ -54,12 +54,12 @@ func (s *Server) GetConfigFileTemplate(ctx context.Context, name string) *apicon } // CreateConfigFileTemplate create config file template -func (s *Server) CreateConfigFileTemplate(ctx context.Context, +func (s *ServerAuthability) CreateConfigFileTemplate(ctx context.Context, template *apiconfig.ConfigFileTemplate) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{template}, auth.Create, auth.CreateConfigFileTemplate) - if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } diff --git a/config/interceptor/auth/resource_listener.go b/config/interceptor/auth/resource_listener.go index 23d73792f..f1fda9b46 100644 --- a/config/interceptor/auth/resource_listener.go +++ b/config/interceptor/auth/resource_listener.go @@ -30,12 +30,12 @@ import ( ) // Before this function is called before the resource operation -func (s *Server) Before(ctx context.Context, resourceType model.Resource) { +func (s *ServerAuthability) Before(ctx context.Context, resourceType model.Resource) { // do nothing } // After this function is called after the resource operation -func (s *Server) After(ctx context.Context, resourceType model.Resource, res *config.ResourceEvent) error { +func (s *ServerAuthability) After(ctx context.Context, resourceType model.Resource, res *config.ResourceEvent) error { switch resourceType { case model.RConfigGroup: return s.onConfigGroupResource(ctx, res) @@ -45,7 +45,7 @@ func (s *Server) After(ctx context.Context, resourceType model.Resource, res *co } // onConfigGroupResource -func (s *Server) onConfigGroupResource(ctx context.Context, res *config.ResourceEvent) error { +func (s *ServerAuthability) onConfigGroupResource(ctx context.Context, res *config.ResourceEvent) error { authCtx := ctx.Value(utils.ContextAuthContextKey).(*auth.AcquireContext) authCtx.SetAttachment(auth.ResourceAttachmentKey, map[apisecurity.ResourceType][]auth.ResourceEntry{ @@ -69,5 +69,5 @@ func (s *Server) onConfigGroupResource(ctx context.Context, res *config.Resource authCtx.SetAttachment(auth.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) authCtx.SetAttachment(auth.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) - return s.policySvr.AfterResourceOperation(authCtx) + return s.policyMgr.AfterResourceOperation(authCtx) } diff --git a/config/interceptor/auth/server.go b/config/interceptor/auth/server.go index f7b808d99..4fece2d2b 100644 --- a/config/interceptor/auth/server.go +++ b/config/interceptor/auth/server.go @@ -33,23 +33,23 @@ import ( "github.com/polarismesh/polaris/config" ) -var _ config.ConfigCenterServer = (*Server)(nil) +var _ config.ConfigCenterServer = (*ServerAuthability)(nil) // Server 配置中心核心服务 -type Server struct { +type ServerAuthability struct { cacheMgr cachetypes.CacheManager nextServer config.ConfigCenterServer - userSvr auth.UserServer - policySvr auth.StrategyServer + userMgn auth.UserServer + policyMgr auth.StrategyServer } func New(nextServer config.ConfigCenterServer, cacheMgr cachetypes.CacheManager, - userSvr auth.UserServer, policySvr auth.StrategyServer) config.ConfigCenterServer { - proxy := &Server{ + userMgr auth.UserServer, strategyMgr auth.StrategyServer) config.ConfigCenterServer { + proxy := &ServerAuthability{ nextServer: nextServer, cacheMgr: cacheMgr, - userSvr: userSvr, - policySvr: policySvr, + userMgn: userMgr, + policyMgr: strategyMgr, } if val, ok := nextServer.(*config.Server); ok { val.SetResourceHooks(proxy) @@ -57,7 +57,7 @@ func New(nextServer config.ConfigCenterServer, cacheMgr cachetypes.CacheManager, return proxy } -func (s *Server) collectConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, +func (s *ServerAuthability) collectConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -68,7 +68,7 @@ func (s *Server) collectConfigFileAuthContext(ctx context.Context, req []*apicon ) } -func (s *Server) collectClientConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, +func (s *ServerAuthability) collectClientConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -80,8 +80,8 @@ func (s *Server) collectClientConfigFileAuthContext(ctx context.Context, req []* ) } -func (s *Server) collectClientWatchConfigFiles(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *ServerAuthability) collectClientWatchConfigFiles(ctx context.Context, + req *apiconfig.ClientWatchConfigFileRequest, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), @@ -92,7 +92,7 @@ func (s *Server) collectClientWatchConfigFiles(ctx context.Context, req *apiconf ) } -func (s *Server) collectConfigFileReleaseAuthContext(ctx context.Context, req []*apiconfig.ConfigFileRelease, +func (s *ServerAuthability) collectConfigFileReleaseAuthContext(ctx context.Context, req []*apiconfig.ConfigFileRelease, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -103,7 +103,7 @@ func (s *Server) collectConfigFileReleaseAuthContext(ctx context.Context, req [] ) } -func (s *Server) collectConfigFilePublishAuthContext(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo, +func (s *ServerAuthability) collectConfigFilePublishAuthContext(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -114,8 +114,8 @@ func (s *Server) collectConfigFilePublishAuthContext(ctx context.Context, req [] ) } -func (s *Server) collectClientConfigFileRelease(ctx context.Context, req []*apiconfig.ConfigFileRelease, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *ServerAuthability) collectClientConfigFileReleaseAuthContext(ctx context.Context, + req []*apiconfig.ConfigFileRelease, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), @@ -126,7 +126,7 @@ func (s *Server) collectClientConfigFileRelease(ctx context.Context, req []*apic ) } -func (s *Server) collectConfigFileReleaseHistoryAuthContext( +func (s *ServerAuthability) collectConfigFileReleaseHistoryAuthContext( ctx context.Context, req []*apiconfig.ConfigFileReleaseHistory, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { @@ -139,7 +139,7 @@ func (s *Server) collectConfigFileReleaseHistoryAuthContext( ) } -func (s *Server) collectConfigGroupAuthContext(ctx context.Context, req []*apiconfig.ConfigFileGroup, +func (s *ServerAuthability) collectConfigGroupAuthContext(ctx context.Context, req []*apiconfig.ConfigFileGroup, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -150,15 +150,15 @@ func (s *Server) collectConfigGroupAuthContext(ctx context.Context, req []*apico ) } -func (s *Server) collectConfigFileTemplateAuthContext(ctx context.Context, req []*apiconfig.ConfigFileTemplate, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *ServerAuthability) collectConfigFileTemplateAuthContext(ctx context.Context, + req []*apiconfig.ConfigFileTemplate, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), ) } -func (s *Server) queryConfigGroupResource(ctx context.Context, +func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, req []*apiconfig.ConfigFileGroup) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -188,7 +188,7 @@ func (s *Server) queryConfigGroupResource(ctx context.Context, } // queryConfigFileResource config file资源的鉴权转换为config group的鉴权 -func (s *Server) queryConfigFileResource(ctx context.Context, +func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, req []*apiconfig.ConfigFile) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -214,7 +214,7 @@ func (s *Server) queryConfigFileResource(ctx context.Context, return ret } -func (s *Server) queryConfigFileReleaseResource(ctx context.Context, +func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, req []*apiconfig.ConfigFileRelease) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -240,7 +240,7 @@ func (s *Server) queryConfigFileReleaseResource(ctx context.Context, return ret } -func (s *Server) queryConfigFilePublishResource(ctx context.Context, +func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -264,7 +264,7 @@ func (s *Server) queryConfigFilePublishResource(ctx context.Context, return ret } -func (s *Server) queryConfigFileReleaseHistoryResource(ctx context.Context, +func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Context, req []*apiconfig.ConfigFileReleaseHistory) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -290,7 +290,7 @@ func (s *Server) queryConfigFileReleaseHistoryResource(ctx context.Context, return ret } -func (s *Server) queryConfigGroupRsEntryByNames(ctx context.Context, namespace string, +func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, namespace string, names []string) ([]authcommon.ResourceEntry, error) { configFileGroups := make([]*model.ConfigFileGroup, 0, len(names)) @@ -315,7 +315,7 @@ func (s *Server) queryConfigGroupRsEntryByNames(ctx context.Context, namespace s return entries, nil } -func (s *Server) queryWatchConfigFilesResource(ctx context.Context, +func (s *ServerAuthability) queryWatchConfigFilesResource(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest) map[apisecurity.ResourceType][]authcommon.ResourceEntry { files := req.GetWatchFiles() if len(files) == 0 { diff --git a/go.mod b/go.mod index f2590eca3..f347292b3 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,14 @@ go 1.21 require ( github.com/BurntSushi/toml v1.2.0 github.com/emicklei/go-restful/v3 v3.9.0 - github.com/envoyproxy/go-control-plane v0.13.0 + github.com/envoyproxy/go-control-plane v0.12.0 github.com/go-openapi/spec v0.20.7 github.com/go-redis/redis/v8 v8.11.5 github.com/go-sql-driver/mysql v1.6.0 github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 - github.com/golang/protobuf v1.5.4 - github.com/google/uuid v1.6.0 + github.com/golang/protobuf v1.5.3 + github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru v0.5.4 github.com/json-iterator/go v1.1.12 // indirect @@ -24,17 +24,17 @@ require ( github.com/prometheus/client_golang v1.18.0 github.com/smartystreets/goconvey v1.6.4 github.com/spf13/cobra v1.2.1 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.8.4 go.uber.org/atomic v1.10.0 go.uber.org/automaxprocs v1.4.0 go.uber.org/zap v1.23.0 - golang.org/x/crypto v0.23.0 - golang.org/x/net v0.25.0 - golang.org/x/sync v0.7.0 - golang.org/x/text v0.15.0 + golang.org/x/crypto v0.21.0 + golang.org/x/net v0.23.0 + golang.org/x/sync v0.6.0 + golang.org/x/text v0.14.0 golang.org/x/time v0.1.1-0.20221020023724-80b9fac54d29 - google.golang.org/grpc v1.65.0 - google.golang.org/protobuf v1.34.1 + google.golang.org/grpc v1.58.3 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -48,11 +48,11 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect + github.com/envoyproxy/protoc-gen-validate v1.0.2 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonreference v0.20.0 // indirect github.com/go-openapi/swag v0.19.15 // indirect @@ -65,33 +65,31 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.6.0 // indirect + github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/smartystreets/assertions v1.0.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + go.uber.org/goleak v1.1.12 // indirect go.uber.org/multierr v1.8.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/sys v0.18.0 // indirect + google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 + gopkg.in/yaml.v3 v3.0.1 // indirect ) require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/polarismesh/specification v1.5.3-alpha.2 + github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 ) -require ( - cel.dev/expr v0.15.0 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect - github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect -) +require github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect require ( github.com/dlclark/regexp2 v1.10.0 go.etcd.io/bbolt v1.3.7 - google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect ) replace gopkg.in/yaml.v2 => gopkg.in/yaml.v2 v2.2.2 diff --git a/go.sum b/go.sum index ef02521f5..9310d5282 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= -cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -63,8 +61,8 @@ github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqO github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -72,8 +70,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= -github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4 h1:/inchEIKaYC1Akx+H+gqO04wryn5h75LSazbRlnya1k= +github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -94,11 +92,11 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.13.0 h1:HzkeUz1Knt+3bK+8LG1bxOO/jzWZmdxpwC51i202les= -github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8= +github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI= +github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= -github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= +github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= +github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -156,8 +154,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -171,8 +169,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -190,8 +188,8 @@ github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -294,20 +292,18 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= -github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= -github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219 h1:XnFyNUWnciM6zgXaz6tm+Egs35rhoD0KGMmKh4gCdi0= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219/go.mod h1:4WhwBysTom9Eoy0hQ4W69I0FmO+T0EpjEW9/5sgHoUk= -github.com/polarismesh/specification v1.5.3-alpha.2 h1:QSgpGmx5VfPcDPAq7qnTOkMVFNpmBMgLSDhtyMlS6/g= -github.com/polarismesh/specification v1.5.3-alpha.2/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= +github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 h1:eLZzl6yaeuQbTRYeZDu9cqrkJt7qlrt+fE7hGB+R3bc= +github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= -github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= @@ -344,8 +340,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -369,8 +365,8 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= -go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= -go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= @@ -384,8 +380,8 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -458,8 +454,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -483,8 +479,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -528,8 +524,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -540,8 +536,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -599,6 +595,7 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -673,10 +670,12 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw= -google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 h1:Zy9XzmMEflZ/MAaA7vNcoebnRAld7FsPW1EeBB7V0m8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= +google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g= +google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0= +google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw= +google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -697,8 +696,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= -google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= +google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= +google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -711,8 +710,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/import-format.sh b/import-format.sh index 9807c399e..a6335c62c 100644 --- a/import-format.sh +++ b/import-format.sh @@ -17,7 +17,7 @@ # 格式化 go.mod go mod tidy -compat=1.17 -docker run -t --rm -v $(pwd):/app -w /app golangci/golangci-lint:v1.61.0 golangci-lint run -v --timeout 30m +docker run -t --rm -v $(pwd):/app -w /app golangci/golangci-lint:v1.55.2 golangci-lint run -v # 处理 go imports 的格式化 rm -rf style_tool diff --git a/namespace/api.go b/namespace/api.go index b6732c24a..84edaacb2 100644 --- a/namespace/api.go +++ b/namespace/api.go @@ -30,6 +30,8 @@ type NamespaceOperateServer interface { CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response // CreateNamespaces Batch creation namespace CreateNamespaces(ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse + // DeleteNamespace Delete a single namespace + DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response // DeleteNamespaces Batch delete namespace DeleteNamespaces(ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse // UpdateNamespaces Batch update naming space diff --git a/namespace/default.go b/namespace/default.go index a172c730b..bebab4e8b 100644 --- a/namespace/default.go +++ b/namespace/default.go @@ -20,13 +20,14 @@ package namespace import ( "context" "errors" - "fmt" "sync" "golang.org/x/sync/singleflight" + "github.com/polarismesh/polaris/auth" "github.com/polarismesh/polaris/cache" cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/store" ) @@ -35,36 +36,17 @@ var ( namespaceServer = &Server{} once sync.Once finishInit bool - // serverProxyFactories Service Server API 代理工厂 - serverProxyFactories = map[string]ServerProxyFactory{} ) -type ServerProxyFactory func(context.Context, NamespaceOperateServer, cachetypes.CacheManager) (NamespaceOperateServer, error) - -func RegisterServerProxy(name string, factor ServerProxyFactory) error { - if _, ok := serverProxyFactories[name]; ok { - return fmt.Errorf("duplicate ServerProxyFactory, name(%s)", name) - } - serverProxyFactories[name] = factor - return nil -} - type Config struct { - AutoCreate bool `yaml:"autoCreate"` - Interceptors []string `yaml:"-"` + AutoCreate bool `yaml:"autoCreate"` } // Initialize 初始化 -func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMgr *cache.CacheManager) error { +func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager) error { var err error once.Do(func() { - actualSvr, proxySvr, err := InitServer(ctx, nsOpt, storage, cacheMgr) - if err != nil { - return - } - namespaceServer = actualSvr - server = proxySvr - return + err = initialize(ctx, nsOpt, storage, cacheMgn) }) if err != nil { @@ -75,36 +57,35 @@ func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMg return nil } -func InitServer(ctx context.Context, nsOpt *Config, storage store.Store, - cacheMgr *cache.CacheManager) (*Server, NamespaceOperateServer, error) { - if err := cacheMgr.OpenResourceCache(cachetypes.ConfigEntry{ +func initialize(_ context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager) error { + if err := cacheMgn.OpenResourceCache(cachetypes.ConfigEntry{ Name: cachetypes.NamespaceName, }); err != nil { - return nil, nil, err + return err + } + namespaceServer.caches = cacheMgn + namespaceServer.storage = storage + namespaceServer.cfg = *nsOpt + namespaceServer.createNamespaceSingle = &singleflight.Group{} + + // 获取History插件,注意:插件的配置在bootstrap已经设置好 + namespaceServer.history = plugin.GetHistory() + if namespaceServer.history == nil { + log.Warn("Not Found History Log Plugin") + } + + userMgn, err := auth.GetUserServer() + if err != nil { + return err } - actualSvr := new(Server) - actualSvr.caches = cacheMgr - actualSvr.storage = storage - actualSvr.cfg = *nsOpt - actualSvr.createNamespaceSingle = &singleflight.Group{} - - var proxySvr NamespaceOperateServer - proxySvr = actualSvr - order := GetChainOrder() - for i := range order { - factory, exist := serverProxyFactories[order[i]] - if !exist { - return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) - } - - afterSvr, err := factory(ctx, proxySvr, cacheMgr) - if err != nil { - return nil, nil, err - } - proxySvr = afterSvr + strategyMgn, err := auth.GetStrategyServer() + if err != nil { + return err } - return actualSvr, proxySvr, nil + + server = newServerAuthAbility(namespaceServer, userMgn, strategyMgn) + return nil } // GetServer 获取已经初始化好的Server @@ -124,9 +105,3 @@ func GetOriginServer() (*Server, error) { return namespaceServer, nil } - -func GetChainOrder() []string { - return []string{ - "auth", - } -} diff --git a/namespace/interceptor/auth/log.go b/namespace/interceptor/auth/log.go deleted file mode 100644 index 78a77e607..000000000 --- a/namespace/interceptor/auth/log.go +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package auth - -import ( - commonlog "github.com/polarismesh/polaris/common/log" -) - -var ( - authLog = commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName) -) diff --git a/namespace/interceptor/register.go b/namespace/interceptor/register.go deleted file mode 100644 index f3925a25b..000000000 --- a/namespace/interceptor/register.go +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package interceptor - -import ( - "context" - - "github.com/polarismesh/polaris/auth" - cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/namespace" - ns_auth "github.com/polarismesh/polaris/namespace/interceptor/auth" -) - -type ( - ContextKeyUserSvr struct{} - ContextKeyPolicySvr struct{} -) - -func init() { - err := namespace.RegisterServerProxy("auth", func(ctx context.Context, - pre namespace.NamespaceOperateServer, cacheSvr cachetypes.CacheManager) (namespace.NamespaceOperateServer, error) { - - var userSvr auth.UserServer - var policySvr auth.StrategyServer - - userSvrVal := ctx.Value(ContextKeyUserSvr{}) - if userSvrVal == nil { - svr, err := auth.GetUserServer() - if err != nil { - return nil, err - } - userSvr = svr - } else { - userSvr = userSvrVal.(auth.UserServer) - } - - policySvrVal := ctx.Value(ContextKeyPolicySvr{}) - if policySvrVal == nil { - svr, err := auth.GetStrategyServer() - if err != nil { - return nil, err - } - policySvr = svr - } else { - policySvr = policySvrVal.(auth.StrategyServer) - } - - return ns_auth.NewServer(pre, userSvr, policySvr, cacheSvr), nil - }) - if err != nil { - panic(err) - } -} diff --git a/namespace/namespace.go b/namespace/namespace.go index 83643a04e..07e9f10cc 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -19,6 +19,7 @@ package namespace import ( "context" + "fmt" "time" "github.com/golang/protobuf/jsonpb" @@ -26,7 +27,6 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" - cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" commonstore "github.com/polarismesh/polaris/common/store" @@ -94,6 +94,8 @@ func (s *Server) CreateNamespaceIfAbsent(ctx context.Context, req *apimodel.Name // CreateNamespace 创建单个命名空间 func (s *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { + requestID, _ := ctx.Value(utils.StringContext("request-id")).(string) + // 参数检查 if checkError := checkCreateNamespace(req); checkError != nil { return checkError @@ -104,18 +106,19 @@ func (s *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) * // 检查是否存在 namespace, err := s.storage.GetNamespace(namespaceName) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace != nil { return api.NewNamespaceResponse(apimodel.Code_ExistedResource, req) } + // data := s.createNamespaceModel(req) // 存储层操作 if err := s.storage.AddNamespace(data); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } @@ -140,7 +143,6 @@ func (s *Server) createNamespaceModel(req *apimodel.Namespace) *model.Namespace Owner: req.GetOwners().GetValue(), Token: utils.NewUUID(), ServiceExportTo: model.ExportToMap(req.GetServiceExportTo()), - Metadata: req.GetMetadata(), } return namespace } @@ -162,6 +164,8 @@ func (s *Server) DeleteNamespaces(ctx context.Context, req []*apimodel.Namespace // DeleteNamespace 删除单个命名空间 func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { + requestID, _ := ctx.Value(utils.StringContext("request-id")).(string) + // 参数检查 if checkError := checkReviseNamespace(ctx, req); checkError != nil { return checkError @@ -169,7 +173,7 @@ func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) * tx, err := s.storage.CreateTransaction() if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } defer func() { _ = tx.Commit() }() @@ -177,38 +181,47 @@ func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) * // 检查是否存在 namespace, err := tx.LockNamespace(req.GetName().GetValue()) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace == nil { return api.NewNamespaceResponse(apimodel.Code_ExecuteSuccess, req) } + // // 鉴权 + // if ok := s.authority.VerifyNamespace(namespace.Token, parseNamespaceToken(ctx, req)); !ok { + // return api.NewNamespaceResponse(api.Unauthorized, req) + // } + // 判断属于该命名空间的服务是否都已经被删除 total, err := s.getServicesCountWithNamespace(namespace.Name) if err != nil { - log.Error("get services count with namespace err", utils.RequestID(ctx), zap.Error(err)) + log.Error("get services count with namespace err", + utils.ZapRequestID(requestID), + zap.String("err", err.Error())) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if total != 0 { - log.Error("the removed namespace has remain services", utils.RequestID(ctx)) + log.Error("the removed namespace has remain services", utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(apimodel.Code_NamespaceExistedServices, req) } // 判断属于该命名空间的服务是否都已经被删除 total, err = s.getConfigGroupCountWithNamespace(namespace.Name) if err != nil { - log.Error("get config group count with namespace err", utils.RequestID(ctx), zap.Error(err)) + log.Error("get config group count with namespace err", + utils.ZapRequestID(requestID), + zap.String("err", err.Error())) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if total != 0 { - log.Error("the removed namespace has remain config-group", utils.RequestID(ctx)) + log.Error("the removed namespace has remain config-group", utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(apimodel.Code_NamespaceExistedConfigGroups, req) } // 存储层操作 if err := tx.DeleteNamespace(namespace.Name); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } @@ -249,16 +262,19 @@ func (s *Server) UpdateNamespace(ctx context.Context, req *apimodel.Namespace) * if resp != nil { return resp } + + rid := utils.ParseRequestID(ctx) // 修改 s.updateNamespaceAttribute(req, namespace) // 存储层操作 if err := s.storage.UpdateNamespace(namespace); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(rid)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } - log.Info("update namespace", zap.String("name", namespace.Name), utils.RequestID(ctx)) + msg := fmt.Sprintf("update namespace: name=%s", namespace.Name) + log.Info(msg, utils.ZapRequestID(rid)) s.RecordHistory(namespaceRecordEntry(ctx, req, model.OUpdate)) if err := s.afterNamespaceResource(ctx, req, namespace, false); err != nil { @@ -285,7 +301,6 @@ func (s *Server) updateNamespaceAttribute(req *apimodel.Namespace, namespace *mo exportTo[req.GetServiceExportTo()[i].GetValue()] = struct{}{} } - namespace.Metadata = req.GetMetadata() namespace.ServiceExportTo = exportTo } @@ -299,16 +314,18 @@ func (s *Server) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespa return resp } + rid := utils.ParseRequestID(ctx) // 生成token token := utils.NewUUID() // 存储层操作 if err := s.storage.UpdateNamespaceToken(namespace.Name, token); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(rid)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } - log.Info("update namespace token", zap.String("name", namespace.Name), utils.RequestID(ctx)) + msg := fmt.Sprintf("update namespace token: name=%s", namespace.Name) + log.Info(msg, utils.ZapRequestID(rid)) s.RecordHistory(namespaceRecordEntry(ctx, req, model.OUpdateToken)) out := &apimodel.Namespace{ @@ -326,11 +343,7 @@ func (s *Server) GetNamespaces(ctx context.Context, query map[string][]string) * return checkError } - amount, namespaces, err := s.caches.Namespace().Query(ctx, &cachetypes.NamespaceArgs{ - Filter: filter, - Offset: offset, - Limit: limit, - }) + namespaces, amount, err := s.storage.GetNamespaces(filter, offset, limit) if err != nil { return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -352,9 +365,6 @@ func (s *Server) GetNamespaces(ctx context.Context, query map[string][]string) * TotalInstanceCount: utils.NewUInt32Value(nsCntInfo.InstanceCnt.TotalInstanceCount), TotalHealthInstanceCount: utils.NewUInt32Value(nsCntInfo.InstanceCnt.HealthyInstanceCount), ServiceExportTo: namespace.ListServiceExportTo(), - Editable: utils.NewBoolValue(true), - Deleteable: utils.NewBoolValue(true), - Metadata: namespace.Metadata, }) totalServiceCount += nsCntInfo.ServiceCount totalInstanceCount += nsCntInfo.InstanceCnt.TotalInstanceCount @@ -425,18 +435,25 @@ func (s *Server) loadNamespace(name string) (string, error) { // 检查namespace的权限,并且返回namespace func (s *Server) checkNamespaceAuthority( ctx context.Context, req *apimodel.Namespace) (*model.Namespace, *apiservice.Response) { + rid := utils.ParseRequestID(ctx) namespaceName := req.GetName().GetValue() // namespaceToken := parseNamespaceToken(ctx, req) // 检查是否存在 namespace, err := s.storage.GetNamespace(namespaceName) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(rid)) return nil, api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace == nil { return nil, api.NewNamespaceResponse(apimodel.Code_NotFoundResource, req) } + + // 鉴权 + // if ok := s.authority.VerifyNamespace(namespace.Token, namespaceToken); !ok { + // return nil, api.NewNamespaceResponse(api.Unauthorized, req) + // } + return namespace, nil } diff --git a/namespace/interceptor/auth/server.go b/namespace/namespace_authability.go similarity index 50% rename from namespace/interceptor/auth/server.go rename to namespace/namespace_authability.go index 2421a14f0..98209e163 100644 --- a/namespace/interceptor/auth/server.go +++ b/namespace/namespace_authability.go @@ -15,7 +15,7 @@ * specific language governing permissions and limitations under the License. */ -package auth +package namespace import ( "context" @@ -23,56 +23,29 @@ import ( apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "go.uber.org/zap" - "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/namespace" ) -var _ namespace.NamespaceOperateServer = (*Server)(nil) - -// Server 带有鉴权能力的 NamespaceOperateServer -// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 -type Server struct { - nextSvr namespace.NamespaceOperateServer - userSvr auth.UserServer - policySvr auth.StrategyServer - cacheSvr cachetypes.CacheManager -} - -func NewServer(nextSvr namespace.NamespaceOperateServer, userSvr auth.UserServer, - policySvr auth.StrategyServer, cacheSvr cachetypes.CacheManager) namespace.NamespaceOperateServer { - proxy := &Server{ - nextSvr: nextSvr, - userSvr: userSvr, - policySvr: policySvr, - cacheSvr: cacheSvr, - } - - if actualSvr, ok := nextSvr.(*namespace.Server); ok { - actualSvr.SetResourceHooks(proxy) - } - return proxy -} +var _ NamespaceOperateServer = (*serverAuthAbility)(nil) // CreateNamespaceIfAbsent Create a single name space -func (svr *Server) CreateNamespaceIfAbsent(ctx context.Context, +func (svr *serverAuthAbility) CreateNamespaceIfAbsent(ctx context.Context, req *apimodel.Namespace) (string, *apiservice.Response) { - return svr.nextSvr.CreateNamespaceIfAbsent(ctx, req) + return svr.targetServer.CreateNamespaceIfAbsent(ctx, req) } // CreateNamespace 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 -func (svr *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *serverAuthAbility) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Create, authcommon.CreateNamespace) // 验证 token 信息 if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -83,17 +56,17 @@ func (svr *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) req.Owners = utils.NewStringValue(ownerId) } - return svr.nextSvr.CreateNamespace(ctx, req) + return svr.targetServer.CreateNamespace(ctx, req) } // CreateNamespaces 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 -func (svr *Server) CreateNamespaces( +func (svr *serverAuthAbility) CreateNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateNamespaces) // 验证 token 信息 if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -107,155 +80,98 @@ func (svr *Server) CreateNamespaces( req.Owners = utils.NewStringValue(ownerId) } } - return svr.nextSvr.CreateNamespaces(ctx, reqs) + + return svr.targetServer.CreateNamespaces(ctx, reqs) +} + +// DeleteNamespace 删除命名空间,需要先走权限检查 +func (svr *serverAuthAbility) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { + authCtx := svr.collectNamespaceAuthContext( + ctx, []*apimodel.Namespace{req}, authcommon.Delete, authcommon.DeleteNamespace) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.targetServer.DeleteNamespace(ctx, req) } // DeleteNamespaces 删除命名空间,需要先走权限检查 -func (svr *Server) DeleteNamespaces( +func (svr *serverAuthAbility) DeleteNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteNamespaces) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.DeleteNamespaces(ctx, reqs) + return svr.targetServer.DeleteNamespaces(ctx, reqs) } // UpdateNamespaces 更新命名空间,需要先走权限检查 -func (svr *Server) UpdateNamespaces( +func (svr *serverAuthAbility) UpdateNamespaces( ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, req, authcommon.Modify, authcommon.UpdateNamespaces) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.UpdateNamespaces(ctx, req) + return svr.targetServer.UpdateNamespaces(ctx, req) } // UpdateNamespaceToken 更新命名空间的token信息,需要先走权限检查 -func (svr *Server) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *serverAuthAbility) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Modify, authcommon.UpdateNamespaceToken) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.UpdateNamespaceToken(ctx, req) + return svr.targetServer.UpdateNamespaceToken(ctx, req) } // GetNamespaces 获取命名空间列表信息,暂时不走权限检查 -func (svr *Server) GetNamespaces( +func (svr *serverAuthAbility) GetNamespaces( ctx context.Context, query map[string][]string) *apiservice.BatchQueryResponse { authCtx := svr.collectNamespaceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeNamespaces) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendNamespacePredicate(ctx, func(ctx context.Context, n *model.Namespace) bool { + cachetypes.AppendNamespacePredicate(ctx, func(ctx context.Context, n *model.Namespace) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Namespaces, - ID: n.Name, - Metadata: n.Metadata, + Type: apisecurity.ResourceType_Users, + ID: n.Name, }) }) - authCtx.SetRequestContext(ctx) - resp := svr.nextSvr.GetNamespaces(ctx, query) - for i := range resp.Namespaces { - item := resp.Namespaces[i] - authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Namespaces: { - { - Type: apisecurity.ResourceType_Namespaces, - ID: item.GetId().GetValue(), - Metadata: item.GetMetadata(), - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateNamespaces}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = utils.NewBoolValue(false) - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteNamespaces}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = utils.NewBoolValue(false) - } - } - return resp + return svr.targetServer.GetNamespaces(ctx, query) } // GetNamespaceToken 获取命名空间的token信息,暂时不走权限检查 -func (svr *Server) GetNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *serverAuthAbility) GetNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Read, authcommon.DescribeNamespaceToken) _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.GetNamespaceToken(ctx, req) -} - -// collectNamespaceAuthContext 对于命名空间的处理,收集所有的与鉴权的相关信息 -func (svr *Server) collectNamespaceAuthContext(ctx context.Context, req []*apimodel.Namespace, - resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { - return authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), - authcommon.WithModule(authcommon.CoreModule), - authcommon.WithMethod(methodName), - authcommon.WithAccessResources(svr.queryNamespaceResource(req)), - ) -} - -// queryNamespaceResource 根据所给的 namespace 信息,收集对应的 ResourceEntry 列表 -func (svr *Server) queryNamespaceResource( - req []*apimodel.Namespace) map[apisecurity.ResourceType][]authcommon.ResourceEntry { - if len(req) == 0 { - return map[apisecurity.ResourceType][]authcommon.ResourceEntry{} - } - - names := utils.NewSet[string]() - for index := range req { - names.Add(req[index].Name.GetValue()) - } - param := names.ToSlice() - nsArr := svr.cacheSvr.Namespace().GetNamespacesByName(param) - - temp := make([]authcommon.ResourceEntry, 0, len(nsArr)) - - for index := range nsArr { - ns := nsArr[index] - temp = append(temp, authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Namespaces, - ID: ns.Name, - Owner: ns.Owner, - }) - } - - ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Namespaces: temp, - } - authLog.Debug("[Auth][Server] collect namespace access res", zap.Any("res", ret)) - return ret + return svr.targetServer.GetNamespaceToken(ctx, req) } diff --git a/namespace/interceptor/auth/resource_listener.go b/namespace/resource_listener.go similarity index 72% rename from namespace/interceptor/auth/resource_listener.go rename to namespace/resource_listener.go index 9ec120d94..0fb7268d3 100644 --- a/namespace/interceptor/auth/resource_listener.go +++ b/namespace/resource_listener.go @@ -15,27 +15,48 @@ * specific language governing permissions and limitations under the License. */ -package auth +package namespace import ( "context" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/namespace" ) +// ResourceHook The listener is placed before and after the resource operation, only normal flow +type ResourceHook interface { + + // Before + // @param ctx + // @param resourceType + Before(ctx context.Context, resourceType model.Resource) + + // After + // @param ctx + // @param resourceType + // @param res + After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error +} + +// ResourceEvent 资源事件 +type ResourceEvent struct { + ReqNamespace *apimodel.Namespace + Namespace *model.Namespace + IsRemove bool +} + // Before this function is called before the resource operation -func (svr *Server) Before(ctx context.Context, resourceType model.Resource) { +func (svr *serverAuthAbility) Before(ctx context.Context, resourceType model.Resource) { // do nothing } // After this function is called after the resource operation -func (svr *Server) After(ctx context.Context, resourceType model.Resource, res *namespace.ResourceEvent) error { +func (svr *serverAuthAbility) After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error { switch resourceType { case model.RNamespace: return svr.onNamespaceResource(ctx, res) @@ -45,7 +66,7 @@ func (svr *Server) After(ctx context.Context, resourceType model.Resource, res * } // onNamespaceResource -func (svr *Server) onNamespaceResource(ctx context.Context, res *namespace.ResourceEvent) error { +func (svr *serverAuthAbility) onNamespaceResource(ctx context.Context, res *ResourceEvent) error { authCtx, _ := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) if authCtx == nil { log.Warn("[Namespace][ResourceHook] get auth context is nil, ignore", utils.RequestID(ctx)) diff --git a/namespace/server.go b/namespace/server.go index f3f313642..cefedd11f 100644 --- a/namespace/server.go +++ b/namespace/server.go @@ -36,6 +36,7 @@ type Server struct { caches *cache.CacheManager createNamespaceSingle *singleflight.Group cfg Config + history plugin.History hooks []ResourceHook } @@ -60,7 +61,7 @@ func (s *Server) afterNamespaceResource(ctx context.Context, req *apimodel.Names // RecordHistory server对外提供history插件的简单封装 func (s *Server) RecordHistory(entry *model.RecordEntry) { // 如果插件没有初始化,那么不记录history - if plugin.GetHistory() == nil { + if s.history == nil { return } // 如果数据为空,则不需要打印了 @@ -69,32 +70,10 @@ func (s *Server) RecordHistory(entry *model.RecordEntry) { } // 调用插件记录history - plugin.GetHistory().Record(entry) + s.history.Record(entry) } // SetResourceHooks 返回Cache func (s *Server) SetResourceHooks(hooks ...ResourceHook) { s.hooks = hooks } - -// ResourceHook The listener is placed before and after the resource operation, only normal flow -type ResourceHook interface { - - // Before - // @param ctx - // @param resourceType - Before(ctx context.Context, resourceType model.Resource) - - // After - // @param ctx - // @param resourceType - // @param res - After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error -} - -// ResourceEvent 资源事件 -type ResourceEvent struct { - ReqNamespace *apimodel.Namespace - Namespace *model.Namespace - IsRemove bool -} diff --git a/namespace/server_authability.go b/namespace/server_authability.go new file mode 100644 index 000000000..aa7cb0a18 --- /dev/null +++ b/namespace/server_authability.go @@ -0,0 +1,102 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package namespace + +import ( + "context" + "errors" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "go.uber.org/zap" + + "github.com/polarismesh/polaris/auth" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" +) + +// serverAuthAbility 带有鉴权能力的 discoverServer +// +// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 +type serverAuthAbility struct { + targetServer *Server + userMgn auth.UserServer + policySvr auth.StrategyServer +} + +func newServerAuthAbility(targetServer *Server, + userMgn auth.UserServer, policySvr auth.StrategyServer) NamespaceOperateServer { + proxy := &serverAuthAbility{ + targetServer: targetServer, + userMgn: userMgn, + policySvr: policySvr, + } + + targetServer.SetResourceHooks(proxy) + return proxy +} + +// collectNamespaceAuthContext 对于命名空间的处理,收集所有的与鉴权的相关信息 +func (svr *serverAuthAbility) collectNamespaceAuthContext(ctx context.Context, req []*apimodel.Namespace, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.CoreModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryNamespaceResource(req)), + ) +} + +// queryNamespaceResource 根据所给的 namespace 信息,收集对应的 ResourceEntry 列表 +func (svr *serverAuthAbility) queryNamespaceResource( + req []*apimodel.Namespace) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + names := utils.NewSet[string]() + for index := range req { + names.Add(req[index].Name.GetValue()) + } + param := names.ToSlice() + nsArr := svr.targetServer.caches.Namespace().GetNamespacesByName(param) + + temp := make([]authcommon.ResourceEntry, 0, len(nsArr)) + + for index := range nsArr { + ns := nsArr[index] + temp = append(temp, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Namespaces, + ID: ns.Name, + Owner: ns.Owner, + }) + } + + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Namespaces: temp, + } + authLog.Debug("[Auth][Server] collect namespace access res", zap.Any("res", ret)) + return ret +} + +func convertToErrCode(err error) apimodel.Code { + if errors.Is(err, authcommon.ErrorTokenNotExist) { + return apimodel.Code_TokenNotExisted + } + if errors.Is(err, authcommon.ErrorTokenDisabled) { + return apimodel.Code_TokenDisabled + } + return apimodel.Code_NotAllowedAccess +} diff --git a/namespace/test_export.go b/namespace/test_export.go new file mode 100644 index 000000000..fbb98a16b --- /dev/null +++ b/namespace/test_export.go @@ -0,0 +1,51 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package namespace + +import ( + "context" + + "golang.org/x/sync/singleflight" + + "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/plugin" + "github.com/polarismesh/polaris/store" +) + +func TestInitialize(_ context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager, + userMgn auth.UserServer, strategyMgn auth.StrategyServer) (NamespaceOperateServer, error) { + _ = cacheMgn.OpenResourceCache(cachetypes.ConfigEntry{ + Name: cachetypes.NamespaceName, + }) + nsOpt.AutoCreate = true + namespaceServer := &Server{} + namespaceServer.caches = cacheMgn + namespaceServer.storage = storage + namespaceServer.cfg = *nsOpt + namespaceServer.createNamespaceSingle = &singleflight.Group{} + + // 获取History插件,注意:插件的配置在bootstrap已经设置好 + namespaceServer.history = plugin.GetHistory() + if namespaceServer.history == nil { + log.Warn("Not Found History Log Plugin") + } + + return newServerAuthAbility(namespaceServer, userMgn, strategyMgn), nil +} diff --git a/plugin.go b/plugin.go index 8b30e70c8..9494a723d 100644 --- a/plugin.go +++ b/plugin.go @@ -18,7 +18,6 @@ package main import ( - _ "github.com/polarismesh/polaris/admin/interceptor" _ "github.com/polarismesh/polaris/apiserver/eurekaserver" _ "github.com/polarismesh/polaris/apiserver/grpcserver/config" _ "github.com/polarismesh/polaris/apiserver/grpcserver/discover" @@ -35,7 +34,6 @@ import ( _ "github.com/polarismesh/polaris/cache/namespace" _ "github.com/polarismesh/polaris/cache/service" _ "github.com/polarismesh/polaris/config/interceptor" - _ "github.com/polarismesh/polaris/namespace/interceptor" _ "github.com/polarismesh/polaris/plugin/cmdb/memory" _ "github.com/polarismesh/polaris/plugin/crypto/aes" _ "github.com/polarismesh/polaris/plugin/discoverevent/local" diff --git a/plugin/statis/logger/statis.go b/plugin/statis/logger/statis.go index 8ce86b59a..6a8121c09 100644 --- a/plugin/statis/logger/statis.go +++ b/plugin/statis/logger/statis.go @@ -97,7 +97,7 @@ func (s *StatisWorker) ReportConfigMetrics(metric ...metrics.ConfigMetrics) { // ReportDiscoverCall report discover service times func (s *StatisWorker) ReportDiscoverCall(metric metrics.ClientDiscoverMetric) { - discoverlog.Info(metric.String()) + discoverlog.Infof(metric.String()) } func (a *StatisWorker) metricsHandle(mt metrics.CallMetricType, start time.Time, diff --git a/plugin/sync.go b/plugin/sync.go deleted file mode 100644 index 01a0ef11a..000000000 --- a/plugin/sync.go +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package plugin - -import "github.com/polarismesh/polaris/common/model" - -// Syncer . -type Syncer interface { - Plugin - // Run - Run() error - // DebugHandlers return debug handlers - DebugHandlers() []model.DebugHandler -} diff --git a/release/build.sh b/release/build.sh index 76e262f84..19b8e7f0b 100755 --- a/release/build.sh +++ b/release/build.sh @@ -66,7 +66,7 @@ rm -f ${bin_name} export CGO_ENABLED=0 build_date=$(date "+%Y%m%d.%H%M%S") -package="github.com/polarismesh/polaris/common/version" +package="github.com/polarismesh/polaris-server/common/version" sqldb_res="store/mysql" GOARCH=${GOARCH} GOOS=${GOOS} go build -o ${bin_name} -ldflags="-X ${package}.Version=${version} -X ${package}.BuildDate=${build_date}" diff --git a/release/conf/bolt-data.yaml b/release/conf/bolt-data.yaml deleted file mode 100644 index 2c7827204..000000000 --- a/release/conf/bolt-data.yaml +++ /dev/null @@ -1,174 +0,0 @@ -# Tencent is pleased to support the open source community by making Polaris available. -# -# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software distributed -# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -# CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -users: - - name: polaris - token: nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I= - password: $2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum - id: 65e4789a6d5b49669adf1e9e8387549c - tokenenable: true - type: 20 - valid: true -policies: - - id: fbca9bfa04ae4ead86e1ecf5811e32a9 - name: (用户) polaris的默认策略 - action: READ_WRITE - comment: default admin - default: true - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["*"] - resources: - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 6 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 7 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 20 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 0 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 3 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 4 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 5 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 21 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 22 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 23 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 1 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 2 - resid: "*" - conditions: [] - principals: - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - principalid: 65e4789a6d5b49669adf1e9e8387549c - principaltype: 1 - valid: true - revision: fbca9bfa04ae4ead86e1ecf5811e32a9 - metadata: {} - - id: bfa04ae1e32a94fbca9ead86e1ecf581 - name: 全局只读策略 - action: ALLOW - comment: global resources read onyly - default: false - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["Describe*", "List*", "Get*"] - resources: - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 6 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 7 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 20 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 0 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 3 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 4 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 5 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 21 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 22 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 23 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 1 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 2 - resid: "*" - conditions: [] - principals: [] - valid: true - revision: 2a04ae4ead86e1e9bfacf59fbca811e3 - metadata: {} - - id: e3d86e1ecf5812bfa04ae1a94fbca9ea - name: 全局读写策略 - action: ALLOW - comment: global resources read and write - default: false - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["*"] - resources: - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 6 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 7 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 20 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 0 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 3 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 4 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 5 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 21 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 22 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 23 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 1 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 2 - resid: "*" - conditions: [] - principals: [] - valid: true - revision: 4ead86e1e9bfac2a04aef59fbca811e3 - metadata: {} diff --git a/release/conf/polaris-server.yaml b/release/conf/polaris-server.yaml index 2d1558136..91535a854 100644 --- a/release/conf/polaris-server.yaml +++ b/release/conf/polaris-server.yaml @@ -451,7 +451,6 @@ store: name: boltdbStore option: path: ./polaris.bolt - loadFile: ./conf/bolt-data.yaml ## Database storage plugin # name: defaultStore # option: diff --git a/service/api.go b/service/api.go index 676f806cb..b37f19f86 100644 --- a/service/api.go +++ b/service/api.go @@ -20,38 +20,22 @@ package service import ( "context" - apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/common/api/l5" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" ) // DiscoverServer Server discovered by the service type DiscoverServer interface { - // CircuitBreakerOperateServer Fuse rule operation interface definition - CircuitBreakerOperateServer - // RateLimitOperateServer Lamflow rule operation interface definition - RateLimitOperateServer - // RouteRuleOperateServer Routing rules operation interface definition - RouteRuleOperateServer - // RouterRuleOperateServer Routing rules operation interface definition - RouterRuleOperateServer - // FaultDetectRuleOperateServer fault detect rules operation interface definition - FaultDetectRuleOperateServer - // ServiceContractOperateServer service contract rules operation inerface definition - ServiceContractOperateServer + // DiscoverServerV1 DiscoverServerV1 + DiscoverServerV1 // ServiceAliasOperateServer Service alias operation interface definition ServiceAliasOperateServer // ServiceOperateServer Service operation interface definition ServiceOperateServer // InstanceOperateServer Instance Operation Interface Definition InstanceOperateServer - // LaneOperateServer lane rule operation interface definition - LaneOperateServer // ClientServer Client operation interface definition ClientServer // Cache Get cache management @@ -62,240 +46,6 @@ type DiscoverServer interface { GetServiceInstanceRevision(serviceID string, instances []*model.Instance) (string, error) } -// CircuitBreakerOperateServer Melting rule related treatment -type CircuitBreakerOperateServer interface { - // CreateCircuitBreakers Create a CircuitBreaker rule - // Deprecated: not support from 1.14.x - CreateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // CreateCircuitBreakerVersions Create a melt rule version - // Deprecated: not support from 1.14.x - CreateCircuitBreakerVersions(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // DeleteCircuitBreakers Delete CircuitBreaker rules - // Deprecated: not support from 1.14.x - DeleteCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // UpdateCircuitBreakers Modify the CircuitBreaker rule - // Deprecated: not support from 1.14.x - UpdateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // ReleaseCircuitBreakers Release CircuitBreaker rule - // Deprecated: not support from 1.14.x - ReleaseCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse - // UnBindCircuitBreakers Solution CircuitBreaker rule - // Deprecated: not support from 1.14.x - UnBindCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse - // GetCircuitBreaker Get CircuitBreaker regular according to ID and VERSION - // Deprecated: not support from 1.14.x - GetCircuitBreaker(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerVersions Query all versions of the CircuitBreaker rule - // Deprecated: not support from 1.14.x - GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetMasterCircuitBreakers Query Master CircuitBreaker rules - // Deprecated: not support from 1.14.x - GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetReleaseCircuitBreakers Query the released CircuitBreaker rule according to the rule ID - // Deprecated: not support from 1.14.x - GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerByService Binding CircuitBreaker rule based on service query - // Deprecated: not support from 1.14.x - GetCircuitBreakerByService(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerToken Get CircuitBreaker rules token - // Deprecated: not support from 1.14.x - GetCircuitBreakerToken(ctx context.Context, req *apifault.CircuitBreaker) *apiservice.Response - // CreateCircuitBreakerRules Create a CircuitBreaker rule - CreateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // DeleteCircuitBreakerRules Delete current CircuitBreaker rules - DeleteCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // EnableCircuitBreakerRules Enable the CircuitBreaker rule - EnableCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // UpdateCircuitBreakerRules Modify the CircuitBreaker rule - UpdateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // GetCircuitBreakerRules Query CircuitBreaker rules - GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RateLimitOperateServer Lamflow rule related operation -type RateLimitOperateServer interface { - // CreateRateLimits Create a RateLimit rule - CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // DeleteRateLimits Delete current RateLimit rules - DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // EnableRateLimits Enable the RateLimit rule - EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // UpdateRateLimits Modify the RateLimit rule - UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // GetRateLimits Query RateLimit rules - GetRateLimits(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RouteRuleOperateServer Routing rules related operations -type RouteRuleOperateServer interface { - // CreateRoutingConfigs Batch creation routing configuration - CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // DeleteRoutingConfigs Batch delete routing configuration - DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // UpdateRoutingConfigs Batch update routing configuration - UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // GetRoutingConfigs Inquiry route configuration to OSS - GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// ServiceOperateServer Service related operations -type ServiceOperateServer interface { - // CreateServices Batch creation service - CreateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // DeleteServices Batch delete service - DeleteServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // UpdateServices Batch update service - UpdateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // UpdateServiceToken Update service token - UpdateServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response - // GetServices Get a list of service - GetServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetAllServices Get all service list - GetAllServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetServicesCount Total number of services - GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse - // GetServiceToken Get service token - GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response - // GetServiceOwner Owner for obtaining service - GetServiceOwner(ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse -} - -// ServiceAliasOperateServer Service alias related operations -type ServiceAliasOperateServer interface { - // CreateServiceAlias Create a service alias - CreateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response - // DeleteServiceAliases Batch delete service alias - DeleteServiceAliases(ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse - // UpdateServiceAlias Update service alias - UpdateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response - // GetServiceAliases Get a list of service alias - GetServiceAliases(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// InstanceOperateServer Example related operations -type InstanceOperateServer interface { - // CreateInstances Batch creation instance - CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse - // DeleteInstances Batch delete instance - DeleteInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // DeleteInstancesByHost Delete instance according to HOST information batch - DeleteInstancesByHost(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // UpdateInstances Batch update instance - UpdateInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // UpdateInstancesIsolate Batch update instance isolation state - UpdateInstancesIsolate(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // GetInstances Get an instance list - GetInstances(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetInstancesCount Get an instance quantity - GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse - // GetInstanceLabels Get an instance tag under a service - GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response -} - -// ClientServer Client related operation Client operation interface definition -type ClientServer interface { - // RegisterInstance create one instance by client - RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // DeregisterInstance delete onr instance by client - DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // ReportClient Client gets geographic location information - ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response - // GetPrometheusTargets Used to obtain the ReportClient information and serve as the SD result of Prometheus - GetPrometheusTargets(ctx context.Context, query map[string]string) *model.PrometheusDiscoveryResponse - // GetServiceWithCache Used for client acquisition service information - GetServiceWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // ServiceInstancesCache Used for client acquisition service instance information - ServiceInstancesCache(ctx context.Context, filter *apiservice.DiscoverFilter, req *apiservice.Service) *apiservice.DiscoverResponse - // GetRoutingConfigWithCache User Client Get Service Routing Configuration Information - GetRoutingConfigWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetRateLimitWithCache User Client Get Service Limit Configuration Information - GetRateLimitWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetCircuitBreakerWithCache Fuse configuration information for obtaining services for clients - GetCircuitBreakerWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetFaultDetectWithCache User Client Get FaultDetect Rule Information - GetFaultDetectWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetServiceContractWithCache User Client Get ServiceContract Rule Information - GetServiceContractWithCache(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response - // UpdateInstance update one instance by client - UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // ReportServiceContract client report service_contract - ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response - // GetLaneRuleWithCache fetch lane rules by client - GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse -} - -// L5OperateServer L5 related operations -type L5OperateServer interface { - // SyncByAgentCmd Get routing information according to SID list - SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) - // RegisterByNameCmd Look for the corresponding SID list according to the list of service names - RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) -} - -// ReportClientOperateServer Report information operation interface on the client -type ReportClientOperateServer interface { - // GetReportClients Query the client information reported - GetReportClients(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RouterRuleOperateServer Routing rules related operations -type RouterRuleOperateServer interface { - // CreateRoutingConfigsV2 Batch creation routing configuration - CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // DeleteRoutingConfigsV2 Batch delete routing configuration - DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // UpdateRoutingConfigsV2 Batch update routing configuration - UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // QueryRoutingConfigsV2 Inquiry route configuration to OSS - QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // EnableRoutings batch enable routing rules - EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse -} - -// FaultDetectRuleOperateServer Fault detect rules related operations -type FaultDetectRuleOperateServer interface { - // CreateFaultDetectRules create the fault detect rule by request - CreateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // DeleteFaultDetectRules delete the fault detect rule by request - DeleteFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // UpdateFaultDetectRules update the fault detect rule by request - UpdateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // GetFaultDetectRules get the fault detect rule by request - GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// ServiceContractOperateServer service contract operations -type ServiceContractOperateServer interface { - // CreateServiceContracts . - CreateServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse - // DeleteServiceContracts . - DeleteServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse - // GetServiceContracts . - GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // CreateServiceContractInterfaces . - CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, - source apiservice.InterfaceDescriptor_Source) *apiservice.Response - // AppendServiceContractInterfaces . - AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, - source apiservice.InterfaceDescriptor_Source) *apiservice.Response - // DeleteServiceContractInterfaces . - DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response - // GetServiceContractVersions . - GetServiceContractVersions(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse -} - -// LaneOperateServer lane operations -type LaneOperateServer interface { - // CreateLaneGroups 批量创建泳道组 - CreateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse - // UpdateLaneGroups 批量更新泳道组 - UpdateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse - // DeleteLaneGroups 批量删除泳道组 - DeleteLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse - // GetLaneGroups 查询泳道组列表 - GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse -} - // ResourceHook The listener is placed before and after the resource operation, only normal flow type ResourceHook interface { @@ -313,9 +63,7 @@ type ResourceHook interface { // ResourceEvent 资源事件 type ResourceEvent struct { - Resource authcommon.ResourceEntry - - AddPrincipals []authcommon.Principal - DelPrincipals []authcommon.Principal - IsRemove bool + ReqService *apiservice.Service + Service *model.Service + IsRemove bool } diff --git a/service/api_v1.go b/service/api_v1.go new file mode 100644 index 000000000..20ac16ccd --- /dev/null +++ b/service/api_v1.go @@ -0,0 +1,266 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service + +import ( + "context" + + apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + + "github.com/polarismesh/polaris/common/api/l5" + "github.com/polarismesh/polaris/common/model" +) + +// CircuitBreakerOperateServer Melting rule related treatment +type CircuitBreakerOperateServer interface { + // CreateCircuitBreakers Create a CircuitBreaker rule + // Deprecated: not support from 1.14.x + CreateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // CreateCircuitBreakerVersions Create a melt rule version + // Deprecated: not support from 1.14.x + CreateCircuitBreakerVersions(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // DeleteCircuitBreakers Delete CircuitBreaker rules + // Deprecated: not support from 1.14.x + DeleteCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // UpdateCircuitBreakers Modify the CircuitBreaker rule + // Deprecated: not support from 1.14.x + UpdateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // ReleaseCircuitBreakers Release CircuitBreaker rule + // Deprecated: not support from 1.14.x + ReleaseCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse + // UnBindCircuitBreakers Solution CircuitBreaker rule + // Deprecated: not support from 1.14.x + UnBindCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse + // GetCircuitBreaker Get CircuitBreaker regular according to ID and VERSION + // Deprecated: not support from 1.14.x + GetCircuitBreaker(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerVersions Query all versions of the CircuitBreaker rule + // Deprecated: not support from 1.14.x + GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetMasterCircuitBreakers Query Master CircuitBreaker rules + // Deprecated: not support from 1.14.x + GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetReleaseCircuitBreakers Query the released CircuitBreaker rule according to the rule ID + // Deprecated: not support from 1.14.x + GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerByService Binding CircuitBreaker rule based on service query + // Deprecated: not support from 1.14.x + GetCircuitBreakerByService(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerToken Get CircuitBreaker rules token + // Deprecated: not support from 1.14.x + GetCircuitBreakerToken(ctx context.Context, req *apifault.CircuitBreaker) *apiservice.Response + // CreateCircuitBreakerRules Create a CircuitBreaker rule + CreateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // DeleteCircuitBreakerRules Delete current CircuitBreaker rules + DeleteCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // EnableCircuitBreakerRules Enable the CircuitBreaker rule + EnableCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // UpdateCircuitBreakerRules Modify the CircuitBreaker rule + UpdateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // GetCircuitBreakerRules Query CircuitBreaker rules + GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RateLimitOperateServer Lamflow rule related operation +type RateLimitOperateServer interface { + // CreateRateLimits Create a RateLimit rule + CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // DeleteRateLimits Delete current RateLimit rules + DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // EnableRateLimits Enable the RateLimit rule + EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // UpdateRateLimits Modify the RateLimit rule + UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // GetRateLimits Query RateLimit rules + GetRateLimits(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RouteRuleOperateServer Routing rules related operations +type RouteRuleOperateServer interface { + // CreateRoutingConfigs Batch creation routing configuration + CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // DeleteRoutingConfigs Batch delete routing configuration + DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // UpdateRoutingConfigs Batch update routing configuration + UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // GetRoutingConfigs Inquiry route configuration to OSS + GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// ServiceOperateServer Service related operations +type ServiceOperateServer interface { + // CreateServices Batch creation service + CreateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // DeleteServices Batch delete service + DeleteServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // UpdateServices Batch update service + UpdateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // UpdateServiceToken Update service token + UpdateServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response + // GetServices Get a list of service + GetServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetAllServices Get all service list + GetAllServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetServicesCount Total number of services + GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse + // GetServiceToken Get service token + GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response + // GetServiceOwner Owner for obtaining service + GetServiceOwner(ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse +} + +// ServiceAliasOperateServer Service alias related operations +type ServiceAliasOperateServer interface { + // CreateServiceAlias Create a service alias + CreateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response + // DeleteServiceAliases Batch delete service alias + DeleteServiceAliases(ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse + // UpdateServiceAlias Update service alias + UpdateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response + // GetServiceAliases Get a list of service alias + GetServiceAliases(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// InstanceOperateServer Example related operations +type InstanceOperateServer interface { + // CreateInstances Batch creation instance + CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse + // DeleteInstances Batch delete instance + DeleteInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // DeleteInstancesByHost Delete instance according to HOST information batch + DeleteInstancesByHost(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // UpdateInstances Batch update instance + UpdateInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // UpdateInstancesIsolate Batch update instance isolation state + UpdateInstancesIsolate(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // GetInstances Get an instance list + GetInstances(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetInstancesCount Get an instance quantity + GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse + // GetInstanceLabels Get an instance tag under a service + GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response +} + +// ClientServer Client related operation Client operation interface definition +type ClientServer interface { + // RegisterInstance create one instance by client + RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // DeregisterInstance delete onr instance by client + DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // ReportClient Client gets geographic location information + ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response + // GetPrometheusTargets Used to obtain the ReportClient information and serve as the SD result of Prometheus + GetPrometheusTargets(ctx context.Context, query map[string]string) *model.PrometheusDiscoveryResponse + // GetServiceWithCache Used for client acquisition service information + GetServiceWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // ServiceInstancesCache Used for client acquisition service instance information + ServiceInstancesCache(ctx context.Context, filter *apiservice.DiscoverFilter, req *apiservice.Service) *apiservice.DiscoverResponse + // GetRoutingConfigWithCache User Client Get Service Routing Configuration Information + GetRoutingConfigWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetRateLimitWithCache User Client Get Service Limit Configuration Information + GetRateLimitWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetCircuitBreakerWithCache Fuse configuration information for obtaining services for clients + GetCircuitBreakerWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetFaultDetectWithCache User Client Get FaultDetect Rule Information + GetFaultDetectWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetServiceContractWithCache User Client Get ServiceContract Rule Information + GetServiceContractWithCache(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response + // UpdateInstance update one instance by client + UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // ReportServiceContract client report service_contract + ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response + // GetLaneRuleWithCache fetch lane rules by client + GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse +} + +// L5OperateServer L5 related operations +type L5OperateServer interface { + // SyncByAgentCmd Get routing information according to SID list + SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) + // RegisterByNameCmd Look for the corresponding SID list according to the list of service names + RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) +} + +// ReportClientOperateServer Report information operation interface on the client +type ReportClientOperateServer interface { + // GetReportClients Query the client information reported + GetReportClients(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RouterRuleOperateServer Routing rules related operations +type RouterRuleOperateServer interface { + // CreateRoutingConfigsV2 Batch creation routing configuration + CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // DeleteRoutingConfigsV2 Batch delete routing configuration + DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // UpdateRoutingConfigsV2 Batch update routing configuration + UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // QueryRoutingConfigsV2 Inquiry route configuration to OSS + QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // EnableRoutings batch enable routing rules + EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse +} + +// FaultDetectRuleOperateServer Fault detect rules related operations +type FaultDetectRuleOperateServer interface { + // CreateFaultDetectRules create the fault detect rule by request + CreateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // DeleteFaultDetectRules delete the fault detect rule by request + DeleteFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // UpdateFaultDetectRules update the fault detect rule by request + UpdateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // GetFaultDetectRules get the fault detect rule by request + GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// ServiceContractOperateServer service contract operations +type ServiceContractOperateServer interface { + // CreateServiceContracts . + CreateServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse + // DeleteServiceContracts . + DeleteServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse + // GetServiceContracts . + GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // CreateServiceContractInterfaces . + CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, + source apiservice.InterfaceDescriptor_Source) *apiservice.Response + // AppendServiceContractInterfaces . + AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, + source apiservice.InterfaceDescriptor_Source) *apiservice.Response + // DeleteServiceContractInterfaces . + DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response + // GetServiceContractVersions . + GetServiceContractVersions(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse +} + +type DiscoverServerV1 interface { + // CircuitBreakerOperateServer Fuse rule operation interface definition + CircuitBreakerOperateServer + // RateLimitOperateServer Lamflow rule operation interface definition + RateLimitOperateServer + // RouteRuleOperateServer Routing rules operation interface definition + RouteRuleOperateServer + // RouterRuleOperateServer Routing rules operation interface definition + RouterRuleOperateServer + // FaultDetectRuleOperateServer fault detect rules operation interface definition + FaultDetectRuleOperateServer + // ServiceContractOperateServer service contract rules operation inerface definition + ServiceContractOperateServer +} diff --git a/service/circuitbreaker_rule.go b/service/circuitbreaker_rule.go index 8b3045362..fbc54b9f4 100644 --- a/service/circuitbreaker_rule.go +++ b/service/circuitbreaker_rule.go @@ -27,13 +27,10 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "go.uber.org/zap" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -53,15 +50,20 @@ func (s *Server) CreateCircuitBreakerRules( // CreateCircuitBreakerRule Create a CircuitBreaker rule func (s *Server) createCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + if resp := checkCircuitBreakerRuleParams(request, false, true); resp != nil { + return resp + } + // 构造底层数据结构 data, err := api2CircuitBreakerRule(request) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewResponse(apimodel.Code_ParseCircuitBreakerException) } exists, err := s.storage.HasCircuitBreakerRuleByName(data.Name, data.Namespace) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { @@ -71,23 +73,96 @@ func (s *Server) createCircuitBreakerRule( // 存储层操作 if err := s.storage.CreateCircuitBreakerRule(data); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } msg := fmt.Sprintf("create circuitBreaker rule: id=%v, name=%v, namespace=%v", data.ID, request.GetName(), request.GetNamespace()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, data, model.OCreate)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: request.GetId(), - Type: security.ResourceType_CircuitBreakerRules, - }, false) + request.Id = data.ID return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, request) } +func checkCircuitBreakerRuleParams( + req *apifault.CircuitBreakerRule, idRequired bool, nameRequired bool) *apiservice.Response { + if req == nil { + return api.NewResponse(apimodel.Code_EmptyRequest) + } + if resp := checkCircuitBreakerRuleParamsDbLen(req); nil != resp { + return resp + } + if nameRequired && len(req.GetName()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) + } + if idRequired && len(req.GetId()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) + } + return nil +} + +func checkCircuitBreakerRuleParamsDbLen(req *apifault.CircuitBreakerRule) *apiservice.Response { + if err := utils.CheckDbRawStrFieldLen( + req.RuleMatcher.GetSource().GetService(), MaxDbServiceNameLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceName) + } + if err := utils.CheckDbRawStrFieldLen( + req.RuleMatcher.GetSource().GetNamespace(), MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetName(), MaxRuleName); err != nil { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), MaxCommentLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceComment) + } + return nil +} + +func circuitBreakerRuleRecordEntry(ctx context.Context, req *apifault.CircuitBreakerRule, md *model.CircuitBreakerRule, + opt model.OperationType) *model.RecordEntry { + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + entry := &model.RecordEntry{ + ResourceType: model.RCircuitBreakerRule, + ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), + Namespace: req.GetNamespace(), + OperationType: opt, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + return entry +} + +var ( + // CircuitBreakerRuleFilters filter circuitbreaker rule query parameters + CircuitBreakerRuleFilters = map[string]bool{ + "brief": true, + "offset": true, + "limit": true, + "id": true, + "name": true, + "namespace": true, + "enable": true, + "level": true, + "service": true, + "serviceNamespace": true, + "srcService": true, + "srcNamespace": true, + "dstService": true, + "dstNamespace": true, + "dstMethod": true, + "description": true, + } +) + // DeleteCircuitBreakerRules Delete current CircuitBreaker rules func (s *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { @@ -102,7 +177,11 @@ func (s *Server) DeleteCircuitBreakerRules( // deleteCircuitBreakerRule delete current CircuitBreaker rule func (s *Server) deleteCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) + requestID := utils.ParseRequestID(ctx) + if resp := checkCircuitBreakerRuleParams(request, true, false); resp != nil { + return resp + } + resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) if resp != nil { if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundCircuitBreaker) { resp.Code = &wrappers.UInt32Value{Value: uint32(apimodel.Code_ExecuteSuccess)} @@ -112,20 +191,16 @@ func (s *Server) deleteCircuitBreakerRule( cbRuleId := &apifault.CircuitBreakerRule{Id: request.GetId()} err := s.storage.DeleteCircuitBreakerRule(request.GetId()) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewAnyDataResponse(apimodel.Code_ParseCircuitBreakerException, cbRuleId) } msg := fmt.Sprintf("delete circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) cbRule := &model.CircuitBreakerRule{ ID: request.GetId(), Name: request.GetName(), Namespace: request.GetNamespace()} s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.ODelete)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: request.GetId(), - Type: security.ResourceType_CircuitBreakerRules, - }, true) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } @@ -142,7 +217,11 @@ func (s *Server) EnableCircuitBreakerRules( func (s *Server) enableCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) + requestID := utils.ParseRequestID(ctx) + if resp := checkCircuitBreakerRuleParams(request, true, false); resp != nil { + return resp + } + resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) if resp != nil { return resp } @@ -155,13 +234,13 @@ func (s *Server) enableCircuitBreakerRule( Revision: utils.NewUUID(), } if err := s.storage.EnableCircuitBreakerRule(cbRule); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return storeError2AnyResponse(err, cbRuleId) } msg := fmt.Sprintf("enable circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.OUpdate)) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) @@ -180,42 +259,46 @@ func (s *Server) UpdateCircuitBreakerRules( func (s *Server) updateCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) + requestID := utils.ParseRequestID(ctx) + if resp := checkCircuitBreakerRuleParams(request, true, true); resp != nil { + return resp + } + resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) if resp != nil { return resp } cbRuleId := &apifault.CircuitBreakerRule{Id: request.GetId()} cbRule, err := api2CircuitBreakerRule(request) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewAnyDataResponse(apimodel.Code_ParseCircuitBreakerException, cbRuleId) } cbRule.ID = request.GetId() exists, err := s.storage.HasCircuitBreakerRuleByNameExcludeId(cbRule.Name, cbRule.Namespace, cbRule.ID) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { return api.NewResponse(apimodel.Code_ServiceExistedCircuitBreakers) } if err := s.storage.UpdateCircuitBreakerRule(cbRule); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return storeError2AnyResponse(err, cbRuleId) } msg := fmt.Sprintf("update circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.OUpdate)) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } -func (s *Server) checkCircuitBreakerRuleExists(ctx context.Context, id string) *apiservice.Response { +func (s *Server) checkCircuitBreakerRuleExists(id, requestID string) *apiservice.Response { exists, err := s.storage.HasCircuitBreakerRule(id) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewResponse(commonstore.StoreCode2APICode(err)) } if !exists { @@ -226,10 +309,24 @@ func (s *Server) checkCircuitBreakerRuleExists(ctx context.Context, id string) * // GetCircuitBreakerRules Query CircuitBreaker rules func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - offset, limit, _ := utils.ParseOffsetAndLimit(query) - total, cbRules, err := s.storage.GetCircuitBreakerRules(query, offset, limit) + offset, limit, err := utils.ParseOffsetAndLimit(query) if err != nil { - log.Error("get circuitbreaker rules store", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + searchFilter := make(map[string]string, len(query)) + for key, value := range query { + if _, ok := CircuitBreakerRuleFilters[key]; !ok { + log.Errorf("params %s is not allowed in querying circuitbreaker rule", key) + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + if value == "" { + continue + } + searchFilter[key] = value + } + total, cbRules, err := s.storage.GetCircuitBreakerRules(searchFilter, offset, limit) + if err != nil { + log.Errorf("get circuitbreaker rules store err: %s", err.Error()) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } out := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) @@ -238,7 +335,7 @@ func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]st for _, cbRule := range cbRules { cbRuleProto, err := circuitBreakerRule2api(cbRule) if nil != err { - log.Error("marshal circuitbreaker rule fail", utils.RequestID(ctx), zap.Error(err)) + log.Errorf("marshal circuitbreaker rule fail: %v", err) continue } if nil == cbRuleProto { @@ -246,34 +343,13 @@ func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]st } err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto) if nil != err { - log.Error("add circuitbreaker rule as any data fail", utils.RequestID(ctx), zap.Error(err)) + log.Errorf("add circuitbreaker rule as any data fail: %v", err) continue } } return out } -// GetAllCircuitBreakerRules Query all router_rule rules -func (s *Server) GetAllCircuitBreakerRules(ctx context.Context) *apiservice.BatchQueryResponse { - return nil -} - -func circuitBreakerRuleRecordEntry(ctx context.Context, req *apifault.CircuitBreakerRule, md *model.CircuitBreakerRule, - opt model.OperationType) *model.RecordEntry { - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - entry := &model.RecordEntry{ - ResourceType: model.RCircuitBreakerRule, - ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), - Namespace: req.GetNamespace(), - OperationType: opt, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - return entry -} - func marshalCircuitBreakerRuleV2(req *apifault.CircuitBreakerRule) (string, error) { r := &apifault.CircuitBreakerRule{ RuleMatcher: req.RuleMatcher, diff --git a/service/circuitbreaker_rule_test.go b/service/circuitbreaker_rule_test.go index 9e8ec5bab..07d9df31a 100644 --- a/service/circuitbreaker_rule_test.go +++ b/service/circuitbreaker_rule_test.go @@ -316,7 +316,7 @@ func TestUpdateCircuitBreakerRule(t *testing.T) { testRule := cbRules[0] testRule.Description = mockDescr resp = discoverSuit.DiscoverServer().UpdateCircuitBreakerRules(discoverSuit.DefaultCtx, []*apifault.CircuitBreakerRule{testRule}) - assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), resp.GetCode().GetValue(), resp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), resp.GetCode().GetValue()) qResp = queryCircuitBreakerRules(discoverSuit, map[string]string{"id": testRule.Id}) assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), qResp.GetCode().GetValue()) diff --git a/service/client_info.go b/service/client_info.go index 105b79e57..bf2ea3854 100644 --- a/service/client_info.go +++ b/service/client_info.go @@ -80,9 +80,12 @@ func (s *Server) createClient(ctx context.Context, req *apiservice.Client) (*mod // req 原始请求 // ins 包含了req数据与instanceID,serviceToken func (s *Server) asyncCreateClient(ctx context.Context, req *apiservice.Client) (*model.Client, *apiservice.Response) { + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) future := s.bc.AsyncRegisterClient(req) if err := future.Wait(); err != nil { - log.Error("[Server][ReportClient] async create client", zap.Error(err), utils.RequestID(ctx)) + log.Error("[Server][ReportClient] async create client", zap.Error(err), utils.ZapRequestID(rid), + utils.ZapPlatformID(pid)) if future.Code() == apimodel.Code_ExistedResource { req.Id = utils.NewStringValue(req.GetId().GetValue()) } diff --git a/service/client_v1.go b/service/client_v1.go index 03924acc2..ec5488852 100644 --- a/service/client_v1.go +++ b/service/client_v1.go @@ -175,9 +175,9 @@ func (s *Server) GetServiceWithCache(ctx context.Context, req *apiservice.Servic ) if req.GetNamespace().GetValue() != "" { - revision, svcs = s.Cache().Service().ListServices(ctx, req.GetNamespace().GetValue()) + revision, svcs = s.Cache().Service().ListServices(req.GetNamespace().GetValue()) } else { - revision, svcs = s.Cache().Service().ListAllServices(ctx) + revision, svcs = s.Cache().Service().ListAllServices() } if revision == "" { return resp @@ -226,29 +226,12 @@ func (s *Server) ServiceInstancesCache(ctx context.Context, filter *apiservice.D revisions := make([]string, 0, len(visibleServices)+1) finalInstances := make(map[string]*apiservice.Instance, 128) for _, svc := range visibleServices { - specSvc := &apiservice.Service{ - Id: utils.NewStringValue(svc.ID), - Name: utils.NewStringValue(svc.Name), - Namespace: utils.NewStringValue(svc.Namespace), - } - ret := s.caches.Instance().DiscoverServiceInstances(specSvc.GetId().GetValue(), filter.GetOnlyHealthyInstance()) - // 如果是空实例,则直接跳过,不处理实例列表以及 revision 信息 - if len(ret) == 0 { - continue - } revision := s.caches.Service().GetRevisionWorker().GetServiceInstanceRevision(svc.ID) if revision == "" { revision = utils.NewUUID() } revisions = append(revisions, revision) - - for i := range ret { - copyIns := s.getInstance(specSvc, ret[i].Proto) - // 注意:这里的value是cache的,不修改cache的数据,通过getInstance,浅拷贝一份数据 - finalInstances[copyIns.GetId().GetValue()] = copyIns - } } - aggregateRevision, err := cachetypes.CompositeComputeRevision(revisions) if err != nil { log.Errorf("[Server][Service][Instance] compute multi revision service(%s) err: %s", @@ -259,6 +242,20 @@ func (s *Server) ServiceInstancesCache(ctx context.Context, filter *apiservice.D return api.NewDiscoverInstanceResponse(apimodel.Code_DataNoChange, req) } + for _, svc := range visibleServices { + specSvc := &apiservice.Service{ + Id: utils.NewStringValue(svc.ID), + Name: utils.NewStringValue(svc.Name), + Namespace: utils.NewStringValue(svc.Namespace), + } + ret := s.caches.Instance().DiscoverServiceInstances(specSvc.GetId().GetValue(), filter.GetOnlyHealthyInstance()) + for i := range ret { + copyIns := s.getInstance(specSvc, ret[i].Proto) + // 注意:这里的value是cache的,不修改cache的数据,通过getInstance,浅拷贝一份数据 + finalInstances[copyIns.GetId().GetValue()] = copyIns + } + } + // 填充service数据 resp.Service = service2Api(aliasFor) // 这里需要把服务信息改为用户请求的服务名以及命名空间 @@ -280,13 +277,20 @@ func (s *Server) findVisibleServices(serviceName, namespaceName string, req *api visibleServices := make([]*model.Service, 0, 4) // 数据源都来自Cache,这里拿到的service,已经是源服务 aliasFor := s.getServiceCache(serviceName, namespaceName) - if aliasFor != nil { - visibleServices = append(visibleServices, aliasFor) - } - ret := s.caches.Service().GetVisibleServicesInOtherNamespace(serviceName, namespaceName) - if len(ret) > 0 { + if aliasFor == nil { + aliasFor = &model.Service{ + Name: serviceName, + Namespace: namespaceName, + } + ret := s.caches.Service().GetVisibleServicesInOtherNamespace(serviceName, namespaceName) + if len(ret) == 0 { + return nil, nil + } visibleServices = append(visibleServices, ret...) + } else { + visibleServices = append(visibleServices, aliasFor) } + return aliasFor, visibleServices } diff --git a/service/common_test.go b/service/common_test.go index c17ddc342..56262c522 100644 --- a/service/common_test.go +++ b/service/common_test.go @@ -350,7 +350,7 @@ func (d *DiscoverTestSuit) createCommonRoutingConfig( // TODO 是否应该先删除routing resp := d.DiscoverServer().CreateRoutingConfigs(d.DefaultCtx, []*apitraffic.Routing{conf}) - if respSuccess(resp) { + if !respSuccess(resp) { t.Fatalf("error: %+v", resp) } @@ -400,7 +400,7 @@ func (d *DiscoverTestSuit) createCommonRoutingConfigV1IntoOldStore(t *testing.T, } resp := d.OriginDiscoverServer().(*service.Server).CreateRoutingConfig(d.DefaultCtx, conf) - if respSuccess(resp) { + if !respSuccess(resp) { t.Fatalf("error: %+v", resp) } diff --git a/service/default.go b/service/default.go index 1efa4bc14..5609dc9e9 100644 --- a/service/default.go +++ b/service/default.go @@ -28,7 +28,6 @@ import ( "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/plugin" - "github.com/polarismesh/polaris/store" ) const ( @@ -49,7 +48,7 @@ const ( DefaultTLL = 5 ) -type ServerProxyFactory func(pre DiscoverServer, s store.Store) (DiscoverServer, error) +type ServerProxyFactory func(pre DiscoverServer) (DiscoverServer, error) var ( server DiscoverServer @@ -136,7 +135,7 @@ func InitServer(ctx context.Context, namingOpt *Config, opts ...InitOption) (*Se return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) } - afterSvr, err := factory(proxySvr, actualSvr.storage) + afterSvr, err := factory(proxySvr) if err != nil { return nil, nil, err } @@ -145,10 +144,6 @@ func InitServer(ctx context.Context, namingOpt *Config, opts ...InitOption) (*Se return actualSvr, proxySvr, nil } -func (svr *Server) Initialize(context.Context, store.Store) error { - return nil -} - type PluginInstanceEventHandler struct { *BaseInstanceEventHandler subscriber plugin.DiscoverChannel diff --git a/service/faultdetect_config.go b/service/faultdetect_config.go index 23305de47..32ab8d5de 100644 --- a/service/faultdetect_config.go +++ b/service/faultdetect_config.go @@ -27,14 +27,12 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -42,9 +40,9 @@ import ( // CreateFaultDetectRules Create a FaultDetect rule func (s *Server) CreateFaultDetectRules( - ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range reqs { + for _, cbRule := range request { response := s.createFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -53,10 +51,10 @@ func (s *Server) CreateFaultDetectRules( // DeleteFaultDetectRules Delete current Fault Detect rules func (s *Server) DeleteFaultDetectRules( - ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range reqs { + for _, cbRule := range request { response := s.deleteFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -65,10 +63,10 @@ func (s *Server) DeleteFaultDetectRules( // UpdateFaultDetectRules Modify the FaultDetect rule func (s *Server) UpdateFaultDetectRules( - ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range reqs { + for _, cbRule := range request { response := s.updateFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -119,10 +117,7 @@ func (s *Server) createFaultDetectRule(ctx context.Context, request *apifault.Fa log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, data, model.OCreate)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: request.GetId(), - Type: security.ResourceType_FaultDetectRules, - }, false) + request.Id = data.ID return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, request) } @@ -159,31 +154,54 @@ func (s *Server) updateFaultDetectRule(ctx context.Context, request *apifault.Fa // deleteFaultDetectRule Delete a FaultDetect rule func (s *Server) deleteFaultDetectRule(ctx context.Context, request *apifault.FaultDetectRule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + resp := s.checkFaultDetectRuleExists(request.GetId(), requestID) + if resp != nil { + if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundResource) { + resp.Code = &wrappers.UInt32Value{Value: uint32(apimodel.Code_ExecuteSuccess)} + } + return resp + } cbRuleId := &apifault.FaultDetectRule{Id: request.GetId()} err := s.storage.DeleteFaultDetectRule(request.GetId()) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewAnyDataResponse(apimodel.Code_ParseException, cbRuleId) } msg := fmt.Sprintf("delete fault detect rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) cbRule := &model.FaultDetectRule{ID: request.GetId(), Name: request.GetName(), Namespace: request.GetNamespace()} s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, cbRule, model.ODelete)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: request.GetId(), - Type: security.ResourceType_FaultDetectRules, - }, true) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } +func (s *Server) checkFaultDetectRuleExists(id, requestID string) *apiservice.Response { + exists, err := s.storage.HasFaultDetectRule(id) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(requestID)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + if !exists { + return api.NewResponse(apimodel.Code_NotFoundResource) + } + return nil +} + func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { offset, limit, _ := utils.ParseOffsetAndLimit(query) total, cbRules, err := s.caches.FaultDetector().Query(ctx, &cachetypes.FaultDetectArgs{ - Filter: query, - Offset: offset, - Limit: limit, + ID: query["id"], + Name: query["name"], + Namespace: query["namespace"], + Service: query["service"], + ServiceNamespace: query["serviceNamespace"], + DstNamespace: query["dstNamespace"], + DstService: query["dstService"], + DstMethod: query["dstMethod"], + Offset: offset, + Limit: limit, }) if err != nil { log.Errorf("get fault detect rules store err: %s", err.Error()) @@ -201,7 +219,8 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin if nil == cbRuleProto { continue } - if err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto); nil != err { + err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto) + if nil != err { log.Error("add circuitbreaker rule as any data fail", utils.RequestID(ctx), zap.Error(err)) continue } @@ -209,11 +228,6 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin return out } -// GetAllFaultDetectRules Query all router_rule rules -func (s *Server) GetAllFaultDetectRules(ctx context.Context) *apiservice.BatchQueryResponse { - return nil -} - func marshalFaultDetectRule(req *apifault.FaultDetectRule) (string, error) { r := &apifault.FaultDetectRule{ TargetService: req.TargetService, @@ -248,7 +262,6 @@ func api2FaultDetectRule(req *apifault.FaultDetectRule) (*model.FaultDetectRule, DstMethod: req.GetTargetService().GetMethod().GetValue().GetValue(), Rule: rule, Revision: utils.NewUUID(), - Metadata: req.Metadata, } if out.Namespace == "" { out.Namespace = DefaultNamespace @@ -260,29 +273,27 @@ func faultDetectRule2api(fdRule *model.FaultDetectRule) (*apifault.FaultDetectRu if fdRule == nil { return nil, nil } - specData := &apifault.FaultDetectRule{} + fdRule.Proto = &apifault.FaultDetectRule{} if len(fdRule.Rule) > 0 { - if err := json.Unmarshal([]byte(fdRule.Rule), specData); err != nil { + if err := json.Unmarshal([]byte(fdRule.Rule), fdRule.Proto); err != nil { return nil, err } } else { // brief search, to display the services in list result - specData.TargetService = &apifault.FaultDetectRule_DestinationService{ + fdRule.Proto.TargetService = &apifault.FaultDetectRule_DestinationService{ Service: fdRule.DstService, Namespace: fdRule.DstNamespace, Method: &apimodel.MatchString{Value: &wrappers.StringValue{Value: fdRule.DstMethod}}, } } - specData.Id = fdRule.ID - specData.Name = fdRule.Name - specData.Namespace = fdRule.Namespace - specData.Description = fdRule.Description - specData.Revision = fdRule.Revision - specData.Ctime = commontime.Time2String(fdRule.CreateTime) - specData.Mtime = commontime.Time2String(fdRule.ModifyTime) - specData.Editable = true - specData.Deleteable = true - return specData, nil + fdRule.Proto.Id = fdRule.ID + fdRule.Proto.Name = fdRule.Name + fdRule.Proto.Namespace = fdRule.Namespace + fdRule.Proto.Description = fdRule.Description + fdRule.Proto.Revision = fdRule.Revision + fdRule.Proto.Ctime = commontime.Time2String(fdRule.CreateTime) + fdRule.Proto.Mtime = commontime.Time2String(fdRule.ModifyTime) + return fdRule.Proto, nil } // faultDetectRule2ClientAPI 把内部数据结构转化为客户端API参数 diff --git a/service/instance.go b/service/instance.go index aec69fa05..13681b7f7 100644 --- a/service/instance.go +++ b/service/instance.go @@ -592,10 +592,6 @@ func (s *Server) updateInstanceAttribute( } func instanceLocationNeedUpdate(req *apimodel.Location, old *apimodel.Location) bool { - // 如果没有带上,则不进行更新 - if req == nil { - return false - } if req.GetRegion().GetValue() != old.GetRegion().GetValue() { return true } @@ -867,7 +863,7 @@ func (s *Server) instanceAuth(ctx context.Context, req *apiservice.Instance, ser *model.Service, *apiservice.Response) { service, err := s.storage.GetServiceByID(serviceID) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(utils.ParseRequestID(ctx))) return nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } if service == nil { diff --git a/service/instance_test.go b/service/instance_test.go index 69d5b2bd2..a5d25008b 100644 --- a/service/instance_test.go +++ b/service/instance_test.go @@ -324,7 +324,15 @@ func TestUpdateInstanceManyTimes(t *testing.T) { go func(index int) { defer wg.Done() for c := 0; c < 16; c++ { - ret := proto.Clone(instanceReq).(*apiservice.Instance) + marshalVal, err := proto.Marshal(instanceReq) + if err != nil { + errs <- err + return + } + + ret := &apiservice.Instance{} + proto.Unmarshal(marshalVal, ret) + ret.Weight = wrapperspb.UInt32(uint32(rand.Int() % 32767)) if updateResp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{instanceReq}); !respSuccess(updateResp) { errs <- fmt.Errorf("error: %+v", updateResp) diff --git a/service/interceptor/auth/circuitbreaker_rule.go b/service/interceptor/auth/circuitbreaker_rule.go index 99018c85b..99201c302 100644 --- a/service/interceptor/auth/circuitbreaker_rule.go +++ b/service/interceptor/auth/circuitbreaker_rule.go @@ -22,10 +22,7 @@ import ( apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -36,8 +33,7 @@ import ( func (svr *Server) CreateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Create, - authcommon.CreateCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Create, authcommon.CreateCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) @@ -50,8 +46,7 @@ func (svr *Server) CreateCircuitBreakerRules( func (svr *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Delete, - authcommon.DeleteCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Delete, authcommon.DeleteCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -63,8 +58,7 @@ func (svr *Server) DeleteCircuitBreakerRules( func (svr *Server) EnableCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Modify, - authcommon.EnableCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.EnableCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -76,8 +70,7 @@ func (svr *Server) EnableCircuitBreakerRules( func (svr *Server) UpdateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Modify, - authcommon.UpdateCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.UpdateCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -89,8 +82,7 @@ func (svr *Server) UpdateCircuitBreakerRules( func (svr *Server) GetCircuitBreakerRules( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectCircuitBreakerRuleV2(ctx, nil, authcommon.Read, - authcommon.DescribeCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, nil, authcommon.Read, authcommon.DescribeCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } @@ -98,48 +90,13 @@ func (svr *Server) GetCircuitBreakerRules( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendCircuitBreakerRulePredicate(ctx, - func(ctx context.Context, cbr *model.CircuitBreakerRule) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_CircuitBreakerRules, - ID: cbr.ID, - Metadata: cbr.Proto.Metadata, - }) - }) - authCtx.SetRequestContext(ctx) - - resp := svr.nextSvr.GetCircuitBreakerRules(ctx, query) - - for index := range resp.Data { - item := &apifault.CircuitBreakerRule{} - _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_CircuitBreakerRules: { - { - Type: apisecurity.ResourceType_CircuitBreakerRules, - ID: item.GetId(), - Metadata: item.Metadata, - }, - }, + cachetypes.AppendCircuitBreakerRulePredicate(ctx, func(ctx context.Context, cbr *model.CircuitBreakerRule) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_CircuitBreakerRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, }) + }) - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{ - authcommon.UpdateCircuitBreakerRules, - authcommon.EnableCircuitBreakerRules, - }) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = false - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteCircuitBreakerRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = false - } - _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) - } - return resp + return svr.nextSvr.GetCircuitBreakerRules(ctx, query) } diff --git a/service/interceptor/auth/client_v1.go b/service/interceptor/auth/client_v1.go index 4dd6a0fb3..fb1fc8dec 100644 --- a/service/interceptor/auth/client_v1.go +++ b/service/interceptor/auth/client_v1.go @@ -20,11 +20,9 @@ package service_auth import ( "context" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "google.golang.org/protobuf/types/known/wrapperspb" - cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -53,8 +51,10 @@ func (svr *Server) DeregisterInstance(ctx context.Context, req *apiservice.Insta authCtx := svr.collectClientInstanceAuthContext( ctx, []*apiservice.Instance{req}, authcommon.Create, authcommon.DeregisterInstance) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp } ctx = authCtx.GetRequestContext() @@ -76,8 +76,10 @@ func (svr *Server) ReportServiceContract(ctx context.Context, req *apiservice.Se Namespace: wrapperspb.String(req.GetNamespace()), }}, authcommon.Create, authcommon.ReportServiceContract) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp } ctx = authCtx.GetRequestContext() @@ -98,21 +100,15 @@ func (svr *Server) GetServiceWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverServices) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Services, - ID: cbr.ID, - Metadata: cbr.Meta, - }) - }) - authCtx.SetRequestContext(ctx) - return svr.nextSvr.GetServiceWithCache(ctx, req) } @@ -122,8 +118,11 @@ func (svr *Server) ServiceInstancesCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverInstances) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -137,8 +136,11 @@ func (svr *Server) GetRoutingConfigWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRouterRule) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -152,8 +154,11 @@ func (svr *Server) GetRateLimitWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRateLimitRule) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -167,8 +172,11 @@ func (svr *Server) GetCircuitBreakerWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverCircuitBreakerRule) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -176,14 +184,16 @@ func (svr *Server) GetCircuitBreakerWithCache( return svr.nextSvr.GetCircuitBreakerWithCache(ctx, req) } -// GetFaultDetectWithCache 获取主动探测规则列表 func (svr *Server) GetFaultDetectWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverFaultDetectRule) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -196,8 +206,10 @@ func (svr *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) authCtx := svr.collectClientInstanceAuthContext( ctx, []*apiservice.Instance{req}, authcommon.Modify, authcommon.UpdateInstance) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp } ctx = authCtx.GetRequestContext() @@ -214,8 +226,11 @@ func (svr *Server) GetServiceContractWithCache(ctx context.Context, Name: wrapperspb.String(req.Service), }}, authcommon.Read, authcommon.DiscoverServiceContract) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() @@ -228,8 +243,11 @@ func (svr *Server) GetServiceContractWithCache(ctx context.Context, func (svr *Server) GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverLaneRule) - if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) + if err != nil { + resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) + resp.Info = utils.NewStringValue(err.Error()) + return resp } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) diff --git a/service/interceptor/auth/faultdetect_config.go b/service/interceptor/auth/faultdetect_config.go index e3401be52..ef4b09d76 100644 --- a/service/interceptor/auth/faultdetect_config.go +++ b/service/interceptor/auth/faultdetect_config.go @@ -22,10 +22,7 @@ import ( apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -37,7 +34,7 @@ import ( func (svr *Server) CreateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Create, authcommon.CreateFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.CreateFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -49,7 +46,7 @@ func (svr *Server) CreateFaultDetectRules( func (svr *Server) DeleteFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Delete, authcommon.DeleteFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.DeleteFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -61,7 +58,7 @@ func (svr *Server) DeleteFaultDetectRules( func (svr *Server) UpdateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Modify, authcommon.UpdateFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.UpdateFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -79,44 +76,13 @@ func (svr *Server) GetFaultDetectRules( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendFaultDetectRulePredicate(ctx, func(ctx context.Context, cbr *model.FaultDetectRule) bool { + cachetypes.AppendFaultDetectRulePredicate(ctx, func(ctx context.Context, cbr *model.FaultDetectRule) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_FaultDetectRules, ID: cbr.ID, - Metadata: cbr.Proto.GetMetadata(), + Metadata: cbr.Proto.Metadata, }) }) - authCtx.SetRequestContext(ctx) - resp := svr.nextSvr.GetFaultDetectRules(ctx, query) - - for index := range resp.Data { - item := &apifault.FaultDetectRule{} - _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_FaultDetectRules: { - { - Type: apisecurity.ResourceType_FaultDetectRules, - ID: item.GetId(), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateFaultDetectRules, authcommon.EnableFaultDetectRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = false - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteFaultDetectRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = false - } - _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) - } - return resp + return svr.nextSvr.GetFaultDetectRules(ctx, query) } diff --git a/service/interceptor/auth/instance.go b/service/interceptor/auth/instance.go index 58fff9a8a..071d9f856 100644 --- a/service/interceptor/auth/instance.go +++ b/service/interceptor/auth/instance.go @@ -26,7 +26,6 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/service" ) // CreateInstances create instances @@ -34,7 +33,8 @@ func (svr *Server) CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateInstances) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) api.Collect(batchResp, resp) @@ -66,7 +66,7 @@ func (svr *Server) DeleteInstances(ctx context.Context, return svr.nextSvr.DeleteInstances(ctx, reqs) } -// DeleteInstancesByHost 根据 host 信息进行数据删除 +// DeleteInstancesByHost 目前只允许 super account 进行数据删除 func (svr *Server) DeleteInstancesByHost(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteInstancesByHost) @@ -145,38 +145,10 @@ func (svr *Server) GetInstancesCount(ctx context.Context) *apiservice.BatchQuery return svr.nextSvr.GetInstancesCount(ctx) } -// GetInstanceLabels 获取某个服务下的实例标签集合 func (svr *Server) GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response { - var ( - serviceId string - namespace = service.DefaultNamespace - ) - - if val, ok := query["namespace"]; ok { - namespace = val - } - - if svcName, ok := query["service"]; ok { - if svc := svr.Cache().Service().GetServiceByName(svcName, namespace); svc != nil { - serviceId = svc.ID - } - } - - if id, ok := query["service_id"]; ok { - serviceId = id - } - - // TODO 如果在鉴权的时候发现资源不存在,怎么处理? - svc := svr.Cache().Service().GetServiceByID(serviceId) - if svc == nil { - return api.NewResponse(apimodel.Code_NotFoundResource) - } - - authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{ - svc.ToSpec(), - }, authcommon.Read, authcommon.DescribeInstanceLabels) + authCtx := svr.collectInstanceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeInstanceLabels) _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) diff --git a/service/interceptor/auth/lane.go b/service/interceptor/auth/lane.go deleted file mode 100644 index e5d7f109e..000000000 --- a/service/interceptor/auth/lane.go +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package service_auth - -import ( - "context" - - "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - - cachetypes "github.com/polarismesh/polaris/cache/api" - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" -) - -// CreateLaneGroups 批量创建泳道组 -func (svr *Server) CreateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - - authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateLaneGroups) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.CreateLaneGroups(ctx, reqs) -} - -// UpdateLaneGroups 批量更新泳道组 -func (svr *Server) UpdateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateLaneGroups) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.UpdateLaneGroups(ctx, reqs) -} - -// DeleteLaneGroups 批量删除泳道组 -func (svr *Server) DeleteLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteLaneGroups) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.DeleteLaneGroups(ctx, reqs) -} - -// GetLaneGroups 查询泳道组列表 -func (svr *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeFaultDetectRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) - } - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - ctx = cachetypes.AppendLaneRulePredicate(ctx, func(ctx context.Context, cbr *model.LaneGroupProto) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_LaneRules, - ID: cbr.ID, - Metadata: cbr.Proto.Metadata, - }) - }) - authCtx.SetRequestContext(ctx) - - resp := svr.nextSvr.GetLaneGroups(ctx, filter) - - for index := range resp.Data { - item := &apitraffic.LaneGroup{} - _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_LaneRules: { - { - Type: apisecurity.ResourceType_LaneRules, - ID: item.GetId(), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateLaneGroups, authcommon.EnableLaneGroups}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = false - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteLaneGroups}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = false - } - _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) - } - return resp -} - -// collectLaneRuleAuthContext 收集全链路灰度规则 -func (svr *Server) collectLaneRuleAuthContext(ctx context.Context, req []*apitraffic.LaneGroup, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { - - resources := make([]authcommon.ResourceEntry, 0, len(req)) - for i := range req { - saveRule := svr.Cache().LaneRule().GetRule(req[i].GetId()) - if saveRule != nil { - resources = append(resources, authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_LaneRules, - ID: saveRule.ID, - Metadata: saveRule.Labels, - }) - } - } - - return authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(op), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithMethod(methodName), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_LaneRules: resources, - }), - ) -} diff --git a/service/interceptor/auth/ratelimit_config.go b/service/interceptor/auth/ratelimit_config.go index 2eea1fddb..f41a09cac 100644 --- a/service/interceptor/auth/ratelimit_config.go +++ b/service/interceptor/auth/ratelimit_config.go @@ -20,8 +20,8 @@ package service_auth import ( "context" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" @@ -37,8 +37,9 @@ func (svr *Server) CreateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateRateLimitRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } ctx = authCtx.GetRequestContext() @@ -52,8 +53,9 @@ func (svr *Server) DeleteRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteRateLimitRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } ctx = authCtx.GetRequestContext() @@ -67,8 +69,9 @@ func (svr *Server) UpdateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateRateLimitRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } ctx = authCtx.GetRequestContext() @@ -80,10 +83,11 @@ func (svr *Server) UpdateRateLimits( // EnableRateLimits 启用限流规则 func (svr *Server) EnableRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Read, authcommon.EnableRateLimitRules) + authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.EnableRateLimitRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } ctx = authCtx.GetRequestContext() @@ -97,50 +101,21 @@ func (svr *Server) GetRateLimits( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeRateLimitRules) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewAuthBatchQueryResponse(authcommon.ConvertToErrCode(err)) + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendRatelimitRulePredicate(ctx, func(ctx context.Context, cbr *model.RateLimit) bool { + cachetypes.AppendRatelimitRulePredicate(ctx, func(ctx context.Context, cbr *model.RateLimit) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_RateLimitRules, ID: cbr.ID, Metadata: cbr.Proto.Metadata, }) }) - authCtx.SetRequestContext(ctx) - - resp := svr.nextSvr.GetRateLimits(ctx, query) - - for index := range resp.RateLimits { - item := resp.RateLimits[index] - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_RateLimitRules: { - { - Type: apisecurity.ResourceType_RateLimitRules, - ID: item.GetId().GetValue(), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRateLimitRules, authcommon.EnableRateLimitRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = false - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRateLimitRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = false - } - } - return resp + return svr.nextSvr.GetRateLimits(ctx, query) } diff --git a/service/interceptor/auth/resource_listen.go b/service/interceptor/auth/resource_listen.go index c29e464b4..080b364ff 100644 --- a/service/interceptor/auth/resource_listen.go +++ b/service/interceptor/auth/resource_listen.go @@ -35,45 +35,40 @@ func (svr *Server) Before(ctx context.Context, resourceType model.Resource) { // After this function is called after the resource operation func (svr *Server) After(ctx context.Context, resourceType model.Resource, res *service.ResourceEvent) error { - // 资源删除,触发所有关联的策略进行一个 update 操作更新 - return svr.onChangeResource(ctx, res) + switch resourceType { + case model.RService: + return svr.onServiceResource(ctx, res) + default: + return nil + } } -// onChangeResource 服务资源的处理,只处理服务,namespace 只由 namespace 相关的进行处理, -func (svr *Server) onChangeResource(ctx context.Context, res *service.ResourceEvent) error { +// onServiceResource 服务资源的处理,只处理服务,namespace 只由 namespace 相关的进行处理, +func (svr *Server) onServiceResource(ctx context.Context, res *service.ResourceEvent) error { authCtx := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) + ownerId := utils.ParseOwnerID(ctx) authCtx.SetAttachment(authcommon.ResourceAttachmentKey, map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - res.Resource.Type: { - res.Resource, + apisecurity.ResourceType_Services: { + { + ID: res.Service.ID, + Owner: ownerId, + Metadata: res.Service.Meta, + }, }, }) - var users, removeUsers []string - var groups, removeGroups []string + users := utils.ConvertStringValuesToSlice(res.ReqService.UserIds) + removeUses := utils.ConvertStringValuesToSlice(res.ReqService.RemoveUserIds) - for i := range res.AddPrincipals { - switch res.AddPrincipals[i].PrincipalType { - case authcommon.PrincipalUser: - users = append(users, res.AddPrincipals[i].PrincipalID) - case authcommon.PrincipalGroup: - groups = append(groups, res.AddPrincipals[i].PrincipalID) - } - } - for i := range res.DelPrincipals { - switch res.DelPrincipals[i].PrincipalType { - case authcommon.PrincipalUser: - removeUsers = append(removeUsers, res.DelPrincipals[i].PrincipalID) - case authcommon.PrincipalGroup: - removeGroups = append(removeGroups, res.DelPrincipals[i].PrincipalID) - } - } + groups := utils.ConvertStringValuesToSlice(res.ReqService.GroupIds) + removeGroups := utils.ConvertStringValuesToSlice(res.ReqService.RemoveGroupIds) - authCtx.SetAttachment(authcommon.LinkUsersKey, users) - authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, removeUsers) + authCtx.SetAttachment(authcommon.LinkUsersKey, utils.StringSliceDeDuplication(users)) + authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) - authCtx.SetAttachment(authcommon.LinkGroupsKey, groups) - authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, removeGroups) + authCtx.SetAttachment(authcommon.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) + authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) return svr.policySvr.AfterResourceOperation(authCtx) } diff --git a/service/interceptor/auth/routing_config_v1.go b/service/interceptor/auth/routing_config_v1.go index b29287aa0..32066dbbe 100644 --- a/service/interceptor/auth/routing_config_v1.go +++ b/service/interceptor/auth/routing_config_v1.go @@ -20,30 +20,75 @@ package service_auth import ( "context" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" ) // CreateRoutingConfigs creates routing configs func (svr *Server) CreateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Create, "CreateRoutingConfigs") + + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.CreateRoutingConfigs(ctx, reqs) } // DeleteRoutingConfigs deletes routing configs func (svr *Server) DeleteRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Delete, "DeleteRoutingConfigs") + + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.DeleteRoutingConfigs(ctx, reqs) } // UpdateRoutingConfigs updates routing configs func (svr *Server) UpdateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { + authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Modify, "UpdateRoutingConfigs") + + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.UpdateRoutingConfigs(ctx, reqs) } // GetRoutingConfigs gets routing configs func (svr *Server) GetRoutingConfigs( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + authCtx := svr.collectRouteRuleAuthContext(ctx, nil, authcommon.Read, "GetRoutingConfigs") + + _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) + if err != nil { + return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.GetRoutingConfigs(ctx, query) } diff --git a/service/interceptor/auth/routing_config_v2.go b/service/interceptor/auth/routing_config_v2.go index 6da142365..3ed128dad 100644 --- a/service/interceptor/auth/routing_config_v2.go +++ b/service/interceptor/auth/routing_config_v2.go @@ -21,11 +21,8 @@ import ( "context" "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -39,7 +36,7 @@ func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { // TODO not support RouteRuleV2 resource auth, so we set op is read - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Create, authcommon.CreateRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.CreateRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -52,7 +49,7 @@ func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Delete, authcommon.DeleteRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.DeleteRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -65,7 +62,7 @@ func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Modify, authcommon.UpdateRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.UpdateRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -78,7 +75,7 @@ func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, func (svr *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Modify, authcommon.EnableRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.EnableRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -97,43 +94,13 @@ func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendRouterRulePredicate(ctx, func(ctx context.Context, cbr *model.ExtendRouterConfig) bool { + cachetypes.AppendRouterRulePredicate(ctx, func(ctx context.Context, cbr *model.ExtendRouterConfig) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_RouteRules, ID: cbr.ID, Metadata: cbr.Metadata, }) }) - authCtx.SetRequestContext(ctx) - resp := svr.nextSvr.QueryRoutingConfigsV2(ctx, query) - for index := range resp.Data { - item := &apitraffic.RouteRule{} - _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_RouteRules: { - { - Type: apisecurity.ResourceType_RouteRules, - ID: item.GetId(), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRouteRules, authcommon.EnableRouteRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = false - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRouteRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = false - } - _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) - } - return resp + return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) } diff --git a/service/interceptor/auth/server.go b/service/interceptor/auth/server.go index 08750f50f..bb6859493 100644 --- a/service/interceptor/auth/server.go +++ b/service/interceptor/auth/server.go @@ -39,19 +39,20 @@ import ( // 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 type Server struct { nextSvr service.DiscoverServer - userSvr auth.UserServer + userMgn auth.UserServer policySvr auth.StrategyServer } -func NewServer(nextSvr service.DiscoverServer, - userSvr auth.UserServer, policySvr auth.StrategyServer) service.DiscoverServer { +func NewServerAuthAbility(nextSvr service.DiscoverServer, + userMgn auth.UserServer, policySvr auth.StrategyServer) service.DiscoverServer { proxy := &Server{ nextSvr: nextSvr, - userSvr: userSvr, + userMgn: userMgn, policySvr: policySvr, } - if actualSvr, ok := nextSvr.(*service.Server); ok { + actualSvr, ok := nextSvr.(*service.Server) + if ok { actualSvr.SetResourceHooks(proxy) } return proxy @@ -202,23 +203,20 @@ func (svr *Server) collectRouteRuleV2AuthContext(ctx context.Context, req []*api } } - accessResources := map[apisecurity.ResourceType][]authcommon.ResourceEntry{} - if len(resources) != 0 { - accessResources[apisecurity.ResourceType_RouteRules] = resources - } - return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithOperation(resourceOp), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), - authcommon.WithAccessResources(accessResources), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_RouteRules: resources, + }), ) } -// collectCircuitBreakerRuleV2 收集熔断v2规则 -func (svr *Server) collectCircuitBreakerRuleV2(ctx context.Context, req []*apifault.CircuitBreakerRule, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +// collectCircuitBreakerRuleV2AuthContext 收集熔断v2规则 +func (svr *Server) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, + req []*apifault.CircuitBreakerRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { resources := make([]authcommon.ResourceEntry, 0, len(req)) for i := range req { @@ -227,14 +225,14 @@ func (svr *Server) collectCircuitBreakerRuleV2(ctx context.Context, req []*apifa resources = append(resources, authcommon.ResourceEntry{ Type: apisecurity.ResourceType_CircuitBreakerRules, ID: saveRule.ID, - Metadata: saveRule.Proto.GetMetadata(), + Metadata: saveRule.Proto.Metadata, }) } } return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), - authcommon.WithOperation(op), + authcommon.WithOperation(resourceOp), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ @@ -244,8 +242,8 @@ func (svr *Server) collectCircuitBreakerRuleV2(ctx context.Context, req []*apifa } // collectFaultDetectAuthContext 收集主动探测规则 -func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, req []*apifault.FaultDetectRule, - op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, + req []*apifault.FaultDetectRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { resources := make([]authcommon.ResourceEntry, 0, len(req)) for i := range req { @@ -254,14 +252,14 @@ func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, req []*api resources = append(resources, authcommon.ResourceEntry{ Type: apisecurity.ResourceType_FaultDetectRules, ID: saveRule.ID, - Metadata: saveRule.Proto.GetMetadata(), + Metadata: saveRule.Proto.Metadata, }) } } return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), - authcommon.WithOperation(op), + authcommon.WithOperation(resourceOp), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ diff --git a/service/interceptor/auth/service.go b/service/interceptor/auth/service.go index cb94fe0a4..864a33368 100644 --- a/service/interceptor/auth/service.go +++ b/service/interceptor/auth/service.go @@ -67,7 +67,7 @@ func (svr *Server) DeleteServices( authCtx.SetAccessResources(accessRes) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -106,7 +106,7 @@ func (svr *Server) UpdateServiceToken( authCtx.SetAccessResources(accessRes) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -119,34 +119,12 @@ func (svr *Server) GetAllServices(ctx context.Context, authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeAllServices) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { - ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Services, - ID: cbr.ID, - Metadata: cbr.Meta, - }) - if ok { - return true - } - saveNs := svr.Cache().Namespace().GetNamespace(cbr.Namespace) - if saveNs == nil { - return false - } - // 检查下是否可以访问对应的 namespace - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Namespaces, - ID: saveNs.Name, - Metadata: saveNs.Metadata, - }) - }) - authCtx.SetRequestContext(ctx) - return svr.nextSvr.GetAllServices(ctx, query) } @@ -156,60 +134,20 @@ func (svr *Server) GetServices( authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServices) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) // 注入查询条件拦截器 - ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { - ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Services, - ID: cbr.ID, - Metadata: cbr.Meta, - }) - if ok { - return true - } - saveNs := svr.Cache().Namespace().GetNamespace(cbr.Namespace) - if saveNs == nil { - return false - } - // 检查下是否可以访问对应的 namespace - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Namespaces, - ID: saveNs.Name, - Metadata: saveNs.Metadata, - }) - }) - authCtx.SetRequestContext(ctx) resp := svr.nextSvr.GetServices(ctx, query) - for index := range resp.Services { - item := resp.Services[index] - authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - Type: apisecurity.ResourceType_Services, - ID: item.GetId().GetValue(), - Metadata: item.Metadata, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateServices}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = utils.NewBoolValue(false) - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteServices}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = utils.NewBoolValue(false) + if len(resp.Services) != 0 { + for index := range resp.Services { + svc := resp.Services[index] + // TODO 需要配合 metadata 做调整 + svc.Editable = utils.NewBoolValue(true) } } return resp @@ -220,7 +158,7 @@ func (svr *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryR authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServicesCount) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -230,11 +168,10 @@ func (svr *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryR // GetServiceToken 获取服务的 token func (svr *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { - authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{req}, authcommon.Read, - authcommon.DescribeServiceToken) + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceToken) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponse(authcommon.ConvertToErrCode(err)) + return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() @@ -245,14 +182,22 @@ func (svr *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) // GetServiceOwner 获取服务的 owner func (svr *Server) GetServiceOwner( ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, req, authcommon.Read, authcommon.DescribeServiceOwner) + authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceOwner) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + }) + return svr.nextSvr.GetServiceOwner(ctx, req) } diff --git a/service/interceptor/auth/service_alias.go b/service/interceptor/auth/service_alias.go index c0c523269..d52f86538 100644 --- a/service/interceptor/auth/service_alias.go +++ b/service/interceptor/auth/service_alias.go @@ -21,7 +21,6 @@ import ( "context" "github.com/polarismesh/specification/source/go/api/v1/security" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" cachetypes "github.com/polarismesh/polaris/cache/api" @@ -96,53 +95,13 @@ func (svr *Server) GetServiceAliases(ctx context.Context, ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { - sourceSvc := svr.Cache().Service().GetServiceByID(cbr.Reference) - if sourceSvc == nil { - return false - } + cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_Services, - ID: sourceSvc.ID, - Metadata: sourceSvc.Meta, + ID: cbr.ID, + Metadata: cbr.Meta, }) }) - authCtx.SetRequestContext(ctx) - - resp := svr.nextSvr.GetServiceAliases(ctx, query) - for i := range resp.Aliases { - item := resp.Aliases[i] - sourceSvc := svr.Cache().Service().GetServiceByName(item.GetAlias().GetValue(), item.GetAliasNamespace().GetValue()) - if sourceSvc == nil { - item.Editable = utils.NewBoolValue(false) - item.Deleteable = utils.NewBoolValue(false) - continue - } - authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ - security.ResourceType_Services: { - { - Type: apisecurity.ResourceType_Services, - ID: sourceSvc.ID, - Metadata: sourceSvc.Meta, - }, - }, - }) - - // 检查 write 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRateLimitRules, authcommon.EnableRateLimitRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Editable = utils.NewBoolValue(false) - } - - // 检查 delete 操作权限 - authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRateLimitRules}) - // 如果检查不通过,设置 editable 为 false - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - item.Deleteable = utils.NewBoolValue(false) - } - } - - return resp + return svr.nextSvr.GetServiceAliases(ctx, query) } diff --git a/service/interceptor/paramcheck/circuit_breaker.go b/service/interceptor/paramcheck/circuit_breaker.go index 8c69aa98b..452121325 100644 --- a/service/interceptor/paramcheck/circuit_breaker.go +++ b/service/interceptor/paramcheck/circuit_breaker.go @@ -19,7 +19,6 @@ package paramcheck import ( "context" - "strconv" "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" @@ -28,208 +27,117 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/utils" ) -var ( - // CircuitBreakerRuleFilters filter circuitbreaker rule query parameters - CircuitBreakerRuleFilters = map[string]bool{ - "brief": true, - "offset": true, - "limit": true, - "id": true, - "name": true, - "namespace": true, - "enable": true, - "level": true, - "service": true, - "serviceNamespace": true, - "srcService": true, - "srcNamespace": true, - "dstService": true, - "dstNamespace": true, - "dstMethod": true, - "description": true, - } -) - // GetMasterCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetMasterCircuitBreakers(ctx, query) } // GetReleaseCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetReleaseCircuitBreakers(ctx, query) } // GetCircuitBreaker implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreaker(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreaker(ctx, query) } // GetCircuitBreakerByService implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerByService(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerByService(ctx, query) } // DeleteCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) DeleteCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { return svr.nextSvr.DeleteCircuitBreakers(ctx, req) } // GetCircuitBreakerToken implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerToken(ctx context.Context, req *fault_tolerance.CircuitBreaker) *service_manage.Response { return svr.nextSvr.GetCircuitBreakerToken(ctx, req) } // GetCircuitBreakerVersions implements service.DiscoverServer. -// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerVersions(ctx, query) } -// ReleaseCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x -func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { - return svr.nextSvr.ReleaseCircuitBreakers(ctx, req) -} - -// UnBindCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x -func (svr *Server) UnBindCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { - return svr.nextSvr.UnBindCircuitBreakers(ctx, req) -} - -// CreateCircuitBreakerVersions implements service.DiscoverServer. -// Deprecated: not support from 1.14.x -func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, - req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateCircuitBreakerVersions(ctx, req) -} - -// CreateCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x -func (svr *Server) CreateCircuitBreakers(ctx context.Context, - req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateCircuitBreakers(ctx, req) -} - -// UpdateCircuitBreakers implements service.DiscoverServer. -// Deprecated: not support from 1.14.x -func (svr *Server) UpdateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateCircuitBreakers(ctx, req) -} - -// ------------- 这里开始接口实现才是正式有效的 ------------- - // GetCircuitBreakerRules implements service.DiscoverServer. func (svr *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { - - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - searchFilter := make(map[string]string, len(query)) - for key, value := range query { - if _, ok := CircuitBreakerRuleFilters[key]; !ok { - log.Errorf("params %s is not allowed in querying circuitbreaker rule", key) - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - if value == "" { - continue - } - searchFilter[key] = value - } - - searchFilter["offset"] = strconv.FormatUint(uint64(offset), 10) - searchFilter["limit"] = strconv.FormatUint(uint64(limit), 10) - - return svr.nextSvr.GetCircuitBreakerRules(ctx, searchFilter) + return svr.nextSvr.GetCircuitBreakerRules(ctx, query) } // DeleteCircuitBreakerRules implements service.DiscoverServer. func (svr *Server) DeleteCircuitBreakerRules(ctx context.Context, - reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(reqs); err != nil { + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { return err } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkCircuitBreakerRuleParams(reqs[i], true, false) - api.Collect(batchRsp, rsp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.DeleteCircuitBreakerRules(ctx, reqs) + return svr.nextSvr.DeleteCircuitBreakerRules(ctx, request) } // EnableCircuitBreakerRules implements service.DiscoverServer. func (svr *Server) EnableCircuitBreakerRules(ctx context.Context, - reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(reqs); err != nil { + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { return err } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkCircuitBreakerRuleParams(reqs[i], true, false) - api.Collect(batchRsp, rsp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp + return svr.nextSvr.EnableCircuitBreakerRules(ctx, request) +} + +// ReleaseCircuitBreakers implements service.DiscoverServer. +func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { + return svr.nextSvr.ReleaseCircuitBreakers(ctx, req) +} + +// UnBindCircuitBreakers implements service.DiscoverServer. +func (svr *Server) UnBindCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { + return svr.nextSvr.UnBindCircuitBreakers(ctx, req) +} + +// UpdateCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { + return err } - return svr.nextSvr.EnableCircuitBreakerRules(ctx, reqs) + return svr.nextSvr.UpdateCircuitBreakerRules(ctx, request) } // CreateCircuitBreakerRules implements service.DiscoverServer. func (svr *Server) CreateCircuitBreakerRules(ctx context.Context, - reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(reqs); err != nil { + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { return err } + return svr.nextSvr.CreateCircuitBreakerRules(ctx, request) +} - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkCircuitBreakerRuleParams(reqs[i], false, true) - api.Collect(batchRsp, rsp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.CreateCircuitBreakerRules(ctx, reqs) +// CreateCircuitBreakerVersions implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, + req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateCircuitBreakerVersions(ctx, req) } -// UpdateCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, - reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkCircuitBreakerRuleParams(reqs[i], true, true) - api.Collect(batchRsp, rsp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.UpdateCircuitBreakerRules(ctx, reqs) +// CreateCircuitBreakers implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakers(ctx context.Context, + req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateCircuitBreakers(ctx, req) +} + +// UpdateCircuitBreakers implements service.DiscoverServer. +func (svr *Server) UpdateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateCircuitBreakers(ctx, req) } func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { @@ -242,41 +150,3 @@ func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservi } return nil } - -func checkCircuitBreakerRuleParams( - req *apifault.CircuitBreakerRule, idRequired bool, nameRequired bool) *apiservice.Response { - if req == nil { - return api.NewResponse(apimodel.Code_EmptyRequest) - } - if resp := checkCircuitBreakerRuleParamsDbLen(req); nil != resp { - return resp - } - if nameRequired && len(req.GetName()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) - } - if idRequired && len(req.GetId()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) - } - return nil -} - -func checkCircuitBreakerRuleParamsDbLen(req *apifault.CircuitBreakerRule) *apiservice.Response { - if err := utils.CheckDbRawStrFieldLen( - req.RuleMatcher.GetSource().GetService(), utils.MaxDbServiceNameLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceName) - } - if err := utils.CheckDbRawStrFieldLen( - req.RuleMatcher.GetSource().GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetName(), utils.MaxRuleName); err != nil { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), utils.MaxCommentLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceComment) - } - return nil -} diff --git a/service/interceptor/paramcheck/fault_detect.go b/service/interceptor/paramcheck/fault_detect.go index 2aac60ecf..28a8d4ca4 100644 --- a/service/interceptor/paramcheck/fault_detect.go +++ b/service/interceptor/paramcheck/fault_detect.go @@ -119,8 +119,7 @@ func (svr *Server) CreateFaultDetectRules(ctx context.Context, } // UpdateFaultDetectRules implements service.DiscoverServer. -func (svr *Server) UpdateFaultDetectRules(ctx context.Context, - request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { +func (svr *Server) UpdateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { return checkErr } diff --git a/service/interceptor/paramcheck/lane.go b/service/interceptor/paramcheck/lane.go deleted file mode 100644 index db4293804..000000000 --- a/service/interceptor/paramcheck/lane.go +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package paramcheck - -import ( - "context" - "strconv" - - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/wrapperspb" - - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/utils" -) - -var ( - laneGroupSearchAttributes = map[string]struct{}{ - "id": {}, - "name": {}, - "offset": {}, - "brief": {}, - "limit": {}, - "order_type": {}, - "order_field": {}, - } -) - -// CreateLaneGroups 批量创建泳道组 -func (svr *Server) CreateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - if err := checkBatchLaneGroupRules(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkLaneGroupParam(reqs[i], false) - api.Collect(batchRsp, rsp) - } - - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.CreateLaneGroups(ctx, reqs) -} - -// UpdateLaneGroups 批量更新泳道组 -func (svr *Server) UpdateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - if err := checkBatchLaneGroupRules(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - rsp := checkLaneGroupParam(reqs[i], true) - api.Collect(batchRsp, rsp) - } - - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.UpdateLaneGroups(ctx, reqs) -} - -// DeleteLaneGroups 批量删除泳道组 -func (svr *Server) DeleteLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - if err := checkBatchLaneGroupRules(reqs); err != nil { - return err - } - return svr.nextSvr.DeleteLaneGroups(ctx, reqs) -} - -// GetLaneGroups 查询泳道组列表 -func (svr *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { - offset, limit, err := utils.ParseOffsetAndLimit(filter) - if err != nil { - return api.NewBatchQueryResponseWithMsg(apimodel.Code_BadRequest, err.Error()) - } - - for k := range filter { - if _, ok := laneGroupSearchAttributes[k]; !ok { - log.Error("[Server][LaneGroup][Query] not allowed", zap.String("attribute", k), utils.RequestID(ctx)) - return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, k+" is not allowed") - } - if filter[k] == "" { - delete(filter, k) - } - } - - if _, ok := filter["order_field"]; !ok { - filter["order_field"] = "mtime" - } - if _, ok := filter["order_type"]; !ok { - filter["order_type"] = "desc" - } - - filter["offset"] = strconv.FormatUint(uint64(offset), 10) - filter["limit"] = strconv.FormatUint(uint64(limit), 10) - - return svr.nextSvr.GetLaneGroups(ctx, filter) -} - -func checkBatchLaneGroupRules(req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > utils.MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - return nil -} - -func checkLaneGroupParam(req *apitraffic.LaneGroup, update bool) *apiservice.Response { - if len(req.GetName()) >= utils.MaxRuleName { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_group name size must be <= 64") - } - if err := utils.CheckResourceName(wrapperspb.String(req.GetName())); err != nil { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, err.Error()) - } - if len(req.Rules) > utils.MaxBatchSize { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_rule size must be <= 100") - } - for i := range req.Rules { - rule := req.Rules[i] - if err := utils.CheckResourceName(wrapperspb.String(rule.GetName())); err != nil { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, err.Error()) - } - if len(rule.GetName()) >= utils.MaxRuleName { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_rule name size must be <= 64") - } - } - - if update { - if req.GetId() == "" { - return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_group id is empty") - } - } - return nil -} diff --git a/service/interceptor/paramcheck/ratelimit.go b/service/interceptor/paramcheck/ratelimit.go index 0c210a1a1..4720b462f 100644 --- a/service/interceptor/paramcheck/ratelimit.go +++ b/service/interceptor/paramcheck/ratelimit.go @@ -19,80 +19,27 @@ package paramcheck import ( "context" - "time" - "github.com/golang/protobuf/ptypes" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/utils" ) // CreateRateLimits implements service.DiscoverServer. func (svr *Server) CreateRateLimits(ctx context.Context, - reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - if err := checkBatchRateLimits(reqs); err != nil { - return err - } - - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - // 参数校验 - // 参数校验 - if resp := checkRateLimitParams(reqs[i]); resp != nil { - api.Collect(batchRsp, resp) - continue - } - if resp := checkRateLimitRuleParams(ctx, reqs[i]); resp != nil { - api.Collect(batchRsp, resp) - continue - } - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - - return svr.nextSvr.CreateRateLimits(ctx, reqs) + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateRateLimits(ctx, request) } // DeleteRateLimits implements service.DiscoverServer. -func (svr *Server) DeleteRateLimits(ctx context.Context, reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - if err := checkBatchRateLimits(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - // 参数校验 - resp := checkRevisedRateLimitParams(reqs[i]) - api.Collect(batchRsp, resp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.DeleteRateLimits(ctx, reqs) +func (svr *Server) DeleteRateLimits(ctx context.Context, + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteRateLimits(ctx, request) } // EnableRateLimits implements service.DiscoverServer. func (svr *Server) EnableRateLimits(ctx context.Context, - reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - if err := checkBatchRateLimits(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - // 参数校验 - resp := checkRevisedRateLimitParams(reqs[i]) - api.Collect(batchRsp, resp) - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.EnableRateLimits(ctx, reqs) + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.EnableRateLimits(ctx, request) } // GetRateLimits implements service.DiscoverServer. @@ -102,104 +49,6 @@ func (svr *Server) GetRateLimits(ctx context.Context, } // UpdateRateLimits implements service.DiscoverServer. -func (svr *Server) UpdateRateLimits(ctx context.Context, reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - if err := checkBatchRateLimits(reqs); err != nil { - return err - } - batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range reqs { - // 参数校验 - if resp := checkRevisedRateLimitParams(reqs[i]); resp != nil { - api.Collect(batchRsp, resp) - continue - } - if resp := checkRateLimitRuleParams(ctx, reqs[i]); resp != nil { - api.Collect(batchRsp, resp) - continue - } - if resp := checkRateLimitParamsDbLen(reqs[i]); resp != nil { - api.Collect(batchRsp, resp) - continue - } - } - if !api.IsSuccess(batchRsp) { - return batchRsp - } - - return svr.nextSvr.UpdateRateLimits(ctx, reqs) -} - -// checkBatchRateLimits 检查批量请求的限流规则 -func checkBatchRateLimits(req []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > utils.MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// checkRateLimitParams 检查限流规则基础参数 -func checkRateLimitParams(req *apitraffic.Rule) *apiservice.Response { - if req == nil { - return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) - } - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) - } - if resp := checkRateLimitParamsDbLen(req); nil != resp { - return resp - } - return nil -} - -// checkRateLimitParams 检查限流规则基础参数 -func checkRateLimitParamsDbLen(req *apitraffic.Rule) *apiservice.Response { - if err := utils.CheckDbStrFieldLen(req.GetService(), utils.MaxDbServiceNameLength); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetName(), utils.MaxDbRateLimitName); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitName, req) - } - return nil -} - -// checkRateLimitRuleParams 检查限流规则其他参数 -func checkRateLimitRuleParams(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { - // 检查amounts是否有重复周期 - amounts := req.GetAmounts() - durations := make(map[time.Duration]bool) - for _, amount := range amounts { - d := amount.GetValidDuration() - duration, err := ptypes.Duration(d) - if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) - } - durations[duration] = true - } - if len(amounts) != len(durations) { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) - } - return nil -} - -// checkRevisedRateLimitParams 检查修改/删除限流规则基础参数 -func checkRevisedRateLimitParams(req *apitraffic.Rule) *apiservice.Response { - if req == nil { - return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) - } - if req.GetId().GetValue() == "" { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitID, req) - } - return nil +func (svr *Server) UpdateRateLimits(ctx context.Context, request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateRateLimits(ctx, request) } diff --git a/service/interceptor/paramcheck/route_rule.go b/service/interceptor/paramcheck/route_rule.go index 46eab06eb..4da1fedad 100644 --- a/service/interceptor/paramcheck/route_rule.go +++ b/service/interceptor/paramcheck/route_rule.go @@ -19,61 +19,45 @@ package paramcheck import ( "context" - "strconv" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - - apiv1 "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" -) - -var ( - // RoutingConfigV2FilterAttrs router config filter attrs - RoutingConfigV2FilterAttrs = map[string]bool{ - "id": true, - "name": true, - "service": true, - "namespace": true, - "source_service": true, - "destination_service": true, - "source_namespace": true, - "destination_namespace": true, - "enable": true, - "offset": true, - "limit": true, - "order_field": true, - "order_type": true, - } ) // UpdateRoutingConfigs implements service.DiscoverServer. -// Deprecated: not support from 1.19.x func (svr *Server) UpdateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.UpdateRoutingConfigs(ctx, req) } +// UpdateRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) +} + +// QueryRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) +} + // GetRoutingConfigs implements service.DiscoverServer. -// Deprecated: not support from 1.19.x func (svr *Server) GetRoutingConfigs(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetRoutingConfigs(ctx, query) } +// EnableRoutings implements service.DiscoverServer. +func (svr *Server) EnableRoutings(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.EnableRoutings(ctx, req) +} + // CreateRoutingConfigs implements service.DiscoverServer. -// Deprecated: not support from 1.19.x func (svr *Server) CreateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.CreateRoutingConfigs(ctx, req) } // DeleteRoutingConfigs implements service.DiscoverServer. -// Deprecated: not support from 1.19.x func (svr *Server) DeleteRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.DeleteRoutingConfigs(ctx, req) @@ -82,211 +66,11 @@ func (svr *Server) DeleteRoutingConfigs(ctx context.Context, // CreateRoutingConfigsV2 implements service.DiscoverServer. func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, item := range req { - if resp := checkRoutingConfigV2(item); resp != nil { - apiv1.Collect(batchRsp, resp) - } - } - if !apiv1.IsSuccess(batchRsp) { - return batchRsp - } return svr.nextSvr.CreateRoutingConfigsV2(ctx, req) } -// EnableRoutings implements service.DiscoverServer. -func (svr *Server) EnableRoutings(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, item := range req { - if resp := checkRoutingConfigIDV2(item); resp != nil { - apiv1.Collect(batchRsp, resp) - } - } - if !apiv1.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.EnableRoutings(ctx, req) -} - -// UpdateRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, item := range req { - if resp := checkUpdateRoutingConfigV2(item); resp != nil { - apiv1.Collect(batchRsp, resp) - } - } - if !apiv1.IsSuccess(batchRsp) { - return batchRsp - } - return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) -} - -// QueryRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return apiv1.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - - filter := make(map[string]string) - for key, value := range query { - if _, ok := RoutingConfigV2FilterAttrs[key]; !ok { - log.Errorf("[Routing][V2][Query] attribute(%s) is not allowed", key) - return apiv1.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - filter[key] = value - } - filter["offset"] = strconv.FormatUint(uint64(offset), 10) - filter["limit"] = strconv.FormatUint(uint64(limit), 10) - - return svr.nextSvr.QueryRoutingConfigsV2(ctx, filter) -} - // DeleteRoutingConfigsV2 implements service.DiscoverServer. func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, item := range req { - if resp := checkRoutingConfigIDV2(item); resp != nil { - apiv1.Collect(batchRsp, resp) - } - } - if !apiv1.IsSuccess(batchRsp) { - return batchRsp - } return svr.nextSvr.DeleteRoutingConfigsV2(ctx, req) } - -// checkBatchRoutingConfig Check batch request -func checkBatchRoutingConfigV2(req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return apiv1.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > utils.MaxBatchSize { - return apiv1.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// checkRoutingConfig Check the validity of the basic parameter of the routing configuration -func checkRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if err := checkRoutingNameAndNamespace(req); err != nil { - return err - } - - if err := checkRoutingConfigPriorityV2(req); err != nil { - return err - } - - if err := checkRoutingPolicyV2(req); err != nil { - return err - } - - return nil -} - -// checkUpdateRoutingConfigV2 Check the validity of the basic parameter of the routing configuration -func checkUpdateRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { - if resp := checkRoutingConfigIDV2(req); resp != nil { - return resp - } - - if err := checkRoutingNameAndNamespace(req); err != nil { - return err - } - - if err := checkRoutingConfigPriorityV2(req); err != nil { - return err - } - - if err := checkRoutingPolicyV2(req); err != nil { - return err - } - - return nil -} - -func checkRoutingNameAndNamespace(req *apitraffic.RouteRule) *apiservice.Response { - if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetName()), utils.MaxRuleName); err != nil { - return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingName, req) - } - - if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetNamespace()), - utils.MaxDbServiceNamespaceLength); err != nil { - return apiv1.NewRouterResponse(apimodel.Code_InvalidNamespaceName, req) - } - - return nil -} - -func checkRoutingConfigIDV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.Id == "" { - return apiv1.NewResponse(apimodel.Code_InvalidRoutingID) - } - - return nil -} - -func checkRoutingConfigPriorityV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.Priority > 10 { - return apiv1.NewResponse(apimodel.Code_InvalidRoutingPriority) - } - - return nil -} - -func checkRoutingPolicyV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { - return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingPolicy, req) - } - - // Automatically supplement @Type attribute according to Policy - if req.RoutingConfig.TypeUrl == "" { - if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { - req.RoutingConfig.TypeUrl = model.RuleRoutingTypeUrl - } - if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_MetadataPolicy { - req.RoutingConfig.TypeUrl = model.MetaRoutingTypeUrl - } - if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_NearbyPolicy { - req.RoutingConfig.TypeUrl = model.NearbyRoutingTypeUrl - } - } - - return nil -} diff --git a/service/interceptor/paramcheck/server.go b/service/interceptor/paramcheck/server.go index 680e14bf1..f46a54065 100644 --- a/service/interceptor/paramcheck/server.go +++ b/service/interceptor/paramcheck/server.go @@ -35,10 +35,9 @@ type Server struct { ratelimit plugin.Ratelimit } -func NewServer(nextSvr service.DiscoverServer, s store.Store) service.DiscoverServer { +func NewServer(nextSvr service.DiscoverServer) service.DiscoverServer { proxy := &Server{ nextSvr: nextSvr, - storage: s, } // 获取限流插件 proxy.ratelimit = plugin.GetRatelimit() diff --git a/service/interceptor/paramcheck/service.go b/service/interceptor/paramcheck/service.go index 2dc917322..d46394f22 100644 --- a/service/interceptor/paramcheck/service.go +++ b/service/interceptor/paramcheck/service.go @@ -20,7 +20,6 @@ package paramcheck import ( "context" "errors" - "strconv" "strings" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" @@ -176,14 +175,6 @@ func (svr *Server) GetServices(ctx context.Context, query map[string]string) *se } } - // 判断offset和limit是否为int,并从filters清除offset/limit参数 - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - query["offset"] = strconv.FormatUint(uint64(offset), 10) - query["limit"] = strconv.FormatUint(uint64(limit), 10) - return svr.nextSvr.GetServices(ctx, query) } diff --git a/service/interceptor/paramcheck/service_alias.go b/service/interceptor/paramcheck/service_alias.go index 7435cfbec..c01f58de8 100644 --- a/service/interceptor/paramcheck/service_alias.go +++ b/service/interceptor/paramcheck/service_alias.go @@ -40,8 +40,12 @@ func (svr *Server) CreateServiceAlias(ctx context.Context, // DeleteServiceAliases implements service.DiscoverServer. func (svr *Server) DeleteServiceAliases(ctx context.Context, req []*service_manage.ServiceAlias) *service_manage.BatchWriteResponse { - if checkError := checkBatchAlias(req); checkError != nil { - return checkError + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) } batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) @@ -70,18 +74,6 @@ func (svr *Server) GetServiceAliases(ctx context.Context, return svr.nextSvr.GetServiceAliases(ctx, query) } -func checkBatchAlias(req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > utils.MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - // checkCreateServiceAliasReq 检查别名请求 func checkCreateServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { response, done := preCheckAlias(req) diff --git a/service/interceptor/register.go b/service/interceptor/register.go index 87a8ac433..399ba8294 100644 --- a/service/interceptor/register.go +++ b/service/interceptor/register.go @@ -22,30 +22,27 @@ import ( "github.com/polarismesh/polaris/service" service_auth "github.com/polarismesh/polaris/service/interceptor/auth" "github.com/polarismesh/polaris/service/interceptor/paramcheck" - "github.com/polarismesh/polaris/store" ) func init() { - err := service.RegisterServerProxy("paramcheck", func(pre service.DiscoverServer, - s store.Store) (service.DiscoverServer, error) { - return paramcheck.NewServer(pre, s), nil + err := service.RegisterServerProxy("paramcheck", func(pre service.DiscoverServer) (service.DiscoverServer, error) { + return paramcheck.NewServer(pre), nil }) if err != nil { panic(err) } - err = service.RegisterServerProxy("auth", func(pre service.DiscoverServer, - s store.Store) (service.DiscoverServer, error) { - userSvr, err := auth.GetUserServer() + err = service.RegisterServerProxy("auth", func(pre service.DiscoverServer) (service.DiscoverServer, error) { + userMgn, err := auth.GetUserServer() if err != nil { return nil, err } - policySvr, err := auth.GetStrategyServer() + strategyMgn, err := auth.GetStrategyServer() if err != nil { return nil, err } - return service_auth.NewServer(pre, userSvr, policySvr), nil + return service_auth.NewServerAuthAbility(pre, userMgn, strategyMgn), nil }) if err != nil { panic(err) diff --git a/service/lane.go b/service/lane.go deleted file mode 100644 index 78b08f39c..000000000 --- a/service/lane.go +++ /dev/null @@ -1,295 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package service - -import ( - "context" - "fmt" - "time" - - "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/proto" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/wrapperspb" - - cachetypes "github.com/polarismesh/polaris/cache/api" - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - commonstore "github.com/polarismesh/polaris/common/store" - "github.com/polarismesh/polaris/common/utils" -) - -// CreateLaneGroups 批量创建泳道组 -func (s *Server) CreateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range req { - resp := s.CreateLaneGroup(ctx, req[i]) - api.Collect(responses, resp) - } - return api.FormatBatchWriteResponse(responses) -} - -// CreateLaneGroup 创建泳道组 -func (s *Server) CreateLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { - tx, err := s.storage.StartTx() - if err != nil { - log.Error("[Service][Lane] open store transaction fail", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - defer func() { - _ = tx.Rollback() - }() - - saveVal, err := s.storage.LockLaneGroup(tx, req.GetName()) - if err != nil { - log.Error("[Service][Lane] lock one lane_group", utils.RequestID(ctx), - zap.String("name", req.GetName()), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - if saveVal != nil { - return api.NewResponse(apimodel.Code_ExistedResource) - } - saveData := &model.LaneGroup{} - if err := saveData.FromSpec(req); err != nil { - log.Error("[Service][Lane] create lane_group transfer spec to model", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(apimodel.Code_ExecuteException) - } - saveData.ID = utils.DefaultString(req.GetId(), utils.NewUUID()) - saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) - - // 由于这里是新建,所以需要手动再把两个 flag 字段设置为 true 状态 - for i := range saveData.LaneRules { - saveData.LaneRules[i].SetAddFlag(true) - saveData.LaneRules[i].SetChangeEnable(true) - } - - if err := s.storage.AddLaneGroup(tx, saveData); err != nil { - log.Error("[Service][Lane] save lane_group", utils.RequestID(ctx), zap.String("name", saveData.Name), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - req.Id = saveData.ID - - if err := tx.Commit(); err != nil { - log.Error("[Service][Lane] commit store transaction fail", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - - s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.OCreate)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId(), - Type: security.ResourceType_LaneRules, - }, false) - return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) -} - -// UpdateLaneGroups 批量更新泳道组 -func (s *Server) UpdateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range req { - resp := s.UpdateLaneGroup(ctx, req[i]) - api.Collect(responses, resp) - } - return api.FormatBatchWriteResponse(responses) -} - -// UpdateLaneGroup 更新泳道组 -func (s *Server) UpdateLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { - tx, err := s.storage.StartTx() - if err != nil { - log.Error("[Service][Lane] open store transaction fail", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - defer func() { - _ = tx.Rollback() - }() - - saveData, err := s.storage.LockLaneGroup(tx, req.GetName()) - if err != nil { - log.Error("[Service][Lane] lock one lane_group", utils.RequestID(ctx), - zap.String("name", req.GetName()), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - if saveData == nil { - log.Error("[Service][Lane] lock one lane_group not found", utils.RequestID(ctx), - zap.String("name", req.GetName())) - return api.NewResponse(apimodel.Code_NotFoundResource) - } - - needUpdate, err := updateLaneGroupAttribute(req, saveData) - if err != nil { - log.Error("[Service][Lane] update lane_group transfer spec to model", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(apimodel.Code_ExecuteException) - } - if !needUpdate { - return api.NewResponse(apimodel.Code_NoNeedUpdate) - } - - saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) - if err := s.storage.UpdateLaneGroup(tx, saveData); err != nil { - log.Error("[Service][Lane] update lane_group", utils.RequestID(ctx), zap.String("name", saveData.Name), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - req.Id = saveData.ID - - if err := tx.Commit(); err != nil { - log.Error("[Service][Lane] commit store transaction fail", utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - - s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.OUpdate)) - return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) -} - -// DeleteLaneGroups 批量删除泳道组 -func (s *Server) DeleteLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for i := range req { - resp := s.DeleteLaneGroup(ctx, req[i]) - api.Collect(responses, resp) - } - return api.FormatBatchWriteResponse(responses) -} - -// DeleteLaneGroup 删除泳道组 -func (s *Server) DeleteLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { - var saveData *model.LaneGroup - var err error - if req.GetId() != "" { - saveData, err = s.storage.GetLaneGroupByID(req.GetId()) - } else { - saveData, err = s.storage.GetLaneGroup(req.GetName()) - } - if err != nil { - log.Error("[Server][LaneGroup] get target lane_group when delete", zap.String("id", req.GetId()), - zap.String("name", req.GetName()), utils.RequestID(ctx), zap.Error(err)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - if saveData == nil { - log.Info("[Server][LaneGroup] delete target lane_group but not found", zap.String("id", req.GetId()), - zap.String("name", req.GetName()), utils.RequestID(ctx)) - return api.NewResponse(apimodel.Code_ExecuteSuccess) - } - - saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) - if err := s.storage.DeleteLaneGroup(saveData.ID); err != nil { - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - req.Id = saveData.ID - s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.ODelete)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId(), - Type: security.ResourceType_LaneRules, - }, true) - return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) -} - -// GetLaneGroups 查询泳道组列表 -func (s *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { - offset, limit, _ := utils.ParseOffsetAndLimit(filter) - total, ret, err := s.caches.LaneRule().Query(ctx, &cachetypes.LaneGroupArgs{ - Filter: filter, - Offset: offset, - Limit: limit, - }) - if err != nil { - log.Error("[Server][LaneGroup][Query] get lane_groups from store", utils.RequestID(ctx), zap.Error(err)) - return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) - } - - rsp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) - rsp.Amount = wrapperspb.UInt32(total) - rsp.Size = wrapperspb.UInt32(uint32(len(ret))) - rsp.Data = make([]*anypb.Any, 0, len(ret)) - - for i := range ret { - data, err := ret[i].ToProto() - if err != nil { - log.Error("[Server][LaneGroup][Query] lane_group convert to proto", utils.RequestID(ctx), zap.Error(err)) - return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) - } - anyData, err := anypb.New(proto.MessageV2(data.Proto)) - if err != nil { - log.Error("[Server][LaneGroup][Query] lane_group convert to anypb", utils.RequestID(ctx), zap.Error(err)) - return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) - } - rsp.Data = append(rsp.Data, anyData) - } - return rsp -} - -// GetAllLaneGroups Query all router_rule rules -func (s *Server) GetAllLaneGroups(ctx context.Context) *apiservice.BatchQueryResponse { - return nil -} - -func updateLaneGroupAttribute(req *apitraffic.LaneGroup, saveData *model.LaneGroup) (bool, error) { - updateData := &model.LaneGroup{} - if err := updateData.FromSpec(req); err != nil { - return false, err - } - - saveData.Description = updateData.Description - saveData.Rule = updateData.Rule - - for ruleId := range updateData.LaneRules { - // 默认所有规则 enable 状态都出现了变更 - updateData.LaneRules[ruleId].SetChangeEnable(true) - updateData.LaneRules[ruleId].SetAddFlag(false) - } - - for ruleId := range updateData.LaneRules { - newRule := updateData.LaneRules[ruleId] - oldRule, ok := saveData.LaneRules[ruleId] - if !ok { - // 在原来的规则当中不存在,认为是新增的 - newRule.SetAddFlag(true) - continue - } - newRule.Revision = utils.DefaultString(newRule.Revision, utils.NewUUID()) - // 如果 Enable 字段比较发现没有变化,则设置为 nil - if oldRule.Enable == newRule.Enable { - newRule.SetChangeEnable(false) - } - } - saveData.LaneRules = updateData.LaneRules - return true, nil -} - -// laneGroupRecordEntry 转换为鉴权策略的记录结构体 -func laneGroupRecordEntry(ctx context.Context, req *apitraffic.LaneGroup, md *model.LaneGroup, - operationType model.OperationType) *model.RecordEntry { - - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - - entry := &model.RecordEntry{ - ResourceType: model.RLaneGroup, - ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), - OperationType: operationType, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - return entry -} diff --git a/service/namespace_test.go b/service/namespace_test.go index abd0b5999..6451492ee 100644 --- a/service/namespace_test.go +++ b/service/namespace_test.go @@ -158,7 +158,7 @@ func TestRemoveNamespace(t *testing.T) { } defer discoverSuit.cleanServiceName(serviceReq.GetName().GetValue(), serviceReq.GetNamespace().GetValue()) - resp := discoverSuit.NamespaceServer().DeleteNamespaces(discoverSuit.DefaultCtx, []*apimodel.Namespace{namespaceResp}) + resp := discoverSuit.NamespaceServer().DeleteNamespace(discoverSuit.DefaultCtx, namespaceResp) if resp.GetCode().GetValue() != uint32(apimodel.Code_NamespaceExistedServices) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) } diff --git a/service/ratelimit_config.go b/service/ratelimit_config.go index 37004b7e8..be4bf0734 100644 --- a/service/ratelimit_config.go +++ b/service/ratelimit_config.go @@ -25,16 +25,14 @@ import ( "time" "github.com/gogo/protobuf/jsonpb" + "github.com/golang/protobuf/ptypes" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "go.uber.org/zap" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -58,6 +56,10 @@ var ( // CreateRateLimits 批量创建限流规则 func (s *Server) CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if err := checkBatchRateLimits(request); err != nil { + return err + } + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, rateLimit := range request { response := s.CreateRateLimit(ctx, rateLimit) @@ -68,34 +70,45 @@ func (s *Server) CreateRateLimits(ctx context.Context, request []*apitraffic.Rul // CreateRateLimit 创建限流规则 func (s *Server) CreateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + + // 参数校验 + if resp := checkRateLimitParams(req); resp != nil { + return resp + } + if resp := checkRateLimitRuleParams(requestID, req); resp != nil { + return resp + } + // 构造底层数据结构 data, err := api2RateLimit(req, nil) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) } // 存储层操作 if err := s.storage.CreateRateLimit(data); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("create rate limit rule: id=%v, namespace=%v, service=%v, name=%v", data.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), req.GetName().GetValue()) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, data, model.OCreate)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId().GetValue(), - Type: security.ResourceType_RateLimitRules, - }, false) + req.Id = utils.NewStringValue(data.ID) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) } // DeleteRateLimits 批量删除限流规则 func (s *Server) DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if err := checkBatchRateLimits(request); err != nil { + return err + } + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { resp := s.DeleteRateLimit(ctx, entry) @@ -106,8 +119,16 @@ func (s *Server) DeleteRateLimits(ctx context.Context, request []*apitraffic.Rul // DeleteRateLimit 删除单个限流规则 func (s *Server) DeleteRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + platformID := utils.ParsePlatformID(ctx) + + // 参数校验 + if resp := checkRevisedRateLimitParams(req); resp != nil { + return resp + } + // 检查限流规则是否存在 - rateLimit, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) + rateLimit, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) if resp != nil { if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundRateLimit) { return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -120,24 +141,23 @@ func (s *Server) DeleteRateLimit(ctx context.Context, req *apitraffic.Rule) *api // 存储层操作 if err := s.storage.DeleteRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("delete rate limit rule: id=%v, namespace=%v, service=%v, name=%v", rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Labels) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.ODelete)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId().GetValue(), - Type: security.ResourceType_RateLimitRules, - }, true) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) } func (s *Server) EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if err := checkBatchRateLimits(request); err != nil { + return err + } responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { response := s.EnableRateLimit(ctx, entry) @@ -148,8 +168,16 @@ func (s *Server) EnableRateLimits(ctx context.Context, request []*apitraffic.Rul // EnableRateLimit 启用限流规则 func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + platformID := utils.ParsePlatformID(ctx) + + // 参数校验 + if resp := checkRevisedRateLimitParams(req); resp != nil { + return resp + } + // 检查限流规则是否存在 - data, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) + data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) if resp != nil { return resp } @@ -162,13 +190,13 @@ func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *api rateLimit.Revision = utils.NewUUID() if err := s.storage.EnableRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("enable rate limit: id=%v, disable=%v", rateLimit.ID, rateLimit.Disable) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdateEnable)) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -176,6 +204,10 @@ func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *api // UpdateRateLimits 批量更新限流规则 func (s *Server) UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if err := checkBatchRateLimits(request); err != nil { + return err + } + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { response := s.UpdateRateLimit(ctx, entry) @@ -186,8 +218,20 @@ func (s *Server) UpdateRateLimits(ctx context.Context, request []*apitraffic.Rul // UpdateRateLimit 更新限流规则 func (s *Server) UpdateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { + requestID := utils.ParseRequestID(ctx) + // 参数校验 + if resp := checkRevisedRateLimitParams(req); resp != nil { + return resp + } + if resp := checkRateLimitRuleParams(requestID, req); resp != nil { + return resp + } + if resp := checkRateLimitParamsDbLen(req); resp != nil { + return resp + } + // 检查限流规则是否存在 - data, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) + data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) if resp != nil { return resp } @@ -195,18 +239,18 @@ func (s *Server) UpdateRateLimit(ctx context.Context, req *apitraffic.Rule) *api // 构造底层数据结构 rateLimit, err := api2RateLimit(req, data) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) } rateLimit.ID = data.ID if err := s.storage.UpdateRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("update rate limit: id=%v, namespace=%v, service=%v, name=%v", rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Name) - log.Info(msg, utils.RequestID(ctx)) + log.Info(msg, utils.ZapRequestID(requestID)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdate)) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -222,7 +266,7 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap total, extendRateLimits, err := s.Cache().RateLimit().QueryRateLimitRules(ctx, *args) if err != nil { - log.Error("get rate limits store", zap.Error(err), utils.RequestID(ctx)) + log.Errorf("get rate limits store err: %s", err.Error()) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -233,7 +277,7 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap for _, item := range extendRateLimits { limit, err := rateLimit2Console(item) if err != nil { - log.Error("get rate limits convert", zap.Error(err), utils.RequestID(ctx)) + log.Errorf("get rate limits convert err: %s", err.Error()) return api.NewBatchQueryResponse(apimodel.Code_ParseRateLimitException) } out.RateLimits = append(out.RateLimits, limit) @@ -242,11 +286,6 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap return out } -// GetAllRateLimits Query all router_rule rules -func (s *Server) GetAllRateLimits(ctx context.Context) *apiservice.BatchQueryResponse { - return nil -} - func parseRateLimitArgs(query map[string]string) (*cachetypes.RateLimitRuleArgs, *apiservice.BatchQueryResponse) { for key := range query { if _, ok := RateLimitFilters[key]; !ok { @@ -279,6 +318,19 @@ func parseRateLimitArgs(query map[string]string) (*cachetypes.RateLimitRuleArgs, return args, nil } +// checkBatchRateLimits 检查批量请求的限流规则 +func checkBatchRateLimits(req []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + // checkRateLimitValid 检查限流规则是否允许修改/删除 func (s *Server) checkRateLimitValid(ctx context.Context, serviceID string, req *apitraffic.Rule) ( *model.Service, *apiservice.Response) { @@ -293,13 +345,74 @@ func (s *Server) checkRateLimitValid(ctx context.Context, serviceID string, req return service, nil } -// checkRateLimitExisted 检查限流规则是否存在 -func (s *Server) checkRateLimitExisted(ctx context.Context, id string, - req *apitraffic.Rule) (*model.RateLimit, *apiservice.Response) { +// checkRateLimitParams 检查限流规则基础参数 +func checkRateLimitParams(req *apitraffic.Rule) *apiservice.Response { + if req == nil { + return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) + } + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) + } + if resp := checkRateLimitParamsDbLen(req); nil != resp { + return resp + } + return nil +} +// checkRateLimitParams 检查限流规则基础参数 +func checkRateLimitParamsDbLen(req *apitraffic.Rule) *apiservice.Response { + if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetName(), MaxDbRateLimitName); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitName, req) + } + return nil +} + +// checkRateLimitRuleParams 检查限流规则其他参数 +func checkRateLimitRuleParams(requestID string, req *apitraffic.Rule) *apiservice.Response { + // 检查amounts是否有重复周期 + amounts := req.GetAmounts() + durations := make(map[time.Duration]bool) + for _, amount := range amounts { + d := amount.GetValidDuration() + duration, err := ptypes.Duration(d) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(requestID)) + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) + } + durations[duration] = true + } + if len(amounts) != len(durations) { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) + } + return nil +} + +// checkRevisedRateLimitParams 检查修改/删除限流规则基础参数 +func checkRevisedRateLimitParams(req *apitraffic.Rule) *apiservice.Response { + if req == nil { + return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) + } + if req.GetId().GetValue() == "" { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitID, req) + } + return nil +} + +// checkRateLimitExisted 检查限流规则是否存在 +func (s *Server) checkRateLimitExisted( + id, requestID string, req *apitraffic.Rule) (*model.RateLimit, *apiservice.Response) { rateLimit, err := s.storage.GetRateLimitWithID(id) if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) + log.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, api.NewRateLimitResponse(commonstore.StoreCode2APICode(err), req) } if rateLimit == nil { @@ -334,7 +447,6 @@ func api2RateLimit(req *apitraffic.Rule, old *model.RateLimit) (*model.RateLimit Labels: string(labelStr), Rule: rule, Revision: utils.NewUUID(), - Metadata: req.Metadata, } return out, nil } @@ -345,7 +457,6 @@ func rateLimit2Console(rateLimit *model.RateLimit) (*apitraffic.Rule, error) { return nil, nil } if len(rateLimit.Rule) > 0 { - rateLimit = rateLimit.CopyNoProto() rateLimit.Proto = &apitraffic.Rule{} // 控制台查询的请求 if err := json.Unmarshal([]byte(rateLimit.Rule), rateLimit.Proto); err != nil { @@ -363,7 +474,6 @@ func rateLimit2Console(rateLimit *model.RateLimit) (*apitraffic.Rule, error) { rule.Ctime = utils.NewStringValue(commontime.Time2String(rateLimit.CreateTime)) rule.Mtime = utils.NewStringValue(commontime.Time2String(rateLimit.ModifyTime)) rule.Disable = utils.NewBoolValue(rateLimit.Disable) - rule.Metadata = rateLimit.Metadata if rateLimit.EnableTime.Year() > 2000 { rule.Etime = utils.NewStringValue(commontime.Time2String(rateLimit.EnableTime)) } else { @@ -418,7 +528,6 @@ func rateLimit2Client( rule.Priority = utils.NewUInt32Value(rateLimit.Priority) rule.Revision = utils.NewStringValue(rateLimit.Revision) rule.Disable = utils.NewBoolValue(rateLimit.Disable) - rule.Metadata = rateLimit.Metadata copyRateLimitProto(rateLimit, rule) return rule, nil } diff --git a/service/ratelimit_config_test.go b/service/ratelimit_config_test.go index a004c963d..0e93d0653 100644 --- a/service/ratelimit_config_test.go +++ b/service/ratelimit_config_test.go @@ -331,7 +331,7 @@ func TestDeleteRateLimit(t *testing.T) { }() resp := discoverSuit.DiscoverServer().DeleteRateLimits(discoverSuit.DefaultCtx, []*apitraffic.Rule{rateLimitReq}) - assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.True(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) }) t.Run("并发删除限流规则,可以正常删除", func(t *testing.T) { diff --git a/service/routing_config_v1.go b/service/routing_config_v1.go index d8a29fa2e..6457f8dbd 100644 --- a/service/routing_config_v1.go +++ b/service/routing_config_v1.go @@ -19,12 +19,20 @@ package service import ( "context" + "encoding/json" + "fmt" + "time" + "github.com/gogo/protobuf/jsonpb" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + commonstore "github.com/polarismesh/polaris/common/store" + commontime "github.com/polarismesh/polaris/common/time" + "github.com/polarismesh/polaris/common/utils" ) var ( @@ -39,10 +47,15 @@ var ( // CreateRoutingConfigs Create a routing configuration func (s *Server) CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfig(req); err != nil { + return err + } + resp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { - api.Collect(resp, s.CreateRoutingConfig(ctx, entry)) + api.Collect(resp, s.createRoutingConfigV1toV2(ctx, entry)) } + return api.FormatBatchWriteResponse(resp) } @@ -50,32 +63,92 @@ func (s *Server) CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Rou // services to prevent the service from being deleted // Deprecated: This method is ready to abandon func (s *Server) CreateRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") - return resps + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) + if resp := checkRoutingConfig(req); resp != nil { + return resp + } + + serviceName := req.GetService().GetValue() + namespaceName := req.GetNamespace().GetValue() + service, errResp := s.loadService(namespaceName, serviceName) + if errResp != nil { + log.Error(errResp.GetInfo().GetValue(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewRoutingResponse(apimodel.Code(errResp.GetCode().GetValue()), req) + } + if service == nil { + return api.NewRoutingResponse(apimodel.Code_NotFoundService, req) + } + if service.IsAlias() { + return api.NewRoutingResponse(apimodel.Code_NotAllowAliasCreateRouting, req) + } + + routingConfig, err := s.storage.GetRoutingConfigWithService(service.Name, service.Namespace) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + if routingConfig != nil { + return api.NewRoutingResponse(apimodel.Code_ExistedResource, req) + } + + conf, err := api2RoutingConfig(service.ID, req) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewRoutingResponse(apimodel.Code_ExecuteException, req) + } + if err := s.storage.CreateRoutingConfig(conf); err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return wrapperRoutingStoreResponse(req, err) + } + + s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, conf, model.OCreate)) + return api.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) } // DeleteRoutingConfigs Batch delete routing configuration func (s *Server) DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfig(req); err != nil { + return err + } + out := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.DeleteRoutingConfig(ctx, entry) api.Collect(out, resp) } + return api.FormatBatchWriteResponse(out) } // DeleteRoutingConfig Delete a routing configuration // Deprecated: This method is ready to abandon func (s *Server) DeleteRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") - return resps + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) + service, resp := s.routingConfigCommonCheck(ctx, req) + if resp != nil { + return resp + } + + if err := s.storage.DeleteRoutingConfig(service.ID); err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return wrapperRoutingStoreResponse(req, err) + } + + s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, nil, model.ODelete)) + return api.NewResponse(apimodel.Code_ExecuteSuccess) } // UpdateRoutingConfigs Batch update routing configuration func (s *Server) UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfig(req); err != nil { + return err + } + out := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { - resp := s.UpdateRoutingConfig(ctx, entry) + resp := s.updateRoutingConfigV1toV2(ctx, entry) api.Collect(out, resp) } @@ -85,14 +158,267 @@ func (s *Server) UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Rou // UpdateRoutingConfig Update a routing configuration // Deprecated: 该方法准备舍弃 func (s *Server) UpdateRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") - return resps + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) + service, resp := s.routingConfigCommonCheck(ctx, req) + if resp != nil { + return resp + } + + conf, err := s.storage.GetRoutingConfigWithService(service.Name, service.Namespace) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + if conf == nil { + return api.NewRoutingResponse(apimodel.Code_NotFoundRouting, req) + } + + reqModel, err := api2RoutingConfig(service.ID, req) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewRoutingResponse(apimodel.Code_ParseRoutingException, req) + } + + if err := s.storage.UpdateRoutingConfig(reqModel); err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return wrapperRoutingStoreResponse(req, err) + } + + s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, reqModel, model.OUpdate)) + return api.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) } // GetRoutingConfigs Get the routing configuration in batches, and provide the interface of // the query routing configuration to the OSS // Deprecated: This method is ready to abandon func (s *Server) GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - resps := api.NewBatchQueryResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") - return resps + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) + + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + + filter := make(map[string]string) + for key, value := range query { + if _, ok := RoutingConfigFilterAttrs[key]; !ok { + log.Errorf("[Server][RoutingConfig][Query] attribute(%s) is not allowed", key) + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + filter[key] = value + } + // service -- > name This special treatment + if service, ok := filter["service"]; ok { + filter["name"] = service + delete(filter, "service") + } + + // Can be filtered according to name and namespace + total, routings, err := s.storage.GetRoutingConfigs(filter, offset, limit) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) + } + + resp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) + resp.Amount = utils.NewUInt32Value(total) + resp.Size = utils.NewUInt32Value(uint32(len(routings))) + resp.Routings = make([]*apitraffic.Routing, 0, len(routings)) + for _, entry := range routings { + routing, err := routingConfig2API(entry.Config, entry.ServiceName, entry.NamespaceName) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return api.NewBatchQueryResponse(apimodel.Code_ParseRoutingException) + } + resp.Routings = append(resp.Routings, routing) + } + + return resp +} + +// routingConfigCommonCheck Public examination of routing configuration operation +func (s *Server) routingConfigCommonCheck( + ctx context.Context, req *apitraffic.Routing) (*model.Service, *apiservice.Response) { + if resp := checkRoutingConfig(req); resp != nil { + return nil, resp + } + + rid := utils.ParseRequestID(ctx) + pid := utils.ParsePlatformID(ctx) + serviceName := req.GetService().GetValue() + namespaceName := req.GetNamespace().GetValue() + + service, err := s.storage.GetService(serviceName, namespaceName) + if err != nil { + log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + return nil, api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + if service == nil { + return nil, api.NewRoutingResponse(apimodel.Code_NotFoundService, req) + } + + return service, nil +} + +// checkRoutingConfig Check the validity of the basic parameter of the routing configuration +func checkRoutingConfig(req *apitraffic.Routing) *apiservice.Response { + if req == nil { + return api.NewRoutingResponse(apimodel.Code_EmptyRequest, req) + } + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewRoutingResponse(apimodel.Code_InvalidServiceName, req) + } + + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewRoutingResponse(apimodel.Code_InvalidNamespaceName, req) + } + + if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { + return api.NewRoutingResponse(apimodel.Code_InvalidServiceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { + return api.NewRoutingResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetServiceToken(), MaxDbServiceToken); err != nil { + return api.NewRoutingResponse(apimodel.Code_InvalidServiceToken, req) + } + + return nil +} + +// parseServiceRoutingToken Get token from RoutingConfig request parameters +func parseServiceRoutingToken(ctx context.Context, req *apitraffic.Routing) string { + if reqToken := req.GetServiceToken().GetValue(); reqToken != "" { + return reqToken + } + + return utils.ParseToken(ctx) +} + +// api2RoutingConfig Convert the API parameter to internal data structure +func api2RoutingConfig(serviceID string, req *apitraffic.Routing) (*model.RoutingConfig, error) { + inBounds, outBounds, err := marshalRoutingConfig(req.GetInbounds(), req.GetOutbounds()) + if err != nil { + return nil, err + } + + out := &model.RoutingConfig{ + ID: serviceID, + InBounds: string(inBounds), + OutBounds: string(outBounds), + Revision: utils.NewUUID(), + } + + return out, nil +} + +// routingConfig2API Convert the internal data structure to API parameter to pass out +func routingConfig2API(req *model.RoutingConfig, service string, namespace string) (*apitraffic.Routing, error) { + if req == nil { + return nil, nil + } + + out := &apitraffic.Routing{ + Service: utils.NewStringValue(service), + Namespace: utils.NewStringValue(namespace), + Revision: utils.NewStringValue(req.Revision), + Ctime: utils.NewStringValue(commontime.Time2String(req.CreateTime)), + Mtime: utils.NewStringValue(commontime.Time2String(req.ModifyTime)), + } + + if req.InBounds != "" { + var inBounds []*apitraffic.Route + if err := json.Unmarshal([]byte(req.InBounds), &inBounds); err != nil { + return nil, err + } + out.Inbounds = inBounds + } + if req.OutBounds != "" { + var outBounds []*apitraffic.Route + if err := json.Unmarshal([]byte(req.OutBounds), &outBounds); err != nil { + return nil, err + } + out.Outbounds = outBounds + } + + return out, nil +} + +// marshalRoutingConfig Formulate Inbounds and OUTBOUNDS +func marshalRoutingConfig(in []*apitraffic.Route, out []*apitraffic.Route) ([]byte, []byte, error) { + inBounds, err := json.Marshal(in) + if err != nil { + return nil, nil, err + } + + outBounds, err := json.Marshal(out) + if err != nil { + return nil, nil, err + } + + return inBounds, outBounds, nil +} + +// checkBatchRoutingConfig Check batch request +func checkBatchRoutingConfig(req []*apitraffic.Routing) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// routingRecordEntry Construction of RoutingConfig's record Entry +func routingRecordEntry(ctx context.Context, req *apitraffic.Routing, svc *model.Service, md *model.RoutingConfig, + opt model.OperationType) *model.RecordEntry { + + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + + entry := &model.RecordEntry{ + ResourceType: model.RRouting, + ResourceName: fmt.Sprintf("%s(%s)", svc.Name, svc.ID), + Namespace: svc.Namespace, + OperationType: opt, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + + return entry +} + +// routingV2RecordEntry Construction of RoutingConfig's record Entry +func routingV2RecordEntry(ctx context.Context, req *apitraffic.RouteRule, md *model.RouterConfig, + opt model.OperationType) *model.RecordEntry { + + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + + entry := &model.RecordEntry{ + ResourceType: model.RRouting, + ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), + Namespace: req.GetNamespace(), + OperationType: opt, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + return entry +} + +// wrapperRoutingStoreResponse Packing routing storage layer error +func wrapperRoutingStoreResponse(routing *apitraffic.Routing, err error) *apiservice.Response { + if err == nil { + return nil + } + resp := api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) + resp.Routing = routing + return resp } diff --git a/service/routing_config_v1_test.go b/service/routing_config_v1_test.go index 2595a0022..a38857670 100644 --- a/service/routing_config_v1_test.go +++ b/service/routing_config_v1_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/golang/protobuf/ptypes" + "github.com/polarismesh/specification/source/go/api/v1/model" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" @@ -91,8 +92,161 @@ func checkSameRoutingConfig(t *testing.T, lhs *apitraffic.Routing, rhs *apitraff checkFunc("Outbounds", lhs.Outbounds, rhs.Outbounds) } +// 测试创建路由配置 +func TestCreateRoutingConfig(t *testing.T) { + t.Run("正常创建路由配置配置请求", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + + defer discoverSuit.Destroy() + _, serviceResp := discoverSuit.createCommonService(t, 200) + defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + _, _ = discoverSuit.createCommonRoutingConfig(t, serviceResp, 3, 0) + + // 对写进去的数据进行查询 + _ = discoverSuit.CacheMgr().TestUpdate() + out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) + defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + }) + + t.Run("参数缺失,报错", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + defer discoverSuit.Destroy() + + _, serviceResp := discoverSuit.createCommonService(t, 20) + defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + + req := &apitraffic.Routing{} + resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + assert.False(t, respSuccess(resp)) + t.Logf("%s", resp.GetInfo().GetValue()) + + req.Service = serviceResp.Name + resp = discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + assert.False(t, respSuccess(resp)) + t.Logf("%s", resp.GetInfo().GetValue()) + + req.Namespace = serviceResp.Namespace + resp = discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + defer discoverSuit.cleanCommonRoutingConfig(req.GetService().GetValue(), req.GetNamespace().GetValue()) + assert.True(t, respSuccess(resp)) + t.Logf("%s", resp.GetInfo().GetValue()) + }) + + t.Run("服务不存在,创建路由配置不报错", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + defer discoverSuit.Destroy() + + _, serviceResp := discoverSuit.createCommonService(t, 120) + discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + + _ = discoverSuit.CacheMgr().TestUpdate() + req := &apitraffic.Routing{} + req.Service = serviceResp.Name + req.Namespace = serviceResp.Namespace + req.ServiceToken = serviceResp.Token + resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + assert.False(t, respSuccess(resp)) + t.Logf("%s", resp.GetInfo().GetValue()) + }) +} + +// 测试创建路由配置 +func TestUpdateRoutingConfig(t *testing.T) { + t.Run("更新V1路由规则, 成功转为V2规则", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + + _, svc := discoverSuit.createCommonService(t, 200) + v1Rule, _ := discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) + t.Cleanup(func() { + discoverSuit.cleanServiceName(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.truncateCommonRoutingConfigV2() + discoverSuit.Destroy() + }) + + v1Rule.Outbounds = v1Rule.Inbounds + uResp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{v1Rule}) + assert.True(t, respSuccess(uResp)) + + // 等缓存层更新 + _ = discoverSuit.CacheMgr().TestUpdate() + + // 直接查询存储无法查询到 v1 的路由规则 + total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) + assert.NoError(t, err, err) + assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") + assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") + + // 从缓存中查询应该查到 6 条 v2 的路由规则 + out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(6), int(out.GetAmount().GetValue()), "query routing size") + rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + for i := range rulesV2 { + item := rulesV2[i] + assert.True(t, item.Enable, "v1 to v2 need default open enable") + msg := &apitraffic.RuleRoutingConfig{} + err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) + assert.NoError(t, err) + assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") + assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") + assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") + } + }) +} + // 测试缓存获取路由配置 func TestGetRoutingConfigWithCache(t *testing.T) { + + t.Run("多个服务的,多个路由配置,都可以查询到", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + defer discoverSuit.Destroy() + + total := 20 + serviceResps := make([]*apiservice.Service, 0, total) + routingResps := make([]*apitraffic.Routing, 0, total) + for i := 0; i < total; i++ { + _, resp := discoverSuit.createCommonService(t, i) + defer discoverSuit.cleanServiceName(resp.GetName().GetValue(), resp.GetNamespace().GetValue()) + serviceResps = append(serviceResps, resp) + + _, routingResp := discoverSuit.createCommonRoutingConfig(t, resp, 2, 0) + defer discoverSuit.cleanCommonRoutingConfig(resp.GetName().GetValue(), resp.GetNamespace().GetValue()) + routingResps = append(routingResps, routingResp) + } + + _ = discoverSuit.CacheMgr().TestUpdate() + for i := 0; i < total; i++ { + t.Logf("service : name=%s namespace=%s", serviceResps[i].GetName().GetValue(), serviceResps[i].GetNamespace().GetValue()) + out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResps[i]) + checkSameRoutingConfig(t, routingResps[i], out.GetRouting()) + } + }) + t.Run("走v2接口创建路由规则,不启用查不到,启用可以查到", func(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -248,6 +402,31 @@ func TestGetRoutingConfigWithCache(t *testing.T) { assert.True(t, len(out.GetRouting().GetOutbounds()) == 0, "inBounds must be zero") }) + + t.Run("服务路由数据不改变,传递了路由revision,不返回数据", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + defer discoverSuit.Destroy() + + _, serviceResp := discoverSuit.createCommonService(t, 10) + defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + + _, routingResp := discoverSuit.createCommonRoutingConfig(t, serviceResp, 2, 0) + defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + + _ = discoverSuit.CacheMgr().TestUpdate() + firstResp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) + checkSameRoutingConfig(t, routingResp, firstResp.GetRouting()) + + serviceResp.Revision = firstResp.Service.Revision + secondResp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) + if secondResp.GetService().GetRevision().GetValue() != serviceResp.GetRevision().GetValue() { + t.Fatalf("error") + } + assert.Equal(t, model.Code(secondResp.GetCode().GetValue()), apimodel.Code_DataNoChange) + }) t.Run("路由不存在,不会出异常", func(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -264,8 +443,8 @@ func TestGetRoutingConfigWithCache(t *testing.T) { }) } -// Test_RouteRule_V1_Server -func Test_RouteRule_V1_Server(t *testing.T) { +// test对routing字段进行校验 +func TestCheckRoutingFieldLen(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -273,29 +452,40 @@ func Test_RouteRule_V1_Server(t *testing.T) { } defer discoverSuit.Destroy() - t.Run("Create", func(t *testing.T) { - rsp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ - &apitraffic.Routing{}, - }) - assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) - }) + req := &apitraffic.Routing{ + ServiceToken: utils.NewStringValue("test"), + Service: utils.NewStringValue("test"), + Namespace: utils.NewStringValue("default"), + } - t.Run("Update", func(t *testing.T) { - rsp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ - &apitraffic.Routing{}, - }) - assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + t.Run("创建路由规则,服务名超长", func(t *testing.T) { + str := genSpecialStr(129) + oldName := req.Service + req.Service = utils.NewStringValue(str) + resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + req.Service = oldName + if resp.Code.Value != api.InvalidServiceName { + t.Fatalf("%+v", resp) + } }) - - t.Run("Delete", func(t *testing.T) { - rsp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ - &apitraffic.Routing{}, - }) - assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + t.Run("创建路由规则,命名空间超长", func(t *testing.T) { + str := genSpecialStr(129) + oldNamespace := req.Namespace + req.Namespace = utils.NewStringValue(str) + resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + req.Namespace = oldNamespace + if resp.Code.Value != api.InvalidNamespaceName { + t.Fatalf("%+v", resp) + } }) - - t.Run("Get", func(t *testing.T) { - rsp := discoverSuit.DiscoverServer().GetRoutingConfigs(discoverSuit.DefaultCtx, map[string]string{}) - assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + t.Run("创建路由规则,toeken超长", func(t *testing.T) { + str := genSpecialStr(2049) + oldServiceToken := req.ServiceToken + req.ServiceToken = utils.NewStringValue(str) + resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) + req.ServiceToken = oldServiceToken + if resp.Code.Value != api.InvalidServiceToken { + t.Fatalf("%+v", resp) + } }) } diff --git a/service/routing_config_v1tov2.go b/service/routing_config_v1tov2.go new file mode 100644 index 000000000..bebcfe146 --- /dev/null +++ b/service/routing_config_v1tov2.go @@ -0,0 +1,208 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service + +import ( + "context" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "go.uber.org/zap" + + apiv1 "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + commonstore "github.com/polarismesh/polaris/common/store" + "github.com/polarismesh/polaris/common/utils" +) + +// createRoutingConfigV1toV2 Compatible with V1 version of the creation routing rules, convert V1 to V2 for storage +func (s *Server) createRoutingConfigV1toV2(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { + if resp := checkRoutingConfig(req); resp != nil { + return resp + } + + serviceName := req.GetService().GetValue() + namespaceName := req.GetNamespace().GetValue() + svc, errResp := s.loadService(namespaceName, serviceName) + if errResp != nil { + log.Error("[Service][Routing] get read lock for service", zap.String("service", serviceName), + zap.String("namespace", namespaceName), utils.RequestID(ctx), zap.Any("err", errResp)) + return apiv1.NewRoutingResponse(apimodel.Code(errResp.GetCode().GetValue()), req) + } + if svc == nil { + return apiv1.NewRoutingResponse(apimodel.Code_NotFoundService, req) + } + if svc.IsAlias() { + return apiv1.NewRoutingResponse(apimodel.Code_NotAllowAliasCreateRouting, req) + } + + inDatas, outDatas, resp := batchBuildV2Routings(req) + if resp != nil { + return resp + } + + resp = s.saveRoutingV1toV2(ctx, svc.ID, inDatas, outDatas) + if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { + return resp + } + + return apiv1.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) +} + +// updateRoutingConfigV1toV2 Compatible with V1 version update routing rules, convert the data of V1 to V2 for storage +// Once the V1 rule is converted to V2 rules, the original V1 rules will be removed from storage +func (s *Server) updateRoutingConfigV1toV2(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { + svc, resp := s.routingConfigCommonCheck(ctx, req) + if resp != nil { + return resp + } + + serviceTx, err := s.storage.CreateTransaction() + if err != nil { + log.Error(err.Error(), utils.RequestID(ctx)) + return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + // Release the lock for the service + defer func() { + _ = serviceTx.Commit() + }() + + // Need to prohibit the concurrent modification of the V1 rules + if _, err = serviceTx.LockService(svc.Name, svc.Namespace); err != nil { + log.Error("[Service][Routing] get service x-lock", zap.String("service", svc.Name), + zap.String("namespace", svc.Namespace), utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + + conf, err := s.storage.GetRoutingConfigWithService(svc.Name, svc.Namespace) + if err != nil { + log.Error(err.Error(), utils.RequestID(ctx)) + return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) + } + if conf == nil { + return apiv1.NewRoutingResponse(apimodel.Code_NotFoundRouting, req) + } + + inDatas, outDatas, resp := batchBuildV2Routings(req) + if resp != nil { + return resp + } + + if resp := s.saveRoutingV1toV2(ctx, svc.ID, inDatas, outDatas); resp.GetCode().GetValue() != uint32( + apimodel.Code_ExecuteSuccess) { + return resp + } + + return apiv1.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) +} + +// saveRoutingV1toV2 Convert the V1 rules of the target to V2 rule +func (s *Server) saveRoutingV1toV2(ctx context.Context, svcId string, + inRules, outRules []*apitraffic.RouteRule) *apiservice.Response { + tx, err := s.storage.StartTx() + if err != nil { + log.Error("[Service][Routing] create routing v2 from v1 open tx", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) + } + defer func() { + _ = tx.Rollback() + }() + + // Need to delete the routing rules of V1 first + if err := s.storage.DeleteRoutingConfigTx(tx, svcId); err != nil { + log.Error("[Service][Routing] clean routing v1 from store", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) + } + + saveOperation := func(routings []*apitraffic.RouteRule) *apiservice.Response { + priorityMax := 0 + for i := range routings { + item := routings[i] + if item.Id == "" { + item.Id = utils.NewRoutingV2UUID() + } + item.Revision = utils.NewV2Revision() + data := &model.RouterConfig{} + if err := data.ParseRouteRuleFromAPI(item); err != nil { + return apiv1.NewResponse(apimodel.Code_ExecuteException) + } + + data.Valid = true + data.Enable = true + if priorityMax > 10 { + priorityMax = 10 + } + + data.Priority = uint32(priorityMax) + priorityMax++ + + if err := s.storage.CreateRoutingConfigV2Tx(tx, data); err != nil { + log.Error("[Routing][V2] create routing v2 from v1 into store", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) + } + s.RecordHistory(ctx, routingV2RecordEntry(ctx, item, data, model.OCreate)) + } + + return nil + } + + if resp := saveOperation(inRules); resp != nil { + return resp + } + if resp := saveOperation(outRules); resp != nil { + return resp + } + + if err := tx.Commit(); err != nil { + log.Error("[Service][Routing] create routing v2 from v1 commit", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(apimodel.Code_ExecuteException) + } + + return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) +} + +func batchBuildV2Routings( + req *apitraffic.Routing) ([]*apitraffic.RouteRule, []*apitraffic.RouteRule, *apiservice.Response) { + inBounds := req.GetInbounds() + outBounds := req.GetOutbounds() + inRoutings := make([]*apitraffic.RouteRule, 0, len(inBounds)) + outRoutings := make([]*apitraffic.RouteRule, 0, len(outBounds)) + for i := range inBounds { + routing, err := model.BuildV2RoutingFromV1Route(req, inBounds[i]) + if err != nil { + return nil, nil, apiv1.NewResponse(apimodel.Code_ExecuteException) + } + routing.Name = req.GetNamespace().GetValue() + "." + req.GetService().GetValue() + inRoutings = append(inRoutings, routing) + } + + for i := range outBounds { + routing, err := model.BuildV2RoutingFromV1Route(req, outBounds[i]) + if err != nil { + return nil, nil, apiv1.NewResponse(apimodel.Code_ExecuteException) + } + outRoutings = append(outRoutings, routing) + } + + return inRoutings, outRoutings, nil +} diff --git a/service/routing_config_v2.go b/service/routing_config_v2.go index 687354855..7e8040a1d 100644 --- a/service/routing_config_v2.go +++ b/service/routing_config_v2.go @@ -19,16 +19,12 @@ package service import ( "context" - "fmt" "strconv" - "time" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/wrappers" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" "go.uber.org/zap" @@ -36,14 +32,36 @@ import ( cachetypes "github.com/polarismesh/polaris/cache/api" apiv1 "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" ) +var ( + // RoutingConfigV2FilterAttrs router config filter attrs + RoutingConfigV2FilterAttrs = map[string]bool{ + "id": true, + "name": true, + "service": true, + "namespace": true, + "source_service": true, + "destination_service": true, + "source_namespace": true, + "destination_namespace": true, + "enable": true, + "offset": true, + "limit": true, + "order_field": true, + "order_type": true, + } +) + // CreateRoutingConfigsV2 Create a routing configuration func (s *Server) CreateRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + resp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { apiv1.Collect(resp, s.createRoutingConfigV2(ctx, entry)) @@ -54,6 +72,10 @@ func (s *Server) CreateRoutingConfigsV2( // createRoutingConfigV2 Create a routing configuration func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { + if resp := checkRoutingConfigV2(req); resp != nil { + return resp + } + conf, err := Api2RoutingConfigV2(req) if err != nil { log.Error("[Routing][V2] parse routing config v2 from request for create", @@ -67,11 +89,8 @@ func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.Rout return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, conf, model.OCreate)) - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId(), - Type: security.ResourceType_RouteRules, - }, false) + s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, conf, model.OCreate)) + req.Id = conf.ID return apiv1.NewRouterResponse(apimodel.Code_ExecuteSuccess, req) } @@ -79,6 +98,10 @@ func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.Rout // DeleteRoutingConfigsV2 Batch delete routing configuration func (s *Server) DeleteRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.deleteRoutingConfigV2(ctx, entry) @@ -90,27 +113,38 @@ func (s *Server) DeleteRoutingConfigsV2( // DeleteRoutingConfigV2 Delete a routing configuration func (s *Server) deleteRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { + if resp := checkRoutingConfigIDV2(req); resp != nil { + return resp + } + + // Determine whether the current routing rules are only converted from the memory transmission in the V1 version + if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { + resp := s.transferV1toV2OnModify(ctx, req) + if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { + return resp + } + } + if err := s.storage.DeleteRoutingConfigV2(req.Id); err != nil { log.Error("[Routing][V2] delete routing config v2 store layer", utils.RequestID(ctx), zap.Error(err)) return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, &model.RouterConfig{ + s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, &model.RouterConfig{ ID: req.GetId(), Name: req.GetName(), }, model.ODelete)) - - _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ - ID: req.GetId(), - Type: security.ResourceType_RouteRules, - }, true) return apiv1.NewRouterResponse(apimodel.Code_ExecuteSuccess, req) } // UpdateRoutingConfigsV2 Batch update routing configuration func (s *Server) UpdateRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.updateRoutingConfigV2(ctx, entry) @@ -122,6 +156,21 @@ func (s *Server) UpdateRoutingConfigsV2( // updateRoutingConfigV2 Update a single routing configuration func (s *Server) updateRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { + // If V2 routing rules to be modified are from the V1 rule in the cache, need to do the following steps first + // step 1: Turn the V1 rule to the real V2 rule + // step 2: Find the corresponding route to the V2 rules to be modified in the V1 rules, set their rules ID + // step 3: Store persistence + if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { + resp := s.transferV1toV2OnModify(ctx, req) + if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { + return resp + } + } + + if resp := checkUpdateRoutingConfigV2(req); resp != nil { + return resp + } + // Check whether the routing configuration exists conf, err := s.storage.GetRoutingConfigV2WithID(req.Id) if err != nil { @@ -147,7 +196,7 @@ func (s *Server) updateRoutingConfigV2(ctx context.Context, req *apitraffic.Rout return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, reqModel, model.OUpdate)) + s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, reqModel, model.OUpdate)) return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) } @@ -178,11 +227,6 @@ func (s *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]str return resp } -// GetAllRouterRules Query all router_rule rules -func (s *Server) GetAllRouterRules(ctx context.Context) *apiservice.BatchQueryResponse { - return nil -} - // EnableRoutings batch enable routing rules func (s *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) @@ -195,6 +239,17 @@ func (s *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule } func (s *Server) enableRoutings(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { + if resp := checkRoutingConfigIDV2(req); resp != nil { + return resp + } + + if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { + resp := s.transferV1toV2OnModify(ctx, req) + if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { + return resp + } + } + conf, err := s.storage.GetRoutingConfigV2WithID(req.Id) if err != nil { log.Error("[Routing][V2] get routing config v2 store layer", @@ -214,13 +269,83 @@ func (s *Server) enableRoutings(ctx context.Context, req *apitraffic.RouteRule) return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, conf, model.OUpdate)) + s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, conf, model.OUpdate)) + return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) +} + +// transferV1toV2OnModify When enabled or prohibited for the V2 rules, the V1 rules need to be converted to V2 rules +// and execute persistent storage +func (s *Server) transferV1toV2OnModify(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { + svcId, _ := s.Cache().RoutingConfig().IsConvertFromV1(req.Id) + v1conf, err := s.storage.GetRoutingConfigWithID(svcId) + if err != nil { + log.Error("[Routing][V2] get routing config v1 store layer", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) + } + if v1conf != nil { + svc, err := s.loadServiceByID(svcId) + if svc == nil { + log.Error("[Routing][V2] convert routing config v1 to v2 find svc", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(apimodel.Code_NotFoundService) + } + + inV2, outV2, err := model.ConvertRoutingV1ToExtendV2(svc.Name, svc.Namespace, v1conf) + if err != nil { + log.Error("[Routing][V2] convert routing config v1 to v2", + utils.RequestID(ctx), zap.Error(err)) + return apiv1.NewResponse(apimodel.Code_ExecuteException) + } + + formatApi := func(rules []*model.ExtendRouterConfig) ([]*apitraffic.RouteRule, *apiservice.Response) { + ret := make([]*apitraffic.RouteRule, 0, len(rules)) + for i := range rules { + item, err := rules[i].ToApi() + if err != nil { + log.Error("[Routing][V2] convert routing config v1 to v2, format v2 to api", + utils.RequestID(ctx), zap.Error(err)) + return nil, apiv1.NewResponse(apimodel.Code_ExecuteException) + } + ret = append(ret, item) + } + + return ret, nil + } + + inDatas, resp := formatApi(inV2) + if resp != nil { + return resp + } + outDatas, resp := formatApi(outV2) + if resp != nil { + return resp + } + + if resp := s.saveRoutingV1toV2(ctx, svcId, inDatas, outDatas); resp.GetCode().GetValue() != apiv1.ExecuteSuccess { + return apiv1.NewResponse(apimodel.Code(resp.GetCode().GetValue())) + } + } + return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) } // parseServiceArgs The query conditions of the analysis service -func parseRoutingArgs(filter map[string]string, ctx context.Context) (*cachetypes.RoutingArgs, *apiservice.Response) { - offset, limit, _ := utils.ParseOffsetAndLimit(filter) +func parseRoutingArgs(query map[string]string, ctx context.Context) (*cachetypes.RoutingArgs, *apiservice.Response) { + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return nil, apiv1.NewResponse(apimodel.Code_InvalidParameter) + } + + filter := make(map[string]string) + for key, value := range query { + if _, ok := RoutingConfigV2FilterAttrs[key]; !ok { + log.Errorf("[Routing][V2][Query] attribute(%s) is not allowed", key) + return nil, apiv1.NewResponse(apimodel.Code_InvalidParameter) + } + filter[key] = value + } + res := &cachetypes.RoutingArgs{ Filter: filter, Name: filter["name"], @@ -254,6 +379,120 @@ func parseRoutingArgs(filter map[string]string, ctx context.Context) (*cachetype return res, nil } +// checkBatchRoutingConfig Check batch request +func checkBatchRoutingConfigV2(req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return apiv1.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > MaxBatchSize { + return apiv1.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// checkRoutingConfig Check the validity of the basic parameter of the routing configuration +func checkRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if err := checkRoutingNameAndNamespace(req); err != nil { + return err + } + + if err := checkRoutingConfigPriorityV2(req); err != nil { + return err + } + + if err := checkRoutingPolicyV2(req); err != nil { + return err + } + + return nil +} + +// checkUpdateRoutingConfigV2 Check the validity of the basic parameter of the routing configuration +func checkUpdateRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { + if resp := checkRoutingConfigIDV2(req); resp != nil { + return resp + } + + if err := checkRoutingNameAndNamespace(req); err != nil { + return err + } + + if err := checkRoutingConfigPriorityV2(req); err != nil { + return err + } + + if err := checkRoutingPolicyV2(req); err != nil { + return err + } + + return nil +} + +func checkRoutingNameAndNamespace(req *apitraffic.RouteRule) *apiservice.Response { + if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetName()), MaxDbRoutingName); err != nil { + return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingName, req) + } + + if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetNamespace()), + MaxDbServiceNamespaceLength); err != nil { + return apiv1.NewRouterResponse(apimodel.Code_InvalidNamespaceName, req) + } + + return nil +} + +func checkRoutingConfigIDV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.Id == "" { + return apiv1.NewResponse(apimodel.Code_InvalidRoutingID) + } + + return nil +} + +func checkRoutingConfigPriorityV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.Priority > 10 { + return apiv1.NewResponse(apimodel.Code_InvalidRoutingPriority) + } + + return nil +} + +func checkRoutingPolicyV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { + return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingPolicy, req) + } + + // Automatically supplement @Type attribute according to Policy + if req.RoutingConfig.TypeUrl == "" { + if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { + req.RoutingConfig.TypeUrl = model.RuleRoutingTypeUrl + } + if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_MetadataPolicy { + req.RoutingConfig.TypeUrl = model.MetaRoutingTypeUrl + } + } + + return nil +} + // Api2RoutingConfigV2 Convert the API parameter to internal data structure func Api2RoutingConfigV2(req *apitraffic.RouteRule) (*model.RouterConfig, error) { out := &model.RouterConfig{ @@ -292,22 +531,3 @@ func marshalRoutingV2toAnySlice(routings []*model.ExtendRouterConfig) ([]*any.An return ret, nil } - -// routeRuleRecordEntry Construction of RoutingConfig's record Entry -func routeRuleRecordEntry(ctx context.Context, req *apitraffic.RouteRule, md *model.RouterConfig, - opt model.OperationType) *model.RecordEntry { - - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - - entry := &model.RecordEntry{ - ResourceType: model.RRouting, - ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), - Namespace: req.GetNamespace(), - OperationType: opt, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - return entry -} diff --git a/service/routing_config_v2_test.go b/service/routing_config_v2_test.go index 3297489df..8513f3ae0 100644 --- a/service/routing_config_v2_test.go +++ b/service/routing_config_v2_test.go @@ -120,6 +120,250 @@ func TestCreateRoutingConfigV2(t *testing.T) { }) } +// TestCompatibleRoutingConfigV2AndV1 测试V2版本的路由规则和V1版本的路由规则 +func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { + + svc := &apiservice.Service{ + Name: utils.NewStringValue("compatible-routing"), + Namespace: utils.NewStringValue("compatible"), + } + + initSuitFunc := func(t *testing.T) *DiscoverTestSuit { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + discoverSuit.Destroy() + }) + + createSvcResp := discoverSuit.DiscoverServer().CreateServices(discoverSuit.DefaultCtx, []*apiservice.Service{svc}) + if !respSuccess(createSvcResp) { + t.Fatalf("error: %s", createSvcResp.GetInfo().GetValue()) + } + + _ = createSvcResp.Responses[0].GetService() + t.Cleanup(func() { + discoverSuit.cleanServices([]*apiservice.Service{svc}) + }) + return discoverSuit + } + + t.Run("V1的存量规则-走V2接口可以查询到,ExtendInfo符合要求", func(t *testing.T) { + discoverSuit := initSuitFunc(t) + _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) + t.Cleanup(func() { + discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.truncateCommonRoutingConfigV2() + }) + + _ = discoverSuit.CacheMgr().TestUpdate() + // 从缓存中查询应该查到 3+3 条 v2 的路由规则 + out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + + rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + for i := range rulesV2 { + item := rulesV2[i] + assert.True(t, item.Enable, "v1 to v2 need default open enable") + msg := &apitraffic.RuleRoutingConfig{} + err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) + assert.NoError(t, err) + assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") + assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") + assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") + } + }) + + t.Run("V1的存量规则-走v2规则的启用可正常迁移v1规则", func(t *testing.T) { + discoverSuit := initSuitFunc(t) + _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) + t.Cleanup(func() { + discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.truncateCommonRoutingConfigV2() + }) + + _ = discoverSuit.CacheMgr().TestUpdate() + // 从缓存中查询应该查到 3 条 v2 的路由规则 + out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + + // 选择其中一条规则进行enable操作 + v2resp := discoverSuit.DiscoverServer().EnableRoutings(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) + if !respSuccess(v2resp) { + t.Fatalf("error: %+v", v2resp) + } + // 直接查询存储无法查询到 v1 的路由规则 + total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) + assert.NoError(t, err, err) + assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") + assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") + + // 从缓存中查询应该查到 3 条 v2 的路由规则 + out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + for i := range rulesV2 { + item := rulesV2[i] + assert.True(t, item.Enable, "v1 to v2 need default open enable") + msg := &apitraffic.RuleRoutingConfig{} + err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) + assert.NoError(t, err) + assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") + assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") + assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") + } + }) + + t.Run("V1的存量规则-走v2规则的删除可正常迁移v1规则", func(t *testing.T) { + discoverSuit := initSuitFunc(t) + _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) + t.Cleanup(func() { + discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.truncateCommonRoutingConfigV2() + }) + + _ = discoverSuit.CacheMgr().TestUpdate() + // 从缓存中查询应该查到 3+3 条 v2 的路由规则 + out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + + rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + + // 选择其中一条规则进行删除操作 + v2resp := discoverSuit.DiscoverServer().DeleteRoutingConfigsV2(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) + if !respSuccess(v2resp) { + t.Fatalf("error: %+v", v2resp) + } + // 直接查询存储无法查询到 v1 的路由规则 + total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) + assert.NoError(t, err, err) + assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") + assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") + + // 查询对应的 v2 规则也查询不到 + ruleV2, err := discoverSuit.Storage.GetRoutingConfigV2WithID(rulesV2[0].Id) + assert.NoError(t, err, err) + assert.Nil(t, ruleV2, "v2 routing must delete") + + _ = discoverSuit.CacheMgr().TestUpdate() + // 从缓存中查询应该查到 2 条 v2 的路由规则 + out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(2), int(out.GetAmount().GetValue()), "query routing size") + rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + for i := range rulesV2 { + item := rulesV2[i] + assert.True(t, item.Enable, "v1 to v2 need default open enable") + msg := &apitraffic.RuleRoutingConfig{} + err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) + assert.NoError(t, err) + assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") + assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") + assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") + } + }) + + t.Run("V1的存量规则-走v2规则的编辑可正常迁移v1规则", func(t *testing.T) { + discoverSuit := initSuitFunc(t) + _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) + t.Cleanup(func() { + discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) + discoverSuit.truncateCommonRoutingConfigV2() + }) + + _ = discoverSuit.CacheMgr().TestUpdate() + // 从缓存中查询应该查到 3+3 条 v2 的路由规则 + out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + + rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + + // 需要将 v2 规则的 extendInfo 规则清理掉 + // 选择其中一条规则进行enable操作 + rulesV2[0].Description = "update v2 rule and transfer v1 to v2" + v2resp := discoverSuit.DiscoverServer().UpdateRoutingConfigsV2(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) + if !respSuccess(v2resp) { + t.Fatalf("error: %+v", v2resp) + } + // 直接查询存储无法查询到 v1 的路由规则 + total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) + assert.NoError(t, err, err) + assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") + assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") + + // 查询对应的 v2 规则能够查询到 + ruleV2, err := discoverSuit.Storage.GetRoutingConfigV2WithID(rulesV2[0].Id) + assert.NoError(t, err, err) + assert.NotNil(t, ruleV2, "v2 routing must exist") + assert.Equal(t, rulesV2[0].Description, ruleV2.Description) + + _ = discoverSuit.CacheMgr().TestUpdate() + out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ + "limit": "100", + "offset": "0", + }) + if !respSuccess(out) { + t.Fatalf("error: %+v", out) + } + assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") + rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) + assert.NoError(t, err) + for i := range rulesV2 { + item := rulesV2[i] + assert.True(t, item.Enable, "v1 to v2 need default open enable") + msg := &apitraffic.RuleRoutingConfig{} + err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) + assert.NoError(t, err) + assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") + assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") + assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") + } + }) +} + // TestDeleteRoutingConfigV2 测试删除路由配置 func TestDeleteRoutingConfigV2(t *testing.T) { diff --git a/service/server.go b/service/server.go index a2bb50a21..8c1cd9b3e 100644 --- a/service/server.go +++ b/service/server.go @@ -20,7 +20,6 @@ package service import ( "context" - "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "golang.org/x/sync/singleflight" @@ -28,7 +27,6 @@ import ( cacheservice "github.com/polarismesh/polaris/cache/service" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/namespace" "github.com/polarismesh/polaris/plugin" @@ -159,62 +157,12 @@ func (s *Server) getLocation(host string) *model.Location { return location } -func (s *Server) afterRuleResource(ctx context.Context, r model.Resource, res auth.ResourceEntry, remove bool) error { - event := &ResourceEvent{ - Resource: res, - IsRemove: remove, - } - - for index := range s.hooks { - hook := s.hooks[index] - if err := hook.After(ctx, r, event); err != nil { - return err - } - } - return nil -} - func (s *Server) afterServiceResource(ctx context.Context, req *apiservice.Service, save *model.Service, remove bool) error { event := &ResourceEvent{ - Resource: auth.ResourceEntry{ - Type: security.ResourceType_Services, - ID: save.ID, - Metadata: save.Meta, - }, - AddPrincipals: func() []auth.Principal { - ret := make([]auth.Principal, 0, 4) - for i := range req.UserIds { - ret = append(ret, auth.Principal{ - PrincipalType: auth.PrincipalUser, - PrincipalID: req.UserIds[i].GetValue(), - }) - } - for i := range req.GroupIds { - ret = append(ret, auth.Principal{ - PrincipalType: auth.PrincipalGroup, - PrincipalID: req.GroupIds[i].GetValue(), - }) - } - return ret - }(), - DelPrincipals: func() []auth.Principal { - ret := make([]auth.Principal, 0, 4) - for i := range req.RemoveUserIds { - ret = append(ret, auth.Principal{ - PrincipalType: auth.PrincipalUser, - PrincipalID: req.RemoveUserIds[i].GetValue(), - }) - } - for i := range req.RemoveGroupIds { - ret = append(ret, auth.Principal{ - PrincipalType: auth.PrincipalGroup, - PrincipalID: req.RemoveGroupIds[i].GetValue(), - }) - } - return ret - }(), - IsRemove: remove, + ReqService: req, + Service: save, + IsRemove: remove, } for index := range s.hooks { @@ -223,6 +171,7 @@ func (s *Server) afterServiceResource(ctx context.Context, req *apiservice.Servi return err } } + return nil } diff --git a/service/service.go b/service/service.go index 57758d57b..e7b43d0a7 100644 --- a/service/service.go +++ b/service/service.go @@ -299,9 +299,9 @@ func (s *Server) GetAllServices(ctx context.Context, query map[string]string) *a ) if ns, ok := query["namespace"]; ok && len(ns) > 0 { - _, svcs = s.Cache().Service().ListServices(ctx, ns) + _, svcs = s.Cache().Service().ListServices(ns) } else { - _, svcs = s.Cache().Service().ListAllServices(ctx) + _, svcs = s.Cache().Service().ListAllServices() } ret := make([]*apiservice.Service, 0, len(svcs)) @@ -332,7 +332,7 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis inputInstMetaKeys, inputInstMetaValues string ) for key, value := range query { - typ := ServiceFilterAttributes[key] + typ, _ := ServiceFilterAttributes[key] switch { case typ == serviceFilter: serviceFilters[key] = value @@ -375,7 +375,10 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis } // 判断offset和limit是否为int,并从filters清除offset/limit参数 - offset, limit, _ := utils.ParseOffsetAndLimit(serviceFilters) + offset, limit, err := utils.ParseOffsetAndLimit(serviceFilters) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } serviceArgs := parseServiceArgs(serviceFilters, serviceMetas, ctx) total, services, err := s.caches.Service().GetServicesByFilter(ctx, serviceArgs, instanceArgs, offset, limit) @@ -753,8 +756,6 @@ func service2Api(service *model.Service) *apiservice.Service { Ctime: utils.NewStringValue(commontime.Time2String(service.CreateTime)), Mtime: utils.NewStringValue(commontime.Time2String(service.ModifyTime)), ExportTo: service.ListExportTo(), - Editable: utils.NewBoolValue(true), - Deleteable: utils.NewBoolValue(true), } return out diff --git a/service/service_alias.go b/service/service_alias.go index d97363a13..e95a6400f 100644 --- a/service/service_alias.go +++ b/service/service_alias.go @@ -142,16 +142,28 @@ func (s *Server) DeleteServiceAlias(ctx context.Context, req *apiservice.Service return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } - s.RecordHistory(ctx, serviceRecordEntry(ctx, &apiservice.Service{ - Name: req.GetAlias(), - Namespace: req.GetAliasNamespace(), - }, alias, model.ODelete)) return api.NewServiceAliasResponse(apimodel.Code_ExecuteSuccess, req) } +func checkBatchAlias(req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + // DeleteServiceAliases 删除服务别名列表 func (s *Server) DeleteServiceAliases( ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { + if checkError := checkBatchAlias(req); checkError != nil { + return checkError + } + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, alias := range req { response := s.DeleteServiceAlias(ctx, alias) @@ -253,8 +265,6 @@ func (s *Server) GetServiceAliases(ctx context.Context, query map[string]string) Comment: utils.NewStringValue(entry.Comment), Ctime: utils.NewStringValue(commontime.Time2String(entry.CreateTime)), Mtime: utils.NewStringValue(commontime.Time2String(entry.ModifyTime)), - Editable: utils.NewBoolValue(true), - Deleteable: utils.NewBoolValue(true), } resp.Aliases = append(resp.Aliases, item) } diff --git a/service/service_contract.go b/service/service_contract.go index 475c90964..a9e9f0eab 100644 --- a/service/service_contract.go +++ b/service/service_contract.go @@ -117,6 +117,7 @@ func (s *Server) CreateServiceContract(ctx context.Context, contract *apiservice } func (s *Server) GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { + out := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) out.Amount = utils.NewUInt32Value(0) out.Size = utils.NewUInt32Value(0) diff --git a/service/service_test.go b/service/service_test.go index 53d00062f..3aea3496a 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -45,7 +45,6 @@ import ( "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/store" "github.com/polarismesh/polaris/store/mock" - testsuit "github.com/polarismesh/polaris/test/suit" ) // 测试新增服务 @@ -1384,7 +1383,7 @@ func TestConcurrencyCreateSameService(t *testing.T) { userMgn, strategyMgn, err := auth.TestInitialize(ctx, &auth.Config{}, mockStore, cacheMgr) assert.NoError(t, err) - nsSvr, err = testsuit.TestNamespaceInitialize(ctx, &namespace.Config{ + nsSvr, err = namespace.TestInitialize(ctx, &namespace.Config{ AutoCreate: true, }, mockStore, cacheMgr, userMgn, strategyMgn) assert.NoError(t, err) diff --git a/store/auth_api.go b/store/auth_api.go index 0db55fec0..625df4a34 100644 --- a/store/auth_api.go +++ b/store/auth_api.go @@ -98,6 +98,9 @@ type StrategyStore interface { principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) // GetStrategyDetail Get strategy details GetStrategyDetail(id string) (*authcommon.StrategyDetail, error) + // GetStrategies Get a list of strategies + GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, + []*authcommon.StrategyDetail, error) // GetMoreStrategies Used to refresh policy cache // 此方法用于 cache 增量更新,需要注意 mtime 应为数据库时间戳 GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) @@ -105,14 +108,12 @@ type StrategyStore interface { // RoleStore Role related storage operation interface type RoleStore interface { - // GetRole - GetRole(id string) (*authcommon.Role, error) // AddRole Add a role AddRole(role *authcommon.Role) error // UpdateRole Update a role UpdateRole(role *authcommon.Role) error // DeleteRole Delete a role - DeleteRole(tx Tx, role *authcommon.Role) error + DeleteRole(role *authcommon.Role) error // CleanPrincipalRoles Clean all the roles associated with the principal CleanPrincipalRoles(tx Tx, p *authcommon.Principal) error // GetRole get more role for cache update diff --git a/store/boltdb/default.go b/store/boltdb/default.go index 17f143e59..17907b923 100644 --- a/store/boltdb/default.go +++ b/store/boltdb/default.go @@ -18,10 +18,11 @@ package boltdb import ( - "os" "time" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + bolt "go.etcd.io/bbolt" + "go.uber.org/zap" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -120,14 +121,13 @@ func (m *boltStore) Initialize(c *store.Config) error { } if loadFile, ok := c.Option["loadFile"].(string); ok { - // 仅用于本地测试验证单机数据 - loadFileName := os.Getenv("POLARIS_DEV_BOLT_INIT_DATA_FILA") - if loadFileName != "" { - loadFile = loadFileName - } if err := m.loadByFile(loadFile); err != nil { return err } + } else { + if err := m.loadByDefault(); err != nil { + return err + } } m.start = true return nil @@ -143,52 +143,166 @@ var ( servicesToInit = map[string]string{ "polaris.checker": "fbca9bfa04ae4ead86e1ecf5811e32a9", } + + mainUser = &authcommon.User{ + ID: "65e4789a6d5b49669adf1e9e8387549c", + Name: "polaris", + Password: "$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum", + Owner: "", + Source: "Polaris", + Mobile: "", + Email: "", + Type: 20, + Token: "nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=", + TokenEnable: true, + Valid: true, + Comment: "default polaris admin account", + CreateTime: time.Now(), + ModifyTime: time.Now(), + } + + superDefaultStrategy = &authcommon.StrategyDetail{ + ID: "super_user_default_strategy", + Name: "(用户) polarissys@admin的默认策略", + Action: "READ_WRITE", + Comment: "default admin", + Principals: []authcommon.Principal{ + { + StrategyID: "super_user_default_strategy", + PrincipalID: "", + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: true, + Owner: "", + Resources: []authcommon.StrategyResource{ + { + StrategyID: "super_user_default_strategy", + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: "*", + }, + { + StrategyID: "super_user_default_strategy", + ResType: int32(apisecurity.ResourceType_Services), + ResID: "*", + }, + { + StrategyID: "super_user_default_strategy", + ResType: int32(apisecurity.ResourceType_ConfigGroups), + ResID: "*", + }, + }, + Valid: true, + Revision: "fbca9bfa04ae4ead86e1ecf5811e32a9", + CreateTime: time.Now(), + ModifyTime: time.Now(), + } + + mainDefaultStrategy = &authcommon.StrategyDetail{ + ID: "fbca9bfa04ae4ead86e1ecf5811e32a9", + Name: "(用户) polaris的默认策略", + Action: "READ_WRITE", + Comment: "default admin", + Principals: []authcommon.Principal{ + { + StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", + PrincipalID: "65e4789a6d5b49669adf1e9e8387549c", + PrincipalType: authcommon.PrincipalUser, + }, + }, + Default: true, + Owner: "65e4789a6d5b49669adf1e9e8387549c", + Resources: []authcommon.StrategyResource{ + { + StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: "*", + }, + { + StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", + ResType: int32(apisecurity.ResourceType_Services), + ResID: "*", + }, + { + StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", + ResType: int32(apisecurity.ResourceType_ConfigGroups), + ResID: "*", + }, + }, + Valid: true, + Revision: "fbca9bfa04ae4ead86e1ecf5811e32a9", + CreateTime: time.Now(), + ModifyTime: time.Now(), + } ) func (m *boltStore) initNamingStoreData() error { for _, namespace := range namespacesToInit { curTime := time.Now() - val, err := m.GetNamespace(namespace) + err := m.AddNamespace(&model.Namespace{ + Name: namespace, + Token: utils.NewUUID(), + Owner: ownerToInit, + Valid: true, + CreateTime: curTime, + ModifyTime: curTime, + }) if err != nil { return err } - if val == nil { - if err := m.AddNamespace(&model.Namespace{ - Name: namespace, - Token: utils.NewUUID(), - Owner: ownerToInit, - Valid: true, - CreateTime: curTime, - ModifyTime: curTime, - }); err != nil { - return err - } - } } for svc, id := range servicesToInit { curTime := time.Now() - val, err := m.getServiceByNameAndNs(svc, namespacePolaris) + err := m.AddService(&model.Service{ + ID: id, + Name: svc, + Namespace: namespacePolaris, + Token: utils.NewUUID(), + Owner: ownerToInit, + Revision: utils.NewUUID(), + Valid: true, + CreateTime: curTime, + ModifyTime: curTime, + }) + if err != nil { + return err + } + } + return nil +} + +func (m *boltStore) initAuthStoreData() error { + return m.handler.Execute(true, func(tx *bolt.Tx) error { + user, err := m.getUser(tx, mainUser.ID) if err != nil { return err } - if val != nil { - if err := m.AddService(&model.Service{ - ID: id, - Name: svc, - Namespace: namespacePolaris, - Token: utils.NewUUID(), - Owner: ownerToInit, - Revision: utils.NewUUID(), - Valid: true, - CreateTime: curTime, - ModifyTime: curTime, - }); err != nil { + + if user == nil { + user = mainUser + // 添加主账户主体信息 + if err := saveValue(tx, tblUser, user.ID, converToUserStore(user)); err != nil { + authLog.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) return err } } - } - return nil + rule, err := m.getStrategyDetail(tx, mainDefaultStrategy.ID) + if err != nil { + return err + } + + if rule == nil { + strategy := mainDefaultStrategy + // 添加主账户的默认鉴权策略信息 + if err := saveValue(tx, tblStrategy, strategy.ID, convertForStrategyStore(strategy)); err != nil { + authLog.Error("[Store][Strategy] save auth_strategy", zap.Error(err), + zap.String("name", strategy.Name), zap.String("owner", strategy.Owner)) + return err + } + } + return nil + }) } func (m *boltStore) newStore() error { @@ -227,7 +341,6 @@ func (m *boltStore) newAuthModuleStore() { m.userStore = &userStore{handler: m.handler} m.strategyStore = &strategyStore{handler: m.handler} m.groupStore = &groupStore{handler: m.handler} - m.roleStore = &roleStore{handle: m.handler} } func (m *boltStore) newConfigModuleStore() { @@ -269,15 +382,3 @@ func init() { s := &boltStore{} _ = store.RegisterStore(s) } - -func buildAllResAllow(id string) []authcommon.StrategyResource { - ret := make([]authcommon.StrategyResource, 0, 8) - for i := range apisecurity.ResourceType_value { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: id, - ResType: apisecurity.ResourceType_value[i], - ResID: "*", - }) - } - return ret -} diff --git a/store/boltdb/handler_test.go b/store/boltdb/handler_test.go index 676811b45..33263bc59 100644 --- a/store/boltdb/handler_test.go +++ b/store/boltdb/handler_test.go @@ -27,10 +27,8 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "gopkg.in/yaml.v3" "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" ) func CreateTableDBHandlerAndRun(t *testing.T, tableName string, tf func(t *testing.T, handler BoltHandler)) { @@ -432,10 +430,3 @@ func TestBoltHandler_UpdateValue(t *testing.T) { } } - -func Test_PrintBoltInitData(t *testing.T) { - d, _ := yaml.Marshal(&authcommon.User{ - TokenEnable: true, - }) - t.Log(string(d)) -} diff --git a/store/boltdb/instance.go b/store/boltdb/instance.go index e97ecbeb2..0b61e055f 100644 --- a/store/boltdb/instance.go +++ b/store/boltdb/instance.go @@ -19,6 +19,7 @@ package boltdb import ( "errors" + "fmt" "sort" "strconv" "strings" @@ -27,7 +28,6 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" bolt "go.etcd.io/bbolt" - "go.uber.org/zap" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/polarismesh/polaris/common/model" @@ -603,7 +603,8 @@ func (i *instanceStore) SetInstanceHealthStatus(instanceID string, flag int, rev return err } if len(instances) == 0 { - log.Errorf("cant not find instance in kv, %s", instanceID) + msg := fmt.Sprintf("cant not find instance in kv, %s", instanceID) + log.Errorf(msg) return nil } @@ -666,7 +667,8 @@ func (i *instanceStore) BatchSetInstanceIsolate(ids []interface{}, isolate int, return err } if len(instances) == 0 { - log.Errorf("cant not find instance in kv, %v", ids) + msg := fmt.Sprintf("cant not find instance in kv, %v", ids) + log.Errorf(msg) return nil } @@ -682,7 +684,7 @@ func (i *instanceStore) BatchSetInstanceIsolate(ids []interface{}, isolate int, instance.Mtime = &wrappers.StringValue{Value: commontime.Time2String(curr)} err = i.handler.UpdateValue(tblNameInstance, id, properties) if err != nil { - log.Error("[Store][boltdb] update instance in set instance isolate error", zap.Error(err)) + log.Errorf("[Store][boltdb] update instance in set instance isolate error, %v", err) return err } } diff --git a/store/boltdb/load.go b/store/boltdb/load.go index 0779c01d0..41cfdbd6f 100644 --- a/store/boltdb/load.go +++ b/store/boltdb/load.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "os" + "sort" "time" bolt "go.etcd.io/bbolt" @@ -31,11 +32,22 @@ import ( authcommon "github.com/polarismesh/polaris/common/model/auth" ) +func (m *boltStore) loadByDefault() error { + if err := m.initAuthStoreData(); err != nil { + _ = m.handler.Close() + return err + } + if err := m.initNamingStoreData(); err != nil { + _ = m.handler.Close() + return err + } + return nil +} + // DefaultData 默认数据信息 type DefaultData struct { - Namespaces []*model.Namespace `yaml:"namespaces"` - Users []*authcommon.User `yaml:"users"` - Policies []*authcommon.StrategyDetail `yaml:"policies"` + Namespaces []*model.Namespace `yaml:"namespaces"` + Users []*authcommon.User `yaml:"users"` } func (m *boltStore) loadByFile(loadFile string) error { @@ -59,47 +71,102 @@ func (m *boltStore) loadByFile(loadFile string) error { return err } } + if len(data.Users) == 0 { + if err := m.initAuthStoreData(); err != nil { + _ = m.handler.Close() + return err + } + return nil + } return m.loadFromData(data) } func (m *boltStore) loadFromData(data *DefaultData) error { + users := data.Users + + // 确保排序为 admin -> main -> sub + sort.Slice(users, func(i, j int) bool { + return users[i].Type < users[j].Type + }) + + tn := time.Now() + var ( + superUser, mainUser *authcommon.User + ) + if len(users) >= 2 && users[0].Type == authcommon.AdminUserRole && users[1].Type == authcommon.OwnerUserRole { + superUser = users[0] + superUser.CreateTime = tn + superUser.ModifyTime = tn + mainUser = users[1] + mainUser.CreateTime = tn + mainUser.ModifyTime = tn + } else if users[0].Type == authcommon.OwnerUserRole { + mainUser = users[0] + mainUser.CreateTime = tn + mainUser.ModifyTime = tn + } else { + return errors.New("invalid init users info, must be have main user info") + } + if err := m.handler.Execute(true, func(tx *bolt.Tx) error { - for i := range data.Users { - saveUser, err := m.getUser(tx, data.Users[i].ID) + saveFunc := func(user *authcommon.User, rule *authcommon.StrategyDetail) error { + rule.Owner = user.ID + rule.Principals[0].PrincipalID = user.ID + saveUser, err := m.getUser(tx, user.ID) if err != nil { return err } + if saveUser == nil { - data.Users[i].CreateTime = time.Now() - data.Users[i].ModifyTime = time.Now() // 添加主账户主体信息 - if err := saveValue(tx, tblUser, data.Users[i].ID, converToUserStore(data.Users[i])); err != nil { - log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", data.Users[i].Name)) + if err := saveValue(tx, tblUser, user.ID, converToUserStore(user)); err != nil { + log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) return err } } - } - for i := range data.Policies { - saveRule, err := m.getStrategyDetail(tx, data.Policies[i].ID) + saveRule, err := m.getStrategyDetail(tx, rule.ID) if err != nil { return err } if saveRule == nil { - data.Policies[i].CreateTime = time.Now() - data.Policies[i].ModifyTime = time.Now() // 添加主账户的默认鉴权策略信息 - if err := saveValue(tx, tblStrategy, data.Policies[i].ID, convertForStrategyStore(data.Policies[i])); err != nil { + if err := saveValue(tx, tblStrategy, rule.ID, convertForStrategyStore(rule)); err != nil { log.Error("[Store][Strategy] save auth_strategy", zap.Error(err), - zap.String("name", data.Policies[i].Name), zap.String("owner", data.Policies[i].Owner)) + zap.String("name", rule.Name), zap.String("owner", rule.Owner)) return err } } + + return nil + } + + if superUser != nil { + if err := saveFunc(superUser, superDefaultStrategy); err != nil { + return err + } + } + if err := saveFunc(mainUser, mainDefaultStrategy); err != nil { + return err } return nil }); err != nil { return err } - return nil + + tx, err := m.handle.StartTx() + if err != nil { + return err + } + defer func() { + _ = tx.Rollback() + }() + // 挨个处理其他用户数据信息 + for i := 1; i < len(users); i++ { + if err := m.AddUser(tx, users[i]); err != nil { + return nil + } + } + return tx.Commit() } diff --git a/store/boltdb/namespace.go b/store/boltdb/namespace.go index 2090770b7..a463ee4c1 100644 --- a/store/boltdb/namespace.go +++ b/store/boltdb/namespace.go @@ -24,8 +24,6 @@ import ( "sort" "time" - "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" ) @@ -103,7 +101,7 @@ func (n *namespaceStore) AddNamespace(namespace *model.Namespace) error { func (n *namespaceStore) cleanNamespace(name string) error { if err := n.handler.DeleteValues(tblNameNamespace, []string{name}); err != nil { - log.Error("[Store][boltdb] delete invalid namespace error", zap.Error(err)) + log.Errorf("[Store][boltdb] delete invalid namespace error, %+v", err) return err } @@ -120,12 +118,7 @@ func (n *namespaceStore) UpdateNamespace(namespace *model.Namespace) error { properties["Comment"] = namespace.Comment properties["ModifyTime"] = time.Now() properties["ServiceExportTo"] = utils.MustJson(namespace.ServiceExportTo) - properties["Metadata"] = utils.MustJson(namespace.Metadata) - if err := n.handler.UpdateValue(tblNameNamespace, namespace.Name, properties); err != nil { - log.Error("[Store][boltdb] update namespace error", zap.Error(err)) - return err - } - return nil + return n.handler.UpdateValue(tblNameNamespace, namespace.Name, properties) } // UpdateNamespaceToken update the token of a namespace @@ -151,9 +144,6 @@ func (n *namespaceStore) GetNamespace(name string) (*model.Namespace, error) { return nil, nil } ns := nsValue.(*Namespace) - if !ns.Valid { - return nil, nil - } return n.toModel(ns), nil } @@ -253,7 +243,7 @@ func (n *namespaceStore) GetMoreNamespaces(mtime time.Time) ([]*model.Namespace, if !ok { return false } - return !mTimeValue.(time.Time).Before(mtime) + return mTimeValue.(time.Time).After(mtime) }) if err != nil { return nil, err @@ -266,14 +256,8 @@ func (n *namespaceStore) toModel(data *Namespace) *model.Namespace { } func toModelNamespace(data *Namespace) *model.Namespace { - if !data.Valid { - return nil - } export := make(map[string]struct{}) _ = json.Unmarshal([]byte(data.ServiceExportTo), &export) - - metadata := make(map[string]string) - _ = json.Unmarshal([]byte(data.Metadata), &metadata) return &model.Namespace{ Name: data.Name, Comment: data.Comment, @@ -283,7 +267,6 @@ func toModelNamespace(data *Namespace) *model.Namespace { CreateTime: data.CreateTime, ModifyTime: data.ModifyTime, Valid: data.Valid, - Metadata: metadata, } } @@ -297,7 +280,6 @@ func (n *namespaceStore) toStore(data *model.Namespace) *Namespace { CreateTime: data.CreateTime, ModifyTime: data.ModifyTime, Valid: data.Valid, - Metadata: utils.MustJson(data.Metadata), } } @@ -312,5 +294,4 @@ type Namespace struct { ServiceExportTo string CreateTime time.Time ModifyTime time.Time - Metadata string } diff --git a/store/boltdb/ratelimit_test.go b/store/boltdb/ratelimit_test.go index bad397335..8531f38b1 100644 --- a/store/boltdb/ratelimit_test.go +++ b/store/boltdb/ratelimit_test.go @@ -57,7 +57,6 @@ func createTestRateLimit(id string, createId bool) *model.RateLimit { CreateTime: time.Now(), ModifyTime: time.Now(), EnableTime: time.Now(), - Metadata: map[string]string{}, } } @@ -78,7 +77,7 @@ func Test_rateLimitStore_CreateRateLimit(t *testing.T) { t.Fatal(err) } - tN := time.Time{} + tN := time.Now() tVal := testVal tVal.ModifyTime = tN tVal.CreateTime = tN @@ -240,7 +239,7 @@ func Test_rateLimitStore_GetExtendRateLimits(t *testing.T) { got1Limits = append(got1Limits, got1[i].RateLimit) } - tN := time.Time{} + tN := time.Now() sort.Slice(got1, func(i, j int) bool { got1Limits[i].CreateTime = tN diff --git a/store/boltdb/role.go b/store/boltdb/role.go index adb41b84d..21b8a5085 100644 --- a/store/boltdb/role.go +++ b/store/boltdb/role.go @@ -21,13 +21,12 @@ import ( "encoding/json" "time" - bolt "go.etcd.io/bbolt" - "go.uber.org/zap" - "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" + bolt "go.etcd.io/bbolt" + "go.uber.org/zap" ) var _ store.RoleStore = (*roleStore)(nil) @@ -110,19 +109,22 @@ func (s *roleStore) UpdateRole(role *authcommon.Role) error { } // DeleteRole Delete a role -func (s *roleStore) DeleteRole(tx store.Tx, role *authcommon.Role) error { +func (s *roleStore) DeleteRole(role *authcommon.Role) error { if role.ID == "" { log.Error("[Store][role] delete role missing some params") return ErrBadParam } - dbTx := tx.GetDelegateTx().(*bolt.Tx) data := newRoleData(role) - properties := map[string]interface{}{ - CommonFieldValid: false, - CommonFieldModifyTime: time.Now(), - } - if err := updateValue(dbTx, tblRole, data.ID, properties); err != nil { + + err := s.handle.Execute(true, func(tx *bolt.Tx) error { + properties := map[string]interface{}{ + CommonFieldValid: false, + CommonFieldModifyTime: time.Now(), + } + return updateValue(tx, tblRole, data.ID, properties) + }) + if err != nil { log.Error("[Store][role] delete role failed", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) } @@ -190,20 +192,6 @@ func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) er return nil } -// GetRole get more role for cache update -func (s *roleStore) GetRole(id string) (*authcommon.Role, error) { - ret, err := s.handle.LoadValues(tblRole, []string{id}, &model.RoutingConfig{}) - if err != nil { - log.Errorf("[Store][role] get one role, %v", err) - return nil, store.Error(err) - } - - for i := range ret { - return newRole(ret[i].(*roleData)), nil - } - return nil, nil -} - // GetMoreRoles get more role for cache update func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { fields := []string{CommonFieldModifyTime, CommonFieldValid} @@ -261,8 +249,8 @@ func newRoleData(r *authcommon.Role) *roleData { } func newRole(r *roleData) *authcommon.Role { - users := make([]authcommon.Principal, 0, 32) - groups := make([]authcommon.Principal, 0, 32) + users := make([]*authcommon.User, 0, 32) + groups := make([]*authcommon.UserGroup, 0, 32) _ = json.Unmarshal([]byte(r.Users), &users) _ = json.Unmarshal([]byte(r.UserGroups), &groups) diff --git a/store/boltdb/strategy.go b/store/boltdb/strategy.go index a47f19807..1a028c9a8 100644 --- a/store/boltdb/strategy.go +++ b/store/boltdb/strategy.go @@ -18,9 +18,9 @@ package boltdb import ( - "encoding/json" "errors" "fmt" + "sort" "strings" "time" @@ -58,6 +58,24 @@ var ( ErrorStrategyNotFound error = errors.New("strategy not fonud") ) +type strategyForStore struct { + ID string + Name string + Action string + Comment string + Users map[string]string + Groups map[string]string + Default bool + Owner string + NsResources map[string]string + SvcResources map[string]string + CfgResources map[string]string + Valid bool + Revision string + CreateTime time.Time + ModifyTime time.Time +} + // StrategyStore type strategyStore struct { handler BoltHandler @@ -118,19 +136,17 @@ func (ss *strategyStore) UpdateStrategy(strategy *authcommon.ModifyStrategyDetai // updateStrategy func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *authcommon.ModifyStrategyDetail, - saveVal *strategyData) error { + saveVal *strategyForStore) error { saveVal.Action = modify.Action saveVal.Comment = modify.Comment saveVal.Revision = utils.NewUUID() - saveVal.CalleeFunctions = utils.MustJson(modify.CalleeMethods) - saveVal.Conditions = utils.MustJson(modify.Conditions) computePrincipals(false, modify.AddPrincipals, saveVal) computePrincipals(true, modify.RemovePrincipals, saveVal) - saveVal.computeResources(false, modify.AddResources) - saveVal.computeResources(true, modify.RemoveResources) + computeResources(false, modify.AddResources, saveVal) + computeResources(true, modify.RemoveResources, saveVal) saveVal.ModifyTime = time.Now() @@ -149,7 +165,7 @@ func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *authcommon.ModifySt return nil } -func computePrincipals(remove bool, principals []authcommon.Principal, saveVal *strategyData) { +func computePrincipals(remove bool, principals []authcommon.Principal, saveVal *strategyForStore) { for i := range principals { principal := principals[i] if principal.PrincipalType == authcommon.PrincipalUser { @@ -168,6 +184,36 @@ func computePrincipals(remove bool, principals []authcommon.Principal, saveVal * } } +func computeResources(remove bool, resources []authcommon.StrategyResource, saveVal *strategyForStore) { + for i := range resources { + resource := resources[i] + if resource.ResType == int32(apisecurity.ResourceType_Namespaces) { + if remove { + delete(saveVal.NsResources, resource.ResID) + } else { + saveVal.NsResources[resource.ResID] = "" + } + continue + } + if resource.ResType == int32(apisecurity.ResourceType_Services) { + if remove { + delete(saveVal.SvcResources, resource.ResID) + } else { + saveVal.SvcResources[resource.ResID] = "" + } + continue + } + if resource.ResType == int32(apisecurity.ResourceType_ConfigGroups) { + if remove { + delete(saveVal.CfgResources, resource.ResID) + } else { + saveVal.CfgResources[resource.ResID] = "" + } + continue + } + } +} + // DeleteStrategy delete a strategy func (ss *strategyStore) DeleteStrategy(id string) error { if id == "" { @@ -218,7 +264,7 @@ func (ss *strategyStore) operateStrategyResources(remove bool, resources []authc return ErrorStrategyNotFound } - rule.computeResources(remove, ress) + computeResources(remove, ress, rule) rule.ModifyTime = time.Now() if err := saveValue(tx, tblStrategy, rule.ID, rule); err != nil { log.Error("[Store][Strategy] operate strategy resource", zap.Error(err), @@ -236,10 +282,10 @@ func (ss *strategyStore) operateStrategyResources(remove bool, resources []authc return nil } -func loadStrategyById(tx *bolt.Tx, id string) (*strategyData, error) { +func loadStrategyById(tx *bolt.Tx, id string) (*strategyForStore, error) { values := make(map[string]interface{}) - if err := loadValues(tx, tblStrategy, []string{id}, &strategyData{}, values); err != nil { + if err := loadValues(tx, tblStrategy, []string{id}, &strategyForStore{}, values); err != nil { log.Error("[Store][Strategy] get auth_strategy by id", zap.Error(err), zap.String("id", id)) return nil, err @@ -252,9 +298,9 @@ func loadStrategyById(tx *bolt.Tx, id string) (*strategyData, error) { return nil, ErrorMultiDefaultStrategy } - var ret *strategyData + var ret *strategyForStore for _, v := range values { - ret = v.(*strategyData) + ret = v.(*strategyForStore) break } @@ -320,7 +366,7 @@ func (ss *strategyStore) GetStrategyResources(principalId string, fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } - values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyData{}, + values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) if ok && !valid { @@ -347,13 +393,43 @@ func (ss *strategyStore) GetStrategyResources(principalId string, ret := make([]authcommon.StrategyResource, 0, 4) for _, item := range values { - rule := item.(*strategyData) - ret = append(ret, rule.GetResources()...) + rule := item.(*strategyForStore) + ret = append(ret, collectStrategyResources(rule)...) } return ret, nil } +func collectStrategyResources(rule *strategyForStore) []authcommon.StrategyResource { + ret := make([]authcommon.StrategyResource, 0, len(rule.NsResources)+len(rule.SvcResources)+len(rule.CfgResources)) + + for id := range rule.NsResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: rule.ID, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: id, + }) + } + + for id := range rule.SvcResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: rule.ID, + ResType: int32(apisecurity.ResourceType_Services), + ResID: id, + }) + } + + for id := range rule.CfgResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: rule.ID, + ResType: int32(apisecurity.ResourceType_ConfigGroups), + ResID: id, + }) + } + + return ret +} + // GetDefaultStrategyDetailByPrincipal 获取默认策略详情 func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) { @@ -364,7 +440,7 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } - values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyData{}, + values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) if ok && !valid { @@ -401,15 +477,142 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, return nil, ErrorMultiDefaultStrategy } - var ret *strategyData + var ret *strategyForStore for _, v := range values { - ret = v.(*strategyData) + ret = v.(*strategyForStore) break } return convertForStrategyDetail(ret), nil } +// GetStrategies 查询鉴权策略列表 +func (ss *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, + []*authcommon.StrategyDetail, error) { + + showDetail := filters["show_detail"] + delete(filters, "show_detail") + + return ss.listStrategies(filters, offset, limit, showDetail == "true") +} + +func (ss *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, + showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { + + fields := []string{StrategyFieldValid, StrategyFieldName, StrategyFieldUsersPrincipal, + StrategyFieldGroupsPrincipal, StrategyFieldNsResources, StrategyFieldSvcResources, + StrategyFieldCfgResources, StrategyFieldOwner, StrategyFieldDefault} + + values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, + func(m map[string]interface{}) bool { + valid, ok := m[StrategyFieldValid].(bool) + if ok && !valid { + return false + } + + saveName, _ := m[StrategyFieldName].(string) + saveDefault, _ := m[StrategyFieldDefault].(bool) + saveOwner, _ := m[StrategyFieldOwner].(string) + + if name, ok := filters["name"]; ok { + if utils.IsPrefixWildName(name) { + name = name[:len(name)-1] + } + if !strings.Contains(saveName, name) { + return false + } + } + + if owner, ok := filters["owner"]; ok { + if strings.Compare(saveOwner, owner) != 0 { + if principalId, ok := filters["principal_id"]; ok { + principalType := filters["principal_type"] + if !comparePrincipalExist(principalType, principalId, m) { + return false + } + } + } + } + + if isDefault, ok := filters["default"]; ok { + compareParam2BoolNotEqual := func(param string, b bool) bool { + if param == "0" && !b { + return true + } + if param == "1" && b { + return true + } + return false + } + if !compareParam2BoolNotEqual(isDefault, saveDefault) { + return false + } + } + + if resType, ok := filters["res_type"]; ok { + resId := filters["res_id"] + if !compareResExist(resType, resId, m) { + return false + } + } + + if principalId, ok := filters["principal_id"]; ok { + principalType := filters["principal_type"] + if !comparePrincipalExist(principalType, principalId, m) { + return false + } + } + + return true + }) + + if err != nil { + log.Error("[Store][Strategy] get auth_strategy for list", zap.Error(err)) + return 0, nil, err + } + + return uint32(len(values)), doStrategyPage(values, offset, limit, showDetail), nil +} + +func doStrategyPage(ret map[string]interface{}, offset, limit uint32, showDetail bool) []*authcommon.StrategyDetail { + rules := make([]*authcommon.StrategyDetail, 0, len(ret)) + + beginIndex := offset + endIndex := beginIndex + limit + totalCount := uint32(len(ret)) + + if totalCount == 0 { + return rules + } + if beginIndex >= endIndex { + return rules + } + if beginIndex >= totalCount { + return rules + } + if endIndex > totalCount { + endIndex = totalCount + } + + emptyPrincipals := make([]authcommon.Principal, 0) + emptyResources := make([]authcommon.StrategyResource, 0) + + for k := range ret { + rule := convertForStrategyDetail(ret[k].(*strategyForStore)) + if !showDetail { + rule.Principals = emptyPrincipals + rule.Resources = emptyResources + } + rules = append(rules, rule) + } + + sort.Slice(rules, func(i, j int) bool { + return rules[i].ModifyTime.After(rules[j].ModifyTime) + }) + + return rules[beginIndex:endIndex] +} + func compareResExist(resType, resId string, m map[string]interface{}) bool { saveNsRes, _ := m[StrategyFieldNsResources].(map[string]string) saveSvcRes, _ := m[StrategyFieldSvcResources].(map[string]string) @@ -453,7 +656,7 @@ func comparePrincipalExist(principalType, principalId string, m map[string]inter // GetMoreStrategies get strategy details for cache func (ss *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { - ret, err := ss.handler.LoadValuesByFilter(tblStrategy, []string{StrategyFieldModifyTime}, &strategyData{}, + ret, err := ss.handler.LoadValuesByFilter(tblStrategy, []string{StrategyFieldModifyTime}, &strategyForStore{}, func(m map[string]interface{}) bool { mt := m[StrategyFieldModifyTime].(time.Time) isAfter := mt.After(mtime) @@ -468,7 +671,7 @@ func (ss *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([ for k := range ret { val := ret[k] - strategies = append(strategies, convertForStrategyDetail(val.(*strategyData))) + strategies = append(strategies, convertForStrategyDetail(val.(*strategyForStore))) } return strategies, nil @@ -479,7 +682,7 @@ func (ss *strategyStore) CleanPrincipalPolicies(tx store.Tx, p authcommon.Princi values := make(map[string]interface{}) dbTx := tx.GetDelegateTx().(*bolt.Tx) - err := loadValuesByFilter(dbTx, tblStrategy, fields, &strategyData{}, + err := loadValuesByFilter(dbTx, tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { isDefault := m[StrategyFieldDefault].(bool) if !isDefault { @@ -532,7 +735,7 @@ func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) e fields := []string{StrategyFieldName, StrategyFieldOwner, StrategyFieldValid} values := make(map[string]interface{}) - err := loadValuesByFilter(tx, tblStrategy, fields, &strategyData{}, + err := loadValuesByFilter(tx, tblStrategy, fields, &strategyForStore{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) // 如果数据是 valid 的,则不能被清理 @@ -564,85 +767,7 @@ func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) e return deleteValues(tx, tblStrategy, keys) } -type strategyData struct { - ID string - Name string - Action string - Comment string - Users map[string]string - Groups map[string]string - Default bool - Owner string - NsResources map[string]string - SvcResources map[string]string - CfgResources map[string]string - AllResources string - CalleeFunctions string - Conditions string - Valid bool - Revision string - CreateTime time.Time - ModifyTime time.Time -} - -func (s *strategyData) computeResources(remove bool, resources []authcommon.StrategyResource) { - saveVal := s.GetResources() - - tmp := make(map[string]authcommon.StrategyResource, 8) - for i := range saveVal { - tmp[saveVal[i].Key()] = saveVal[i] - } - for i := range resources { - resource := resources[i] - if remove { - delete(tmp, resource.Key()) - } else { - tmp[resource.Key()] = resource - } - } - - ret := make([]authcommon.StrategyResource, 0, 8) - for i := range tmp { - ret = append(ret, tmp[i]) - } - - s.AllResources = utils.MustJson(ret) -} - -func (s *strategyData) GetResources() []authcommon.StrategyResource { - ret := make([]authcommon.StrategyResource, 0, len(s.NsResources)+len(s.SvcResources)+len(s.CfgResources)) - - for id := range s.NsResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: s.ID, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: id, - }) - } - - for id := range s.SvcResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: s.ID, - ResType: int32(apisecurity.ResourceType_Services), - ResID: id, - }) - } - - for id := range s.CfgResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: s.ID, - ResType: int32(apisecurity.ResourceType_ConfigGroups), - ResID: id, - }) - } - if len(s.AllResources) != 0 { - ret = make([]authcommon.StrategyResource, 0, 4) - _ = json.Unmarshal([]byte(s.AllResources), &ret) - } - return ret -} - -func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyData { +func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyForStore { var ( users = make(map[string]string, 4) @@ -659,28 +784,48 @@ func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyData } } - return &strategyData{ - ID: strategy.ID, - Name: strategy.Name, - Action: strategy.Action, - Comment: strategy.Comment, - Users: users, - Groups: groups, - Default: strategy.Default, - Owner: strategy.Owner, - AllResources: utils.MustJson(strategy.Resources), - CalleeFunctions: utils.MustJson(strategy.CalleeMethods), - Conditions: utils.MustJson(strategy.Conditions), - Valid: strategy.Valid, - Revision: strategy.Revision, - CreateTime: strategy.CreateTime, - ModifyTime: strategy.ModifyTime, + ns := make(map[string]string, 4) + svc := make(map[string]string, 4) + cfg := make(map[string]string, 4) + + resources := strategy.Resources + + for i := range resources { + res := resources[i] + switch res.ResType { + case int32(apisecurity.ResourceType_Namespaces): + ns[res.ResID] = "" + case int32(apisecurity.ResourceType_Services): + svc[res.ResID] = "" + case int32(apisecurity.ResourceType_ConfigGroups): + cfg[res.ResID] = "" + } + } + + return &strategyForStore{ + ID: strategy.ID, + Name: strategy.Name, + Action: strategy.Action, + Comment: strategy.Comment, + Users: users, + Groups: groups, + Default: strategy.Default, + Owner: strategy.Owner, + NsResources: ns, + SvcResources: svc, + CfgResources: cfg, + Valid: strategy.Valid, + Revision: strategy.Revision, + CreateTime: strategy.CreateTime, + ModifyTime: strategy.ModifyTime, } } -func convertForStrategyDetail(strategy *strategyData) *authcommon.StrategyDetail { +func convertForStrategyDetail(strategy *strategyForStore) *authcommon.StrategyDetail { principals := make([]authcommon.Principal, 0, len(strategy.Users)+len(strategy.Groups)) + resources := make([]authcommon.StrategyResource, 0, len(strategy.NsResources)+ + len(strategy.SvcResources)+len(strategy.CfgResources)) for id := range strategy.Users { principals = append(principals, authcommon.Principal{ @@ -697,13 +842,31 @@ func convertForStrategyDetail(strategy *strategyData) *authcommon.StrategyDetail }) } - ret := &authcommon.StrategyDetail{ + fillRes := func(idMap map[string]string, resType apisecurity.ResourceType) []authcommon.StrategyResource { + res := make([]authcommon.StrategyResource, 0, len(idMap)) + + for id := range idMap { + res = append(res, authcommon.StrategyResource{ + StrategyID: strategy.ID, + ResType: int32(resType), + ResID: id, + }) + } + + return res + } + + resources = append(resources, fillRes(strategy.NsResources, apisecurity.ResourceType_Namespaces)...) + resources = append(resources, fillRes(strategy.SvcResources, apisecurity.ResourceType_Services)...) + resources = append(resources, fillRes(strategy.CfgResources, apisecurity.ResourceType_ConfigGroups)...) + + return &authcommon.StrategyDetail{ ID: strategy.ID, Name: strategy.Name, Action: strategy.Action, Comment: strategy.Comment, Principals: principals, - Resources: strategy.GetResources(), + Resources: resources, Default: strategy.Default, Owner: strategy.Owner, Valid: strategy.Valid, @@ -711,18 +874,6 @@ func convertForStrategyDetail(strategy *strategyData) *authcommon.StrategyDetail CreateTime: strategy.CreateTime, ModifyTime: strategy.ModifyTime, } - - if len(strategy.CalleeFunctions) != 0 { - functions := make([]string, 0, 4) - _ = json.Unmarshal([]byte(strategy.CalleeFunctions), &functions) - ret.CalleeMethods = functions - } - if len(strategy.Conditions) != 0 { - condition := make([]authcommon.Condition, 0, 4) - _ = json.Unmarshal([]byte(strategy.Conditions), &condition) - ret.Conditions = condition - } - return ret } func initStrategy(rule *authcommon.StrategyDetail) { diff --git a/store/boltdb/transaction.go b/store/boltdb/transaction.go index ca818103e..60780cbdf 100644 --- a/store/boltdb/transaction.go +++ b/store/boltdb/transaction.go @@ -18,8 +18,6 @@ package boltdb import ( - "time" - "github.com/polarismesh/polaris/common/model" ) @@ -61,12 +59,7 @@ func (t *transaction) RLockNamespace(name string) (*model.Namespace, error) { // DeleteNamespace 删除namespace func (t *transaction) DeleteNamespace(name string) error { - properties := map[string]interface{}{ - CommonFieldValid: false, - CommonFieldModifyTime: time.Now(), - } - - return t.handler.UpdateValue(tblNameNamespace, name, properties) + return t.handler.DeleteValues(tblNameNamespace, []string{name}) } const ( diff --git a/store/boltdb/user.go b/store/boltdb/user.go index 5e8eeb804..cc386a404 100644 --- a/store/boltdb/user.go +++ b/store/boltdb/user.go @@ -90,6 +90,7 @@ func (us *userStore) AddUser(tx store.Tx, user *authcommon.User) error { if owner == "" { owner = user.ID } + // 添加用户信息 if err := saveValue(dbTx, tblUser, user.ID, converToUserStore(user)); err != nil { log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) @@ -430,15 +431,8 @@ func (us *userStore) getGroupUsers(filters map[string]string, offset uint32, lim // GetUsersForCache 获取所有用户信息 func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.User, error) { - fields := []string{UserFieldModifyTime, UserFieldValid} - ret, err := us.handler.LoadValuesByFilter(tblUser, fields, &userForStore{}, + ret, err := us.handler.LoadValuesByFilter(tblUser, []string{UserFieldModifyTime}, &userForStore{}, func(m map[string]interface{}) bool { - if firstUpdate { - valid, _ := m[UserFieldValid].(bool) - if !valid { - return false - } - } mt := m[UserFieldModifyTime].(time.Time) isBefore := mt.Before(mtime) return !isBefore diff --git a/store/mock/api_mock.go b/store/mock/api_mock.go index 57ffee0af..3b397247d 100644 --- a/store/mock/api_mock.go +++ b/store/mock/api_mock.go @@ -865,17 +865,17 @@ func (mr *MockStoreMockRecorder) DeleteRateLimit(limiting interface{}) *gomock.C } // DeleteRole mocks base method. -func (m *MockStore) DeleteRole(tx store.Tx, role *auth.Role) error { +func (m *MockStore) DeleteRole(role *auth.Role) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRole", tx, role) + ret := m.ctrl.Call(m, "DeleteRole", role) ret0, _ := ret[0].(error) return ret0 } // DeleteRole indicates an expected call of DeleteRole. -func (mr *MockStoreMockRecorder) DeleteRole(tx, role interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteRole(role interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRole", reflect.TypeOf((*MockStore)(nil).DeleteRole), tx, role) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRole", reflect.TypeOf((*MockStore)(nil).DeleteRole), role) } // DeleteRoutingConfig mocks base method. @@ -1848,21 +1848,6 @@ func (mr *MockStoreMockRecorder) GetRateLimitsForCache(mtime, firstUpdate interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRateLimitsForCache", reflect.TypeOf((*MockStore)(nil).GetRateLimitsForCache), mtime, firstUpdate) } -// GetRole mocks base method. -func (m *MockStore) GetRole(id string) (*auth.Role, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRole", id) - ret0, _ := ret[0].(*auth.Role) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRole indicates an expected call of GetRole. -func (mr *MockStoreMockRecorder) GetRole(id interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockStore)(nil).GetRole), id) -} - // GetRoutingConfigV2WithID mocks base method. func (m *MockStore) GetRoutingConfigV2WithID(id string) (*model.RouterConfig, error) { m.ctrl.T.Helper() diff --git a/store/mysql/admin.go b/store/mysql/admin.go index 658c9d9cf..b3e775530 100644 --- a/store/mysql/admin.go +++ b/store/mysql/admin.go @@ -468,7 +468,7 @@ func (m *adminStore) BatchCleanDeletedInstances(timeout time.Duration, batchSize } cleanCheckStr := fmt.Sprintf("delete from health_check where id in (%s)", inSql) - if _, err = tx.Exec(cleanCheckStr, waitDelIds...); err != nil { + if _, err := tx.Exec(cleanCheckStr, waitDelIds...); err != nil { log.Errorf("[Store][database] batch clean soft deleted instances(%d), err: %s", batchSize, err.Error()) return store.Error(err) } @@ -487,7 +487,7 @@ func (m *adminStore) BatchCleanDeletedInstances(timeout time.Duration, batchSize return store.Error(err) } - if err = tx.Commit(); err != nil { + if err := tx.Commit(); err != nil { log.Errorf("[Store][database] batch clean soft deleted instances(%d) commit tx err: %s", batchSize, err.Error()) return err diff --git a/store/mysql/config_file.go b/store/mysql/config_file.go index 5854437e8..bc8c35d3d 100644 --- a/store/mysql/config_file.go +++ b/store/mysql/config_file.go @@ -300,7 +300,7 @@ func (cf *configFileStore) QueryConfigFiles(filter map[string]string, offset, li err = cf.slave.processWithTransaction("batch-load-file-tags", func(tx *BaseTx) error { for i := range files { item := files[i] - if err = cf.loadFileTags(tx, item); err != nil { + if err := cf.loadFileTags(tx, item); err != nil { return err } } diff --git a/store/mysql/config_file_release.go b/store/mysql/config_file_release.go index f2fe05457..55e511f5c 100644 --- a/store/mysql/config_file_release.go +++ b/store/mysql/config_file_release.go @@ -52,7 +52,7 @@ func (cfr *configFileReleaseStore) CreateConfigFileReleaseTx(tx store.Tx, data * } clean := "DELETE FROM config_file_release WHERE namespace = ? AND `group` = ? AND file_name = ? AND name = ? AND flag = 1" - if _, err = dbTx.Exec(clean, data.Namespace, data.Group, data.FileName, data.Name); err != nil { + if _, err := dbTx.Exec(clean, data.Namespace, data.Group, data.FileName, data.Name); err != nil { return store.Error(err) } diff --git a/store/mysql/default.go b/store/mysql/default.go index 3d2bb5b39..ae5d8c205 100644 --- a/store/mysql/default.go +++ b/store/mysql/default.go @@ -272,15 +272,12 @@ func (s *stableStore) newStore() { s.configFileTemplateStore = &configFileTemplateStore{master: s.master, slave: s.slave} s.clientStore = &clientStore{master: s.master, slave: s.slave} - s.grayStore = &grayStore{master: s.master, slave: s.slave} - s.adminStore = newAdminStore(s.master) s.toolStore = &toolStore{db: s.master} - s.userStore = &userStore{master: s.master, slave: s.slave} s.groupStore = &groupStore{master: s.master, slave: s.slave} s.strategyStore = &strategyStore{master: s.master, slave: s.slave} - s.roleStore = &roleStore{master: s.master, slave: s.slave} + s.grayStore = &grayStore{master: s.master, slave: s.slave} } func buildEtimeStr(enable bool) string { diff --git a/store/mysql/group.go b/store/mysql/group.go index 93df9d99a..938a10087 100644 --- a/store/mysql/group.go +++ b/store/mysql/group.go @@ -128,14 +128,14 @@ func (u *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { // 更新用户-用户组关联数据 if len(group.AddUserIds) != 0 { - if err = u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil { + if err := u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil { log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error()) return err } } if len(group.RemoveUserIds) != 0 { - if err = u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil { + if err := u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil { log.Errorf("[Store][Group] remove usergroup relation err: %s", err.Error()) return err } @@ -153,7 +153,7 @@ func (u *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { return err } - if err = tx.Commit(); err != nil { + if err := tx.Commit(); err != nil { log.Errorf("[Store][Group] update usergroup tx commit err: %s", err.Error()) return err } diff --git a/store/mysql/lane.go b/store/mysql/lane.go index 705f0b2ab..068cd3603 100644 --- a/store/mysql/lane.go +++ b/store/mysql/lane.go @@ -200,9 +200,9 @@ SELECT id, name, rule, description, revision, flag, UNIX_TIMESTAMP(ctime), UNIX_ for k, v := range filter { switch k { case "name": - if pv, ok := utils.ParseWildName(v); ok { + if v, ok := utils.ParseWildName(v); ok { conditions = append(conditions, "name = ?") - args = append(args, pv) + args = append(args, v) } else { conditions = append(conditions, "name LIKE ?") args = append(args, "%"+v+"%") @@ -295,8 +295,8 @@ func (l *laneStore) getLaneRulesByGroup(tx *BaseTx, names []string) (map[string] } querySql := ` -SELECT id, name, group_name, rule, revision, priority, description, enable, flag, UNIX_TIMESTAMP(ctime), -UNIX_TIMESTAMP(etime), UNIX_TIMESTAMP(mtime) FROM lane_rule WHERE flag = 0 AND group_name IN (%s) +SELECT id, name, group_name, rule, revision, priority, description, enable, flag, UNIX_TIMESTAMP(ctime), UNIX_TIMESTAMP(etime), UNIX_TIMESTAMP(mtime) + FROM lane_rule WHERE flag = 0 AND group_name IN (%s) ` querySql = fmt.Sprintf(querySql, placeholders(len(names))) diff --git a/store/mysql/role.go b/store/mysql/role.go index dfe353626..97732eed9 100644 --- a/store/mysql/role.go +++ b/store/mysql/role.go @@ -21,11 +21,10 @@ import ( "encoding/json" "time" - "go.uber.org/zap" - authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" + "go.uber.org/zap" ) type roleStore struct { @@ -70,16 +69,16 @@ func (s *roleStore) savePrincipals(tx *BaseTx, role *authcommon.Role) error { return err } - insertTpl := "INSERT INTO auth_role_principal(role_id, principal_id, principal_role, extend_info) VALUES (?, ?, ?)" + insertTpl := "INSERT INTO auth_role_principal(role_id, principal_id, principal_role) VALUES (?, ?, ?)" for i := range role.Users { - args := []interface{}{role.ID, role.Users[i].PrincipalID, authcommon.PrincipalUser, utils.MustJson(role.Users[i].Extend)} + args := []interface{}{role.ID, role.Users[i].ID, authcommon.PrincipalUser} if _, err := tx.Exec(insertTpl, args...); err != nil { return err } } for i := range role.UserGroups { - args := []interface{}{role.ID, role.UserGroups[i].PrincipalID, authcommon.PrincipalGroup, utils.MustJson(role.UserGroups[i].Extend)} + args := []interface{}{role.ID, role.UserGroups[i].ID, authcommon.PrincipalGroup} if _, err := tx.Exec(insertTpl, args...); err != nil { return err } @@ -114,16 +113,18 @@ WHERE id = ? } // DeleteRole Delete a role -func (s *roleStore) DeleteRole(tx store.Tx, role *authcommon.Role) error { +func (s *roleStore) DeleteRole(role *authcommon.Role) error { if role.ID == "" { return store.NewStatusError(store.EmptyParamsErr, "role id is empty") } - dbTx := tx.GetDelegateTx().(*BaseTx) - if _, err := dbTx.Exec("UPDATE auth_role SET flag = 1 WHERE id = ?", role.ID); err != nil { - log.Error("[store][role] delete role", zap.String("name", role.Name), zap.Error(err)) - return store.Error(err) - } - return nil + err := s.master.processWithTransaction("delete_role", func(tx *BaseTx) error { + if _, err := tx.Exec("UPDATE auth_role SET flag = 1 WHERE id = ?", role.ID); err != nil { + log.Error("[store][role] delete role", zap.String("name", role.Name), zap.Error(err)) + return err + } + return nil + }) + return store.Error(err) } // CleanPrincipalRoles clean principal roles @@ -162,47 +163,6 @@ func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) er return nil } -func (s *roleStore) GetRole(id string) (*authcommon.Role, error) { - tx, err := s.master.Begin() - if err != nil { - return nil, store.Error(err) - } - - defer func() { _ = tx.Commit() }() - - querySql := "SELECT id, name, owner, source, role_type, comment, flag, metadata, UNIX_TIMESTAMP(ctime), " + - " UNIX_TIMESTAMP(mtime) FROM auth_role WHERE flag = 0 AND id = ?" - args := []interface{}{id} - - row := tx.QueryRow(querySql, args...) - var ( - ctime, mtime int64 - flag int16 - metadata string - ) - ret := &authcommon.Role{ - Metadata: map[string]string{}, - Users: make([]authcommon.Principal, 0, 4), - UserGroups: make([]authcommon.Principal, 0, 4), - } - - if err := row.Scan(&ret.ID, &ret.Name, &ret.Owner, &ret.Source, &ret.Type, &ret.Comment, - &flag, &metadata, &ctime, &mtime); err != nil { - log.Error("[store][role] fetch one record role info", zap.Error(err)) - return nil, store.Error(err) - } - - ret.CreateTime = time.Unix(ctime, 0) - ret.ModifyTime = time.Unix(mtime, 0) - ret.Valid = flag == 0 - _ = json.Unmarshal([]byte(metadata), &ret.Metadata) - - if err := s.fetchRolePrincipals(tx, ret); err != nil { - return nil, store.Error(err) - } - return ret, nil -} - // GetRole get more role for cache update func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { tx, err := s.slave.Begin() @@ -241,8 +201,8 @@ func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcomm ) ret := &authcommon.Role{ Metadata: map[string]string{}, - Users: make([]authcommon.Principal, 0, 4), - UserGroups: make([]authcommon.Principal, 0, 4), + Users: make([]*authcommon.User, 0, 4), + UserGroups: make([]*authcommon.UserGroup, 0, 4), } if err := rows.Scan(&ret.ID, &ret.Name, &ret.Owner, &ret.Source, &ret.Type, &ret.Comment, @@ -267,7 +227,7 @@ func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcomm } func (s *roleStore) fetchRolePrincipals(tx *BaseTx, role *authcommon.Role) error { - rows, err := tx.Query("SELECT role_id, principal_id, principal_role, extend_info FROM auth_role_principal WHERE rold_id = ?", role.ID) + rows, err := tx.Query("SELECT role_id, principal_id, principal_role FROM auth_role_principal WHERE rold_id = ?", role.ID) if err != nil { log.Error("[store][role] fetch role principals", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) @@ -278,28 +238,21 @@ func (s *roleStore) fetchRolePrincipals(tx *BaseTx, role *authcommon.Role) error for rows.Next() { var ( - roleID, principalID, extendStr string - principalRole int + roleID, principalID string + principalRole int ) - if err := rows.Scan(&roleID, &principalID, &principalRole, &extendStr); err != nil { + if err := rows.Scan(&roleID, &principalID, &principalRole); err != nil { log.Error("[store][role] fetch one record role principal", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) } - extend := map[string]string{} - _ = json.Unmarshal([]byte(extendStr), &extend) - if principalRole == int(authcommon.PrincipalUser) { - role.Users = append(role.Users, authcommon.Principal{ - PrincipalID: principalID, - PrincipalType: authcommon.PrincipalUser, - Extend: extend, + role.Users = append(role.Users, &authcommon.User{ + ID: principalID, }) } else { - role.UserGroups = append(role.UserGroups, authcommon.Principal{ - PrincipalID: principalID, - PrincipalType: authcommon.PrincipalGroup, - Extend: extend, + role.UserGroups = append(role.UserGroups, &authcommon.UserGroup{ + ID: principalID, }) } } diff --git a/store/mysql/scripts/delta/v1_18_1-v1_19_0.sql b/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql similarity index 73% rename from store/mysql/scripts/delta/v1_18_1-v1_19_0.sql rename to store/mysql/scripts/delta/v1_18_1-v1_18_2.sql index 04548adc0..49041073d 100644 --- a/store/mysql/scripts/delta/v1_18_1-v1_19_0.sql +++ b/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql @@ -1,19 +1,3 @@ -/* - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ /* 角色数据 */ CREATE TABLE `auth_role` ( diff --git a/store/mysql/scripts/polaris_server.sql b/store/mysql/scripts/polaris_server.sql index 4d4afb714..da5351911 100644 --- a/store/mysql/scripts/polaris_server.sql +++ b/store/mysql/scripts/polaris_server.sql @@ -649,7 +649,6 @@ CREATE TABLE `strategy_id` VARCHAR(128) NOT NULL COMMENT 'Strategy ID', `principal_id` VARCHAR(128) NOT NULL COMMENT 'Principal ID', `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group, 3 is Role', - `extend_info` TEXT COMMENT 'link principal extend info', PRIMARY KEY (`strategy_id`, `principal_id`, `principal_role`) ) ENGINE = InnoDB; @@ -689,7 +688,6 @@ CREATE TABLE `role_id` VARCHAR(128) NOT NULL COMMENT 'role id', `principal_id` VARCHAR(128) NOT NULL COMMENT 'principal id', `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group', - `extend_info` TEXT COMMENT 'link principal extend info', PRIMARY KEY (`role_id`, `principal_id`, `principal_role`) ) ENGINE = InnoDB; @@ -709,6 +707,104 @@ CREATE TABLE `auth_strategy_function` ( PRIMARY KEY (`strategy_id`, `function`) ) ENGINE = InnoDB; +-- Create a default master account, password is Polarismesh @ 2021 +INSERT INTO + `user` ( + `id`, + `name`, + `password`, + `source`, + `token`, + `token_enable`, + `user_type`, + `comment`, + `mobile`, + `email`, + `owner` + ) +VALUES + ( + '65e4789a6d5b49669adf1e9e8387549c', + 'polaris', + '$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum', + 'Polaris', + 'nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=', + 1, + 20, + 'default polaris admin account', + '12345678910', + '12345678910', + '' + ); + +-- Permissions policy inserted into Polaris-Admin +INSERT INTO + `auth_strategy` ( + `id`, + `name`, + `action`, + `owner`, + `comment`, + `default`, + `revision`, + `flag`, + `ctime`, + `mtime` + ) +VALUES + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + '(用户) polaris的默认策略', + 'READ_WRITE', + '65e4789a6d5b49669adf1e9e8387549c', + 'default admin', + 1, + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + SYSDATE (), + SYSDATE () + ); + +-- Sport rules inserted into Polaris-Admin to access +INSERT INTO + `auth_strategy_resource` ( + `strategy_id`, + `res_type`, + `res_id`, + `ctime`, + `mtime` + ) +VALUES + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + '*', + SYSDATE (), + SYSDATE () + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 1, + '*', + SYSDATE (), + SYSDATE () + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 2, + '*', + SYSDATE (), + SYSDATE () + ); + +-- Insert permission policies and association relationships for Polaris-Admin accounts +INSERT INTO + auth_principal (`strategy_id`, `principal_id`, `principal_role`) VALUE ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + '65e4789a6d5b49669adf1e9e8387549c', + 1 + ); + -- v1.8.0, support client info storage CREATE TABLE `client` ( @@ -787,9 +883,9 @@ VALUES }', 'json', 'Spring Cloud Gateway 染色规则', - NOW(), + NOW (), 'polaris', - NOW(), + NOW (), 'polaris' ); @@ -966,439 +1062,3 @@ CREATE TABLE PRIMARY KEY (`id`), UNIQUE KEY `name` (`group_name`, `name`) ) ENGINE = InnoDB; - - -/* 默认资源信息数据插入 */ - --- Create a default master account, password is Polarismesh @ 2021 -INSERT INTO - `user` ( - `id`, - `name`, - `password`, - `source`, - `token`, - `token_enable`, - `user_type`, - `comment`, - `mobile`, - `email`, - `owner` - ) -VALUES - ( - '65e4789a6d5b49669adf1e9e8387549c', - 'polaris', - '$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum', - 'Polaris', - 'nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=', - 1, - 20, - 'default polaris admin account', - '12345678910', - '12345678910', - '' - ); - --- Permissions policy inserted into Polaris-Admin -INSERT INTO - `auth_strategy` ( - `id`, - `name`, - `action`, - `owner`, - `comment`, - `default`, - `revision`, - `flag`, - `ctime`, - `mtime` - ) -VALUES - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - '(用户) polaris的默认策略', - 'READ_WRITE', - '65e4789a6d5b49669adf1e9e8387549c', - 'default admin', - 1, - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - sysdate(), - sysdate() - ); - --- Sport rules inserted into Polaris-Admin to access -INSERT INTO - `auth_strategy_resource` ( - `strategy_id`, - `res_type`, - `res_id`, - `ctime`, - `mtime` - ) -VALUES - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 1, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 2, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 3, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 4, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 5, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 6, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 7, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 20, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 21, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 22, - '*', - sysdate(), - sysdate() - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 23, - '*', - sysdate(), - sysdate() - ); - --- Insert permission policies and association relationships for Polaris-Admin accounts -INSERT INTO - auth_principal (`strategy_id`, `principal_id`, `principal_role`) VALUES ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - '65e4789a6d5b49669adf1e9e8387549c', - 1 - ); - -INSERT INTO - auth_strategy_function (`strategy_id`, `function`) VALUES ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - '*' - ); - -/* 默认的全局只读策略 */ -INSERT INTO - `auth_strategy` ( - `id`, - `name`, - `action`, - `owner`, - `comment`, - `default`, - `revision`, - `flag`, - `ctime`, - `mtime` - ) -VALUES - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - '全局只读策略', - 'ALLOW', - '65e4789a6d5b49669adf1e9e8387549c', - 'global resources read onyly', - 1, - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - sysdate(), - sysdate() - ); - -INSERT INTO - `auth_strategy_resource` ( - `strategy_id`, - `res_type`, - `res_id`, - `ctime`, - `mtime` - ) -VALUES - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 0, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 1, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 2, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 3, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 4, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 5, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 6, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 7, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 20, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 21, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 22, - '*', - sysdate(), - sysdate() - ), - ( - 'bfa04ae1e32a94fbca9ead86e1ecf581', - 23, - '*', - sysdate(), - sysdate() - ); - -INSERT INTO - auth_strategy_function (`strategy_id`, `function`) VALUES ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 'Describe*' - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 'List*' - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 'Get*' - ); - - -/* 默认的全局读写策略 */ -INSERT INTO - `auth_strategy` ( - `id`, - `name`, - `action`, - `owner`, - `comment`, - `default`, - `revision`, - `flag`, - `ctime`, - `mtime` - ) -VALUES - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - '全局读写策略', - 'ALLOW', - '65e4789a6d5b49669adf1e9e8387549c', - 'global resources read and write', - 1, - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - sysdate(), - sysdate() - ); - -INSERT INTO - `auth_strategy_resource` ( - `strategy_id`, - `res_type`, - `res_id`, - `ctime`, - `mtime` - ) -VALUES - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 0, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 1, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 2, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 3, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 4, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 5, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 6, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 7, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 20, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 21, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 22, - '*', - sysdate(), - sysdate() - ), - ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - 23, - '*', - sysdate(), - sysdate() - ); - -INSERT INTO - auth_strategy_function (`strategy_id`, `function`) VALUES ( - 'e3d86e1ecf5812bfa04ae1a94fbca9ea', - '*' - ); - diff --git a/store/mysql/strategy.go b/store/mysql/strategy.go index 23c8e096d..cd95c2fce 100644 --- a/store/mysql/strategy.go +++ b/store/mysql/strategy.go @@ -19,7 +19,6 @@ package sqldb import ( "database/sql" - "encoding/json" "fmt" "strings" "time" @@ -76,30 +75,23 @@ func (s *strategyStore) AddStrategy(tx store.Tx, strategy *authcommon.StrategyDe isDefault = 1 } - if err := s.addPolicyPrincipals(dbTx, strategy.ID, strategy.Principals); err != nil { + if err := s.addStrategyPrincipals(dbTx, strategy.ID, strategy.Principals); err != nil { log.Error("[Store][Strategy] add auth_strategy principals", zap.Error(err)) return err } - if err := s.addPolicyResources(dbTx, strategy.ID, strategy.Resources); err != nil { + + if err := s.addStrategyResources(dbTx, strategy.ID, strategy.Resources); err != nil { log.Error("[Store][Strategy] add auth_strategy resources", zap.Error(err)) return err } - if err := s.savePolicyFunctions(dbTx, strategy.ID, strategy.CalleeMethods); err != nil { - log.Error("[Store][Strategy] save auth_strategy functions", zap.Error(err)) - return err - } - if err := s.savePolicyConditions(dbTx, strategy.ID, strategy.Conditions); err != nil { - log.Error("[Store][Strategy] save auth_strategy conditions", zap.Error(err)) - return err - } // 保存策略主信息 saveMainSql := "INSERT INTO auth_strategy(`id`, `name`, `action`, `owner`, `comment`, `flag`, " + - " `default`, `revision`, `source`, `metadata`) VALUES (?,?,?,?,?,?,?,?,?,?)" + " `default`, `revision`) VALUES (?,?,?,?,?,?,?,?)" if _, err := dbTx.Exec(saveMainSql, []interface{}{ strategy.ID, strategy.Name, strategy.Action, strategy.Owner, strategy.Comment, - 0, isDefault, strategy.Revision, strategy.Source, utils.MustJson(strategy.Metadata)}..., + 0, isDefault, strategy.Revision}..., ); err != nil { log.Error("[Store][Strategy] add auth_strategy main info", zap.Error(err)) return err @@ -128,34 +120,25 @@ func (s *strategyStore) updateStrategy(strategy *authcommon.ModifyStrategyDetail defer func() { _ = tx.Rollback() }() // 调整 principal 信息 - if err = s.addPolicyPrincipals(tx, strategy.ID, strategy.AddPrincipals); err != nil { + if err := s.addStrategyPrincipals(tx, strategy.ID, strategy.AddPrincipals); err != nil { log.Errorf("[Store][Strategy] add strategy principal err: %s", err.Error()) return err } - if err = s.deletePolicyPrincipals(tx, strategy.ID, strategy.RemovePrincipals); err != nil { + if err := s.deleteStrategyPrincipals(tx, strategy.ID, strategy.RemovePrincipals); err != nil { log.Errorf("[Store][Strategy] remove strategy principal err: %s", err.Error()) return err } // 调整鉴权资源信息 - if err = s.addPolicyResources(tx, strategy.ID, strategy.AddResources); err != nil { + if err := s.addStrategyResources(tx, strategy.ID, strategy.AddResources); err != nil { log.Errorf("[Store][Strategy] add strategy resource err: %s", err.Error()) return err } - if err = s.deletePolicyResources(tx, strategy.ID, strategy.RemoveResources); err != nil { + if err := s.deleteStrategyResources(tx, strategy.ID, strategy.RemoveResources); err != nil { log.Errorf("[Store][Strategy] remove strategy resource err: %s", err.Error()) return err } - if err = s.savePolicyFunctions(tx, strategy.ID, strategy.CalleeMethods); err != nil { - log.Error("[Store][Strategy] save auth_strategy functions", zap.Error(err)) - return err - } - if err = s.savePolicyConditions(tx, strategy.ID, strategy.Conditions); err != nil { - log.Error("[Store][Strategy] save auth_strategy conditions", zap.Error(err)) - return err - } - // 保存策略主信息 saveMainSql := "UPDATE auth_strategy SET action = ?, comment = ?, mtime = sysdate() WHERE id = ?" if _, err = tx.Exec(saveMainSql, []interface{}{strategy.Action, strategy.Comment, strategy.ID}...); err != nil { @@ -163,7 +146,7 @@ func (s *strategyStore) updateStrategy(strategy *authcommon.ModifyStrategyDetail return err } - if err = tx.Commit(); err != nil { + if err := tx.Commit(); err != nil { log.Errorf("[Store][Strategy] update auth_strategy tx commit err: %s", err.Error()) return err } @@ -191,19 +174,21 @@ func (s *strategyStore) deleteStrategy(id string) error { defer func() { _ = tx.Rollback() }() - if _, err = tx.Exec("UPDATE auth_strategy SET flag = 1, mtime = sysdate() WHERE id = ?", id); err != nil { - return err - } - if _, err = tx.Exec("DELETE FROM auth_strategy_resource WHERE strategy_id = ?", id); err != nil { - return err - } - if _, err = tx.Exec("DELETE FROM auth_principal WHERE strategy_id = ?", id); err != nil { + if _, err = tx.Exec("UPDATE auth_strategy SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{ + id, + }...); err != nil { return err } - if _, err = tx.Exec("DELETE FROM auth_strategy_function WHERE strategy_id = ?", id); err != nil { + + if _, err = tx.Exec("DELETE FROM auth_strategy_resource WHERE strategy_id = ?", []interface{}{ + id, + }...); err != nil { return err } - if _, err = tx.Exec("DELETE FROM auth_strategy_label WHERE strategy_id = ?", id); err != nil { + + if _, err = tx.Exec("DELETE FROM auth_principal WHERE strategy_id = ?", []interface{}{ + id, + }...); err != nil { return err } @@ -214,77 +199,20 @@ func (s *strategyStore) deleteStrategy(id string) error { return nil } -// savePolicyFunctions -func (s *strategyStore) savePolicyFunctions(tx *BaseTx, id string, functions []string) error { - if len(functions) == 0 { - return nil - } - - if _, err := tx.Exec("DELETE FROM auth_strategy_function WHERE strategy_id = ?", id); err != nil { - return err - } - - savePrincipalSql := "INSERT IGNORE INTO auth_strategy_function(`strategy_id`, `function`) VALUES " - values := make([]string, 0) - args := make([]interface{}, 0) - - for i := range functions { - values = append(values, "(?,?)") - args = append(args, id, functions[i]) - } - - savePrincipalSql += strings.Join(values, ",") - - log.Debug("[Store][Strategy] save policy functions", zap.String("sql", savePrincipalSql), - zap.Any("args", args)) - - _, err := tx.Exec(savePrincipalSql, args...) - return err -} - -// savePolicyConditions -func (s *strategyStore) savePolicyConditions(tx *BaseTx, id string, conditions []authcommon.Condition) error { - if len(conditions) == 0 { - return nil - } - - if _, err := tx.Exec("DELETE FROM auth_strategy_label WHERE strategy_id = ?", id); err != nil { - return err - } - - savePrincipalSql := "INSERT IGNORE INTO auth_strategy_label(`strategy_id`, `key`, `value`, `compare_type`) VALUES " - values := make([]string, 0) - args := make([]interface{}, 0) - - for i := range conditions { - item := conditions[i] - values = append(values, "(?,?,?,?)") - args = append(args, id, item.Key, item.Value, item.CompareFunc) - } - - savePrincipalSql += strings.Join(values, ",") - - log.Debug("[Store][Strategy] save policy conditions", zap.String("sql", savePrincipalSql), - zap.Any("args", args)) - - _, err := tx.Exec(savePrincipalSql, args...) - return err -} - -// addPolicyPrincipals -func (s *strategyStore) addPolicyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { +// addStrategyPrincipals +func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { if len(principals) == 0 { return nil } - savePrincipalSql := "INSERT IGNORE INTO auth_principal(strategy_id, principal_id, principal_role, extend_info) VALUES " + savePrincipalSql := "INSERT IGNORE INTO auth_principal(strategy_id, principal_id, principal_role) VALUES " values := make([]string, 0) args := make([]interface{}, 0) for i := range principals { principal := principals[i] - values = append(values, "(?,?,?,?)") - args = append(args, id, principal.PrincipalID, principal.PrincipalType, utils.MustJson(principal.Extend)) + values = append(values, "(?,?,?)") + args = append(args, id, principal.PrincipalID, principal.PrincipalType) } savePrincipalSql += strings.Join(values, ",") @@ -296,8 +224,8 @@ func (s *strategyStore) addPolicyPrincipals(tx *BaseTx, id string, principals [] return err } -// deletePolicyPrincipals -func (s *strategyStore) deletePolicyPrincipals(tx *BaseTx, id string, +// deleteStrategyPrincipals +func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { if len(principals) == 0 { return nil @@ -317,8 +245,7 @@ func (s *strategyStore) deletePolicyPrincipals(tx *BaseTx, id string, return nil } -// addPolicyResources . -func (s *strategyStore) addPolicyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { +func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { if len(resources) == 0 { return nil } @@ -344,8 +271,7 @@ func (s *strategyStore) addPolicyResources(tx *BaseTx, id string, resources []au return err } -// deletePolicyResources . -func (s *strategyStore) deletePolicyResources(tx *BaseTx, id string, +func (s *strategyStore) deleteStrategyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { if len(resources) == 0 { @@ -354,9 +280,11 @@ func (s *strategyStore) deletePolicyResources(tx *BaseTx, id string, for i := range resources { resource := resources[i] + saveResSql := "DELETE FROM auth_strategy_resource WHERE strategy_id = ? AND res_id = ? AND res_type = ?" if _, err := tx.Exec( - saveResSql, []interface{}{resource.StrategyID, resource.ResID, resource.ResType}..., + saveResSql, + []interface{}{resource.StrategyID, resource.ResID, resource.ResType}..., ); err != nil { return err } @@ -420,7 +348,7 @@ func (s *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyR saveResSql = "DELETE FROM auth_strategy_resource WHERE res_id = ? AND res_type = ?" args = append(args, resource.ResID, resource.ResType) } - if _, err = tx.Exec(saveResSql, args...); err != nil { + if _, err := tx.Exec(saveResSql, args...); err != nil { return err } // 主要是为了能够触发 StrategyCache 的刷新逻辑 @@ -430,7 +358,7 @@ func (s *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyR } } - if err = tx.Commit(); err != nil { + if err := tx.Commit(); err != nil { log.Errorf("[Store][Strategy] add auth_strategy tx commit err: %s", err.Error()) return err } @@ -517,6 +445,168 @@ func (s *strategyStore) getStrategyDetail(row *sql.Row) (*authcommon.StrategyDet return ret, nil } +// GetStrategies 获取策略列表 +func (s *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, + []*authcommon.StrategyDetail, error) { + showDetail := filters["show_detail"] + delete(filters, "show_detail") + + filters["ag.flag"] = "0" + + return s.listStrategies(filters, offset, limit, showDetail == "true") +} + +// listStrategies +func (s *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, + showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { + + querySql := + `SELECT + ag.id, + ag.name, + ag.action, + ag.owner, + ag.comment, + ag.default, + ag.revision, + ag.flag, + UNIX_TIMESTAMP(ag.ctime), + UNIX_TIMESTAMP(ag.mtime) + FROM + ( + auth_strategy ag + LEFT JOIN auth_strategy_resource ar ON ag.id = ar.strategy_id + ) + LEFT JOIN auth_principal ap ON ag.id = ap.strategy_id ` + countSql := ` + SELECT COUNT(DISTINCT ag.id) + FROM + ( + auth_strategy ag + LEFT JOIN auth_strategy_resource ar ON ag.id = ar.strategy_id + ) + LEFT JOIN auth_principal ap ON ag.id = ap.strategy_id + ` + + return s.queryStrategies(s.master.Query, filters, RuleFilters, querySql, countSql, + offset, limit, showDetail) +} + +// queryStrategies 通用的查询策略列表 +func (s *strategyStore) queryStrategies( + handler QueryHandler, + filters map[string]string, mapping map[string]string, + querySqlPrefix string, countSqlPrefix string, + offset uint32, limit uint32, showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { + querySql := querySqlPrefix + countSql := countSqlPrefix + + args := make([]interface{}, 0) + if len(filters) != 0 { + querySql += " WHERE " + countSql += " WHERE " + firstIndex := true + for k, v := range filters { + needLike := false + if !firstIndex { + querySql += " AND " + countSql += " AND " + } + firstIndex = false + + if val, ok := mapping[k]; ok { + if _, exist := RuleNeedLikeFilters[k]; exist { + needLike = true + } + k = val + } + + if needLike { + if utils.IsPrefixWildName(v) { + v = v[:len(v)-1] + } + querySql += (" " + k + " like ? ") + countSql += (" " + k + " like ? ") + args = append(args, "%"+v+"%") + } else if k == "ag.owner" { + querySql += " (ag.owner = ? OR (ap.principal_id = ? AND ap.principal_role = 1 )) " + countSql += " (ag.owner = ? OR (ap.principal_id = ? AND ap.principal_role = 1 )) " + args = append(args, v, v) + } else { + querySql += (" " + k + " = ? ") + countSql += (" " + k + " = ? ") + args = append(args, v) + } + } + } + + count, err := queryEntryCount(s.master, countSql, args) + if err != nil { + return 0, nil, store.Error(err) + } + + querySql += " GROUP BY ag.id ORDER BY ag.mtime LIMIT ?, ? " + args = append(args, offset, limit) + + ret, err := s.collectStrategies(s.master.Query, querySql, args, showDetail) + if err != nil { + return 0, nil, err + } + + return count, ret, nil +} + +// collectStrategies 执行真正的 sql 并从 rows 中获取策略列表 +func (s *strategyStore) collectStrategies(handler QueryHandler, querySql string, + args []interface{}, showDetail bool) ([]*authcommon.StrategyDetail, error) { + log.Debug("[Store][Strategy] get simple strategies", zap.String("query sql", querySql), + zap.Any("args", args)) + + rows, err := handler(querySql, args...) + if err != nil { + log.Error("[Store][Strategy] get simple strategies", zap.String("query sql", querySql), + zap.Any("args", args)) + return nil, store.Error(err) + } + defer func() { + _ = rows.Close() + }() + + idMap := make(map[string]struct{}) + + ret := make([]*authcommon.StrategyDetail, 0, 16) + for rows.Next() { + detail, err := fetchRown2StrategyDetail(rows) + if err != nil { + return nil, store.Error(err) + } + + // 为了避免数据重复被加入到 slice 中,做一个 map 去重 + if _, ok := idMap[detail.ID]; ok { + continue + } + idMap[detail.ID] = struct{}{} + + if showDetail { + resArr, err := s.getStrategyResources(s.slave.Query, detail.ID) + if err != nil { + return nil, store.Error(err) + } + principals, err := s.getStrategyPrincipals(s.slave.Query, detail.ID) + if err != nil { + return nil, store.Error(err) + } + + detail.Resources = resArr + detail.Principals = principals + } + + ret = append(ret, detail) + } + + return ret, nil +} + func (s *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { tx, err := s.slave.Begin() if err != nil { @@ -556,19 +646,9 @@ func (s *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([] if err != nil { return nil, store.Error(err) } - conditions, err := s.getStrategyConditions(s.slave.Query, detail.ID) - if err != nil { - return nil, store.Error(err) - } - functions, err := s.getStrategyFunctions(s.slave.Query, detail.ID) - if err != nil { - return nil, store.Error(err) - } detail.Resources = resArr detail.Principals = principals - detail.CalleeMethods = functions - detail.Conditions = conditions ret = append(ret, detail) } @@ -613,7 +693,7 @@ func (s *strategyStore) GetStrategyResources(principalId string, func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id string) ([]authcommon.Principal, error) { - rows, err := queryHander("SELECT principal_id, principal_role, extend_info FROM auth_principal WHERE strategy_id = ?", id) + rows, err := queryHander("SELECT principal_id, principal_role FROM auth_principal WHERE strategy_id = ?", id) if err != nil { switch err { case sql.ErrNoRows: @@ -629,72 +709,15 @@ func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id strin for rows.Next() { res := new(authcommon.Principal) - var extend string - if err := rows.Scan(&res.PrincipalID, &res.PrincipalType, &extend); err != nil { + if err := rows.Scan(&res.PrincipalID, &res.PrincipalType); err != nil { return nil, store.Error(err) } - res.Extend = map[string]string{} - _ = json.Unmarshal([]byte(extend), &res.Extend) principals = append(principals, *res) } return principals, nil } -func (s *strategyStore) getStrategyConditions(queryHander QueryHandler, id string) ([]authcommon.Condition, error) { - - rows, err := queryHander("SELECT `key`, `value`, `compare_type` FROM auth_strategy_label WHERE strategy_id = ?", id) - if err != nil { - switch err { - case sql.ErrNoRows: - log.Info("[Store][Strategy] not found link condition", zap.String("strategy-id", id)) - return nil, nil - default: - return nil, store.Error(err) - } - } - defer rows.Close() - - conditions := make([]authcommon.Condition, 0) - - for rows.Next() { - res := new(authcommon.Condition) - if err := rows.Scan(&res.Key, &res.Value, &res.CompareFunc); err != nil { - return nil, store.Error(err) - } - conditions = append(conditions, *res) - } - - return conditions, nil -} - -func (s *strategyStore) getStrategyFunctions(queryHander QueryHandler, id string) ([]string, error) { - - rows, err := queryHander("SELECT `function` FROM auth_strategy_label WHERE strategy_id = ?", id) - if err != nil { - switch err { - case sql.ErrNoRows: - log.Info("[Store][Strategy] not found link functions", zap.String("strategy-id", id)) - return nil, nil - default: - return nil, store.Error(err) - } - } - defer rows.Close() - - functions := make([]string, 0) - - for rows.Next() { - var item string - if err := rows.Scan(&item); err != nil { - return nil, store.Error(err) - } - functions = append(functions, item) - } - - return functions, nil -} - func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string) ([]authcommon.StrategyResource, error) { querySql := "SELECT res_id, res_type FROM auth_strategy_resource WHERE strategy_id = ?" rows, err := queryHander(querySql, id) diff --git a/store/mysql/tool.go b/store/mysql/tool.go index 55c3c4897..62a6fd52a 100644 --- a/store/mysql/tool.go +++ b/store/mysql/tool.go @@ -42,7 +42,7 @@ func (t *toolStore) GetUnixSecond(maxWait time.Duration) (int64, error) { defer rows.Close() timePass := time.Since(startTime) if maxWait != 0 && timePass > maxWait { - log.Warnf("[Store][database] query now spend %s, exceed %s, skip", timePass, maxWait) + log.Infof("[Store][database] query now spend %s, exceed %s, skip", timePass, maxWait) return 0, nil } var value int64 diff --git a/store/mysql/user.go b/store/mysql/user.go index cdc11f611..2a94f51ef 100644 --- a/store/mysql/user.go +++ b/store/mysql/user.go @@ -69,11 +69,19 @@ func (u *userStore) AddUser(tx store.Tx, user *authcommon.User) error { } func (u *userStore) addUser(tx *BaseTx, user *authcommon.User) error { + + tx, err := u.master.Begin() + if err != nil { + return err + } + + defer func() { _ = tx.Rollback() }() + addSql := "INSERT INTO user(`id`, `name`, `password`, `owner`, `source`, `token`, " + " `comment`, `flag`, `user_type`, " + " `ctime`, `mtime`, `mobile`, `email`) VALUES (?,?,?,?,?,?,?,?,?,sysdate(),sysdate(),?,?)" - _, err := tx.Exec(addSql, []interface{}{ + _, err = tx.Exec(addSql, []interface{}{ user.ID, user.Name, user.Password, diff --git a/test/data/bolt-data.yaml b/test/data/bolt-data.yaml deleted file mode 100644 index 2c7827204..000000000 --- a/test/data/bolt-data.yaml +++ /dev/null @@ -1,174 +0,0 @@ -# Tencent is pleased to support the open source community by making Polaris available. -# -# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software distributed -# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -# CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -users: - - name: polaris - token: nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I= - password: $2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum - id: 65e4789a6d5b49669adf1e9e8387549c - tokenenable: true - type: 20 - valid: true -policies: - - id: fbca9bfa04ae4ead86e1ecf5811e32a9 - name: (用户) polaris的默认策略 - action: READ_WRITE - comment: default admin - default: true - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["*"] - resources: - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 6 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 7 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 20 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 0 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 3 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 4 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 5 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 21 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 22 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 23 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 1 - resid: "*" - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - restype: 2 - resid: "*" - conditions: [] - principals: - - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 - principalid: 65e4789a6d5b49669adf1e9e8387549c - principaltype: 1 - valid: true - revision: fbca9bfa04ae4ead86e1ecf5811e32a9 - metadata: {} - - id: bfa04ae1e32a94fbca9ead86e1ecf581 - name: 全局只读策略 - action: ALLOW - comment: global resources read onyly - default: false - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["Describe*", "List*", "Get*"] - resources: - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 6 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 7 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 20 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 0 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 3 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 4 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 5 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 21 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 22 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 23 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 1 - resid: "*" - - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 - restype: 2 - resid: "*" - conditions: [] - principals: [] - valid: true - revision: 2a04ae4ead86e1e9bfacf59fbca811e3 - metadata: {} - - id: e3d86e1ecf5812bfa04ae1a94fbca9ea - name: 全局读写策略 - action: ALLOW - comment: global resources read and write - default: false - owner: 65e4789a6d5b49669adf1e9e8387549c - calleemethods: ["*"] - resources: - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 6 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 7 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 20 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 0 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 3 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 4 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 5 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 21 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 22 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 23 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 1 - resid: "*" - - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea - restype: 2 - resid: "*" - conditions: [] - principals: [] - valid: true - revision: 4ead86e1e9bfac2a04aef59fbca811e3 - metadata: {} diff --git a/test/integrate/client_grpc_test.go b/test/integrate/client_grpc_test.go index eaddaa625..fe40a5ad7 100644 --- a/test/integrate/client_grpc_test.go +++ b/test/integrate/client_grpc_test.go @@ -203,12 +203,12 @@ func TestClientGRPC_DiscoverServices(t *testing.T) { }, }) if err != nil { - t.Fatalf("discover services fail: %+v", err) + t.Fatalf("discover services fail") } - assert.False(t, len(resp.GetServices()) == 0, "discover services response not empty") - assert.Truef(t, len(newSvcs) == len(resp.GetServices()), - "discover services size not equal, expect : %d, actual : %s", len(newSvcs), len(resp.GetServices())) + assert.False(t, len(resp.Services) == 0, "discover services response not empty") + assert.Truef(t, len(newSvcs) == len(resp.Services), + "discover services size not equal, expect : %d, actual : %s", len(newSvcs), len(resp.Services)) }) }) diff --git a/test/integrate/http/client.go b/test/integrate/http/client.go index cac98ae8b..db868b94d 100644 --- a/test/integrate/http/client.go +++ b/test/integrate/http/client.go @@ -25,7 +25,6 @@ import ( "net/http" "github.com/golang/protobuf/jsonpb" - "github.com/google/uuid" apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" ) @@ -48,10 +47,6 @@ type Client struct { // SendRequest 发送请求 HTTP Post/Put func (c *Client) SendRequest(method string, url string, body *bytes.Buffer) (*http.Response, error) { - return c.SendRequestWithRequestID(uuid.New().String(), method, url, body) -} - -func (c *Client) SendRequestWithRequestID(requestId, method string, url string, body *bytes.Buffer) (*http.Response, error) { var request *http.Request var err error @@ -66,7 +61,7 @@ func (c *Client) SendRequestWithRequestID(requestId, method string, url string, } request.Header.Add("Content-Type", "application/json") - request.Header.Add("Request-Id", requestId) + request.Header.Add("Request-Id", "test") request.Header.Add("X-Polaris-Token", "nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=") response, err := c.Worker.Do(request) diff --git a/test/integrate/http/namespace.go b/test/integrate/http/namespace.go index 6cb00fa57..93fe85a80 100644 --- a/test/integrate/http/namespace.go +++ b/test/integrate/http/namespace.go @@ -64,7 +64,7 @@ func (c *Client) CreateNamespaces(namespaces []*apimodel.Namespace) (*apiservice return nil, err } - response, err := c.SendRequestWithRequestID("CreateNamespaces", "POST", url, body) + response, err := c.SendRequest("POST", url, body) if err != nil { fmt.Printf("%v\n", err) return nil, err @@ -179,7 +179,7 @@ func (c *Client) GetNamespaces(namespaces []*apimodel.Namespace) ([]*apimodel.Na } url = c.CompleteURL(url, params) - response, err := c.SendRequestWithRequestID("GetNamespaces", "GET", url, nil) + response, err := c.SendRequest("GET", url, nil) if err != nil { return nil, err } @@ -197,7 +197,7 @@ func (c *Client) GetNamespaces(namespaces []*apimodel.Namespace) ([]*apimodel.Na namespacesSize := len(namespaces) if ret.GetAmount() == nil || ret.GetAmount().GetValue() != uint32(namespacesSize) { - return nil, fmt.Errorf("invalid batch amount: %d %d", ret.GetAmount().GetValue(), namespacesSize) + return nil, errors.New("invalid batch amount") } if ret.GetSize() == nil || ret.GetSize().GetValue() != uint32(namespacesSize) { diff --git a/test/integrate/namespace_test.go b/test/integrate/namespace_test.go index 60352e7d3..361a8bdce 100644 --- a/test/integrate/namespace_test.go +++ b/test/integrate/namespace_test.go @@ -29,7 +29,6 @@ import ( v1 "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/test/integrate/http" "github.com/polarismesh/polaris/test/integrate/resource" ) @@ -57,7 +56,7 @@ func TestNamespace(t *testing.T) { // 查询命名空间 _, err = client.GetNamespaces(namespaces) if err != nil { - t.Fatalf("get namespaces: %#v fail: %s", utils.MustJson(namespaces), err.Error()) + t.Fatalf("get namespaces fail: %s", err.Error()) } t.Log("get namespaces success") diff --git a/test/suit/test_suit.go b/test/suit/test_suit.go index f51a98f95..00b8d3ff1 100644 --- a/test/suit/test_suit.go +++ b/test/suit/test_suit.go @@ -44,9 +44,7 @@ import ( "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" - "github.com/polarismesh/polaris/namespace" ns "github.com/polarismesh/polaris/namespace" - "github.com/polarismesh/polaris/namespace/interceptor" "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/service/batch" @@ -221,9 +219,6 @@ func (d *DiscoverTestSuit) loadConfig() error { fmt.Printf("[ERROR] %v\n", err) return err } - if os.Getenv("STORE_MODE") != "sqldb" { - d.cfg.Store.Option["loadFile"] = testdata.Path("bolt-data.yaml") - } d.cfg.Naming.Interceptors = service.GetChainOrder() d.cfg.Config.Interceptors = config.GetChainOrder() return err @@ -326,7 +321,7 @@ func (d *DiscoverTestSuit) initialize(opts ...options) error { } // 初始化命名空间模块 - namespaceSvr, err := TestNamespaceInitialize(ctx, &d.cfg.Namespace, d.Storage, cacheMgn, d.userMgn, d.strategyMgn) + namespaceSvr, err := ns.TestInitialize(ctx, &d.cfg.Namespace, d.Storage, cacheMgn, d.userMgn, d.strategyMgn) if err != nil { panic(err) } @@ -396,19 +391,6 @@ func (d *DiscoverTestSuit) initialize(opts ...options) error { return nil } -func TestNamespaceInitialize(ctx context.Context, nsOpt *namespace.Config, storage store.Store, cacheMgr *cache.CacheManager, - userMgn auth.UserServer, strategyMgn auth.StrategyServer) (namespace.NamespaceOperateServer, error) { - - ctx = context.WithValue(ctx, interceptor.ContextKeyUserSvr{}, userMgn) - ctx = context.WithValue(ctx, interceptor.ContextKeyPolicySvr{}, strategyMgn) - - _, proxySvr, err := namespace.InitServer(ctx, nsOpt, storage, cacheMgr) - if err != nil { - return nil, err - } - return proxySvr, nil -} - func (d *DiscoverTestSuit) Destroy() { d.cancel() if svr, ok := d.configOriginSvr.(*config.Server); ok { diff --git a/version b/version index e9c1a1883..13e94ce5c 100644 --- a/version +++ b/version @@ -1 +1 @@ -v1.19.0-alpha.0 \ No newline at end of file +v1.18.0 \ No newline at end of file