diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index 8ae24e36c..ad849bec5 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -61,13 +61,25 @@ jobs: with: fetch-depth: 2 + - uses: shogo82148/actions-setup-mysql@v1 + with: + mysql-version: "5.7" + auto-start: true + my-cnf: | + innodb_log_file_size=256MB + innodb_buffer_pool_size=512MB + max_allowed_packet=16MB + max_connections=50 + local_infile=1 + root-password: root + + - name: Initialize database env: MYSQL_DB_USER: root MYSQL_DB_PWD: root MYSQL_DATABASE: polaris_server run: | - sudo systemctl start mysql.service mysql -e 'CREATE DATABASE ${{ env.MYSQL_DATABASE }};' -u${{ env.MYSQL_DB_USER }} -p${{ env.MYSQL_DB_PWD }} mysql -e "ALTER USER '${{ env.MYSQL_DB_USER }}'@'localhost' IDENTIFIED WITH mysql_native_password BY 'root';" -u${{ env.MYSQL_DB_USER }} -p${{ env.MYSQL_DB_PWD }} diff --git a/.golangci.yml b/.golangci.yml index 3ef63cc8b..33340b6fb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -167,7 +167,7 @@ linters-settings: disabled: false - name: max-public-structs severity: warning - disabled: false + disabled: true arguments: [35] - name: indent-error-flow severity: warning @@ -281,7 +281,7 @@ linters-settings: govet: # Report about shadowed variables. # Default: false - check-shadowing: true + shadow: false # Settings per analyzer. settings: # Analyzer name, run `go tool vet help` to see all analyzers. diff --git a/admin/api.go b/admin/api.go index aec490231..8f717c105 100644 --- a/admin/api.go +++ b/admin/api.go @@ -25,6 +25,7 @@ import ( "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/admin" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) // AdminOperateServer Maintain related operation @@ -55,4 +56,6 @@ type AdminOperateServer interface { GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) // InitMainUser InitMainUser(ctx context.Context, user apisecurity.User) error + // GetServerFunctions Get server functions + GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup } diff --git a/admin/config.go b/admin/config.go index 891a68e3a..21bea1f52 100644 --- a/admin/config.go +++ b/admin/config.go @@ -23,7 +23,8 @@ import ( // Config maintain configuration type Config struct { - Jobs []job.JobConfig `yaml:"jobs"` + Jobs []job.JobConfig `yaml:"jobs"` + Interceptors []string `yaml:"-"` } func DefaultConfig() *Config { diff --git a/admin/default.go b/admin/default.go index c7b3ab59d..5ebf3643b 100644 --- a/admin/default.go +++ b/admin/default.go @@ -20,9 +20,9 @@ package admin import ( "context" "errors" + "fmt" "github.com/polarismesh/polaris/admin/job" - "github.com/polarismesh/polaris/auth" "github.com/polarismesh/polaris/cache" "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/service/healthcheck" @@ -30,11 +30,22 @@ import ( ) var ( - server AdminOperateServer - maintainServer = &Server{} - finishInit bool + server AdminOperateServer + maintainServer = &Server{} + finishInit bool + serverProxyFactories = map[string]ServerProxyFactory{} ) +type ServerProxyFactory func(ctx context.Context, pre AdminOperateServer) (AdminOperateServer, error) + +func RegisterServerProxy(name string, factor ServerProxyFactory) error { + if _, ok := serverProxyFactories[name]; ok { + return fmt.Errorf("duplicate ServerProxyFactory, name(%s)", name) + } + serverProxyFactories[name] = factor + return nil +} + // Initialize 初始化 func Initialize(ctx context.Context, cfg *Config, namingService service.DiscoverServer, healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) error { @@ -43,40 +54,49 @@ func Initialize(ctx context.Context, cfg *Config, namingService service.Discover return nil } - err := initialize(ctx, cfg, namingService, healthCheckServer, cacheMgn, storage) + proxySvr, actualSvr, err := InitServer(ctx, cfg, namingService, healthCheckServer, cacheMgn, storage) if err != nil { return err } + server = proxySvr + maintainServer = actualSvr finishInit = true return nil } -func initialize(_ context.Context, cfg *Config, namingService service.DiscoverServer, - healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) error { +func InitServer(ctx context.Context, cfg *Config, namingService service.DiscoverServer, + healthCheckServer *healthcheck.Server, cacheMgn *cache.CacheManager, storage store.Store) (AdminOperateServer, *Server, error) { - userMgn, err := auth.GetUserServer() - if err != nil { - return err - } + actualSvr := new(Server) - strategyMgn, err := auth.GetStrategyServer() - if err != nil { - return err - } - - maintainServer.namingServer = namingService - maintainServer.healthCheckServer = healthCheckServer - maintainServer.cacheMgn = cacheMgn - maintainServer.storage = storage + actualSvr.namingServer = namingService + actualSvr.healthCheckServer = healthCheckServer + actualSvr.cacheMgn = cacheMgn + actualSvr.storage = storage maintainJobs := job.NewMaintainJobs(namingService, cacheMgn, storage) if err := maintainJobs.StartMaintianJobs(cfg.Jobs); err != nil { - return err + return nil, nil, err } - server = newServerAuthAbility(maintainServer, userMgn, strategyMgn) - return nil + var proxySvr AdminOperateServer + proxySvr = actualSvr + order := GetChainOrder() + for i := range order { + factory, exist := serverProxyFactories[order[i]] + if !exist { + return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) + } + + afterSvr, err := factory(ctx, proxySvr) + if err != nil { + return nil, nil, err + } + proxySvr = afterSvr + } + + return proxySvr, actualSvr, nil } // GetServer 获取已经初始化好的Server diff --git a/admin/interceptor/auth/log.go b/admin/interceptor/auth/log.go new file mode 100644 index 000000000..9ff5f1c28 --- /dev/null +++ b/admin/interceptor/auth/log.go @@ -0,0 +1,24 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + commonlog "github.com/polarismesh/polaris/common/log" +) + +var log = commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName) diff --git a/admin/interceptor/auth/server.go b/admin/interceptor/auth/server.go new file mode 100644 index 000000000..a5e67c578 --- /dev/null +++ b/admin/interceptor/auth/server.go @@ -0,0 +1,214 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + "context" + + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + "github.com/polarismesh/polaris/admin" + "github.com/polarismesh/polaris/auth" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + admincommon "github.com/polarismesh/polaris/common/model/admin" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" +) + +var _ admin.AdminOperateServer = (*Server)(nil) + +// Server 带有鉴权能力的 maintainServer +type Server struct { + nextSvr admin.AdminOperateServer + userSvr auth.UserServer + policySvr auth.StrategyServer +} + +func NewServer(nextSvr admin.AdminOperateServer, + userSvr auth.UserServer, policySvr auth.StrategyServer) admin.AdminOperateServer { + proxy := &Server{ + nextSvr: nextSvr, + userSvr: userSvr, + policySvr: policySvr, + } + + return proxy +} + +func (svr *Server) collectMaintainAuthContext(ctx context.Context, resourceOp authcommon.ResourceOperation, + methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.MaintainModule), + authcommon.WithMethod(methodName), + ) +} + +func (s *Server) HasMainUser(ctx context.Context) (bool, error) { + return false, nil +} + +func (s *Server) InitMainUser(ctx context.Context, user apisecurity.User) error { + return nil +} + +func (svr *Server) GetServerConnections(ctx context.Context, req *admincommon.ConnReq) (*admincommon.ConnCountResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnections) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.GetServerConnections(ctx, req) +} + +func (svr *Server) GetServerConnStats(ctx context.Context, req *admincommon.ConnReq) (*admincommon.ConnStatsResp, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnStats) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.GetServerConnStats(ctx, req) +} + +func (svr *Server) CloseConnections(ctx context.Context, reqs []admincommon.ConnReq) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CloseConnections) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.CloseConnections(ctx, reqs) +} + +func (svr *Server) FreeOSMemory(ctx context.Context) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.FreeOSMemory) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.FreeOSMemory(ctx) +} + +func (svr *Server) CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CleanInstance) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.CleanInstance(ctx, req) +} + +func (svr *Server) BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.BatchCleanInstances) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return 0, err + } + + return svr.nextSvr.BatchCleanInstances(ctx, batchSize) +} + +func (svr *Server) GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeInstanceLastHeartbeat) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.GetLastHeartbeat(ctx, req) +} + +func (svr *Server) GetLogOutputLevel(ctx context.Context) ([]admincommon.ScopeLevel, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeGetLogOutputLevel) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.GetLogOutputLevel(ctx) +} + +func (svr *Server) SetLogOutputLevel(ctx context.Context, scope string, level string) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.UpdateLogOutputLevel) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + return svr.nextSvr.SetLogOutputLevel(ctx, scope, level) +} + +func (svr *Server) ListLeaderElections(ctx context.Context) ([]*admincommon.LeaderElection, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeLeaderElections) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.ListLeaderElections(ctx) +} + +func (svr *Server) ReleaseLeaderElection(ctx context.Context, electKey string) error { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.ReleaseLeaderElection) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.ReleaseLeaderElection(ctx, electKey) +} + +func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { + authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeCMDBInfo) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return nil, err + } + + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + return svr.nextSvr.GetCMDBInfo(ctx) +} + +// GetServerFunctions . +func (svr *Server) GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup { + return svr.nextSvr.GetServerFunctions(ctx) +} diff --git a/admin/interceptor/register.go b/admin/interceptor/register.go new file mode 100644 index 000000000..d151f615d --- /dev/null +++ b/admin/interceptor/register.go @@ -0,0 +1,67 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package inteceptor + +import ( + "context" + + "github.com/polarismesh/polaris/admin" + admin_auth "github.com/polarismesh/polaris/admin/interceptor/auth" + "github.com/polarismesh/polaris/auth" +) + +type ( + ContextKeyUserSvr struct{} + ContextKeyPolicySvr struct{} +) + +func init() { + err := admin.RegisterServerProxy("auth", func(ctx context.Context, + pre admin.AdminOperateServer) (admin.AdminOperateServer, error) { + + var userSvr auth.UserServer + var policySvr auth.StrategyServer + + userSvrVal := ctx.Value(ContextKeyUserSvr{}) + if userSvrVal == nil { + svr, err := auth.GetUserServer() + if err != nil { + return nil, err + } + userSvr = svr + } else { + userSvr = userSvrVal.(auth.UserServer) + } + + policySvrVal := ctx.Value(ContextKeyPolicySvr{}) + if policySvrVal == nil { + svr, err := auth.GetStrategyServer() + if err != nil { + return nil, err + } + policySvr = svr + } else { + policySvr = policySvrVal.(auth.StrategyServer) + } + + return admin_auth.NewServer(pre, userSvr, policySvr), nil + }) + if err != nil { + panic(err) + } +} diff --git a/admin/maintain.go b/admin/maintain.go index 208eded34..0b4fafa74 100644 --- a/admin/maintain.go +++ b/admin/maintain.go @@ -33,6 +33,7 @@ import ( commonlog "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/admin" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" @@ -166,11 +167,11 @@ func (s *Server) CleanInstance(ctx context.Context, req *apiservice.Instance) *a } if err := s.storage.CleanInstance(instanceID); err != nil { log.Error("Clean instance", - zap.String("err", err.Error()), utils.ZapRequestID(utils.ParseRequestID(ctx))) + zap.String("err", err.Error()), utils.RequestID(ctx)) return api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } - log.Info("Clean instance", utils.ZapRequestID(utils.ParseRequestID(ctx)), utils.ZapInstanceID(instanceID)) + log.Info("Clean instance", utils.RequestID(ctx), utils.ZapInstanceID(instanceID)) return api.NewInstanceResponse(apimodel.Code_ExecuteSuccess, req) } @@ -205,7 +206,6 @@ func (s *Server) ListLeaderElections(_ context.Context) ([]*admin.LeaderElection func (s *Server) ReleaseLeaderElection(_ context.Context, electKey string) error { return s.storage.ReleaseLeaderElection(electKey) - } func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { @@ -230,3 +230,8 @@ func (svr *Server) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error return ret, nil } + +// GetServerFunctions 获取服务端支持的功能列表 +func (svr *Server) GetServerFunctions(ctx context.Context) []authcommon.ServerFunctionGroup { + return authcommon.ServerFunctions +} diff --git a/admin/maintain_authability.go b/admin/maintain_authability.go deleted file mode 100644 index d90e14a56..000000000 --- a/admin/maintain_authability.go +++ /dev/null @@ -1,179 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package admin - -import ( - "context" - - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - "github.com/polarismesh/polaris/common/model/admin" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" -) - -var _ AdminOperateServer = (*serverAuthAbility)(nil) - -func (s *serverAuthAbility) HasMainUser(ctx context.Context) (bool, error) { - return false, nil -} - -func (s *serverAuthAbility) InitMainUser(ctx context.Context, user apisecurity.User) error { - return nil -} - -func (svr *serverAuthAbility) GetServerConnections(ctx context.Context, req *admin.ConnReq) (*admin.ConnCountResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnections) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.GetServerConnections(ctx, req) -} - -func (svr *serverAuthAbility) GetServerConnStats(ctx context.Context, req *admin.ConnReq) (*admin.ConnStatsResp, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeServerConnStats) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.GetServerConnStats(ctx, req) -} - -func (svr *serverAuthAbility) CloseConnections(ctx context.Context, reqs []admin.ConnReq) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CloseConnections) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.CloseConnections(ctx, reqs) -} - -func (svr *serverAuthAbility) FreeOSMemory(ctx context.Context) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.FreeOSMemory) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.FreeOSMemory(ctx) -} - -func (svr *serverAuthAbility) CleanInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.CleanInstance) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.CleanInstance(ctx, req) -} - -func (svr *serverAuthAbility) BatchCleanInstances(ctx context.Context, batchSize uint32) (uint32, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Delete, authcommon.BatchCleanInstances) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return 0, err - } - - return svr.targetServer.BatchCleanInstances(ctx, batchSize) -} - -func (svr *serverAuthAbility) GetLastHeartbeat(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeInstanceLastHeartbeat) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.GetLastHeartbeat(ctx, req) -} - -func (svr *serverAuthAbility) GetLogOutputLevel(ctx context.Context) ([]admin.ScopeLevel, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeGetLogOutputLevel) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.GetLogOutputLevel(ctx) -} - -func (svr *serverAuthAbility) SetLogOutputLevel(ctx context.Context, scope string, level string) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.UpdateLogOutputLevel) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - return svr.targetServer.SetLogOutputLevel(ctx, scope, level) -} - -func (svr *serverAuthAbility) ListLeaderElections(ctx context.Context) ([]*admin.LeaderElection, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeLeaderElections) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.ListLeaderElections(ctx) -} - -func (svr *serverAuthAbility) ReleaseLeaderElection(ctx context.Context, electKey string) error { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Modify, authcommon.ReleaseLeaderElection) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.ReleaseLeaderElection(ctx, electKey) -} - -func (svr *serverAuthAbility) GetCMDBInfo(ctx context.Context) ([]model.LocationView, error) { - authCtx := svr.collectMaintainAuthContext(ctx, authcommon.Read, authcommon.DescribeCMDBInfo) - if _, err := svr.strategyMgn.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return nil, err - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.GetCMDBInfo(ctx) -} diff --git a/admin/server.go b/admin/server.go index e2e682f76..9ab899f00 100644 --- a/admin/server.go +++ b/admin/server.go @@ -35,3 +35,9 @@ type Server struct { cacheMgn *cache.CacheManager storage store.Store } + +func GetChainOrder() []string { + return []string{ + "auth", + } +} diff --git a/admin/server_authability.go b/admin/server_authability.go deleted file mode 100644 index 2ddfbcbde..000000000 --- a/admin/server_authability.go +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package admin - -import ( - "context" - "errors" - - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - - "github.com/polarismesh/polaris/auth" - authcommon "github.com/polarismesh/polaris/common/model/auth" -) - -// serverAuthAbility 带有鉴权能力的 maintainServer -type serverAuthAbility struct { - targetServer *Server - userMgn auth.UserServer - strategyMgn auth.StrategyServer -} - -func newServerAuthAbility(targetServer *Server, - userMgn auth.UserServer, strategyMgn auth.StrategyServer) AdminOperateServer { - proxy := &serverAuthAbility{ - targetServer: targetServer, - userMgn: userMgn, - strategyMgn: strategyMgn, - } - - return proxy -} - -func (svr *serverAuthAbility) collectMaintainAuthContext(ctx context.Context, resourceOp authcommon.ResourceOperation, - methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { - return authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), - authcommon.WithModule(authcommon.MaintainModule), - authcommon.WithMethod(methodName), - ) -} - -func convertToErrCode(err error) apimodel.Code { - if errors.Is(err, authcommon.ErrorTokenNotExist) { - return apimodel.Code_TokenNotExisted - } - - if errors.Is(err, authcommon.ErrorTokenDisabled) { - return apimodel.Code_TokenDisabled - } - - return apimodel.Code_NotAllowedAccess -} diff --git a/apiserver/httpserver/admin_access.go b/apiserver/httpserver/admin_access.go index 80589b46a..22784edac 100644 --- a/apiserver/httpserver/admin_access.go +++ b/apiserver/httpserver/admin_access.go @@ -67,6 +67,7 @@ func (h *HTTPServer) GetAdminAccessServer() *restful.WebService { ws.Route(docs.EnrichGetCMDBInfoApiDocs(ws.GET("/cmdb/info").To(h.GetCMDBInfo))) ws.Route(docs.EnrichGetReportClientsApiDocs(ws.GET("/report/clients").To(h.GetReportClients))) ws.Route(docs.EnrichEnablePprofApiDocs(ws.POST("/pprof/enable").To(h.EnablePprof))) + ws.Route(docs.EnrichGetServerFunctionsApiDocs(ws.GET("/server/functions").To(h.GetServerFunctions))) return ws } @@ -308,6 +309,13 @@ func (h *HTTPServer) EnablePprof(req *restful.Request, rsp *restful.Response) { _ = rsp.WriteEntity("ok") } +// GetServerFunctions . +func (h *HTTPServer) GetServerFunctions(req *restful.Request, rsp *restful.Response) { + ctx := initContext(req) + ret := h.maintainServer.GetServerFunctions(ctx) + _ = rsp.WriteAsJson(ret) +} + func initContext(req *restful.Request) context.Context { ctx := context.Background() diff --git a/apiserver/httpserver/auth_access.go b/apiserver/httpserver/auth_access.go index 15c5fc0e2..ce2a064f2 100644 --- a/apiserver/httpserver/auth_access.go +++ b/apiserver/httpserver/auth_access.go @@ -35,7 +35,7 @@ import ( // GetAuthServer 运维接口 func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichAuthStatusApiDocs(ws.GET("/auth/status").To(h.AuthStatus))) - // + // 用户 ws.Route(docs.EnrichLoginApiDocs(ws.POST("/user/login").To(h.Login))) ws.Route(docs.EnrichGetUsersApiDocs(ws.GET("/users").To(h.GetUsers))) ws.Route(docs.EnrichCreateUsersApiDocs(ws.POST("/users").To(h.CreateUsers))) @@ -45,7 +45,8 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichGetUserTokenApiDocs(ws.GET("/user/token").To(h.GetUserToken))) ws.Route(docs.EnrichUpdateUserTokenApiDocs(ws.PUT("/user/token/status").To(h.EnableUserToken))) ws.Route(docs.EnrichResetUserTokenApiDocs(ws.PUT("/user/token/refresh").To(h.ResetUserToken))) - // + + // 用户组 ws.Route(docs.EnrichCreateGroupApiDocs(ws.POST("/usergroup").To(h.CreateGroup))) ws.Route(docs.EnrichUpdateGroupsApiDocs(ws.PUT("/usergroups").To(h.UpdateGroups))) ws.Route(docs.EnrichGetGroupsApiDocs(ws.GET("/usergroups").To(h.GetGroups))) @@ -55,6 +56,7 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichUpdateGroupTokenApiDocs(ws.PUT("/usergroup/token/status").To(h.EnableGroupToken))) ws.Route(docs.EnrichResetGroupTokenApiDocs(ws.PUT("/usergroup/token/refresh").To(h.ResetGroupToken))) + // 鉴权策略 ws.Route(docs.EnrichCreateStrategyApiDocs(ws.POST("/auth/strategy").To(h.CreateStrategy))) ws.Route(docs.EnrichGetStrategyApiDocs(ws.GET("/auth/strategy/detail").To(h.GetStrategy))) ws.Route(docs.EnrichUpdateStrategiesApiDocs(ws.PUT("/auth/strategies").To(h.UpdateStrategies))) @@ -62,6 +64,12 @@ func (h *HTTPServer) GetAuthServer(ws *restful.WebService) error { ws.Route(docs.EnrichGetStrategiesApiDocs(ws.GET("/auth/strategies").To(h.GetStrategies))) ws.Route(docs.EnrichGetPrincipalResourcesApiDocs(ws.GET("/auth/principal/resources").To(h.GetPrincipalResources))) + // 角色 + ws.Route(docs.EnrichGetRolesApiDocs(ws.GET("/roles").To(h.GetRoles))) + ws.Route(docs.EnrichCreateRolesApiDocs(ws.POST("/roles").To(h.CreateRoles))) + ws.Route(docs.EnrichDeleteRolesApiDocs(ws.POST("/roles/delete").To(h.DeleteRoles))) + ws.Route(docs.EnrichUpdateRolesApiDocs(ws.PUT("/roles").To(h.UpdateRoles))) + return nil } @@ -498,3 +506,79 @@ func (h *HTTPServer) GetPrincipalResources(req *restful.Request, rsp *restful.Re handler.WriteHeaderAndProto(h.strategyMgn.GetPrincipalResources(ctx, queryParams)) } + +// CreateRoles . +func (h *HTTPServer) CreateRoles(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + + roles := make([]*apisecurity.Role, 0, 4) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apisecurity.Role{} + roles = append(roles, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + handler.WriteHeaderAndProto(h.strategyMgn.CreateRoles(ctx, roles)) +} + +// UpdateRoles . +func (h *HTTPServer) UpdateRoles(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + + roles := make([]*apisecurity.Role, 0, 4) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apisecurity.Role{} + roles = append(roles, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + handler.WriteHeaderAndProto(h.strategyMgn.UpdateRoles(ctx, roles)) +} + +// DeleteRoles . +func (h *HTTPServer) DeleteRoles(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + + roles := make([]*apisecurity.Role, 0, 4) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apisecurity.Role{} + roles = append(roles, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + handler.WriteHeaderAndProto(h.strategyMgn.DeleteRoles(ctx, roles)) +} + +// GetRoles 查询角色列表 +func (h *HTTPServer) GetRoles(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + + queryParams := httpcommon.ParseQueryParams(req) + ctx := handler.ParseHeaderContext() + + handler.WriteHeaderAndProto(h.strategyMgn.GetRoles(ctx, queryParams)) +} diff --git a/apiserver/httpserver/config/client_access.go b/apiserver/httpserver/config/client_access.go index 7de65ebc3..ffade0a6b 100644 --- a/apiserver/httpserver/config/client_access.go +++ b/apiserver/httpserver/config/client_access.go @@ -168,7 +168,12 @@ func (h *HTTPServer) Discover(req *restful.Request, rsp *restful.Response) { switch in.Type { case apiconfig.ConfigDiscoverRequest_CONFIG_FILE: action = metrics.ActionGetConfigFile - ret := h.configServer.GetConfigFileWithCache(ctx, &apiconfig.ClientConfigFileInfo{}) + ret := h.configServer.GetConfigFileWithCache(ctx, &apiconfig.ClientConfigFileInfo{ + Version: in.GetConfigFile().Version, + Namespace: in.GetConfigFile().GetNamespace(), + Group: in.GetConfigFile().GetGroup(), + FileName: in.GetConfigFile().GetFileName(), + }) out = api.NewConfigDiscoverResponse(apimodel.Code(ret.GetCode().GetValue())) out.ConfigFile = ret.GetConfigFile() out.Type = apiconfig.ConfigDiscoverResponse_CONFIG_FILE diff --git a/apiserver/httpserver/discover/v1/console_access.go b/apiserver/httpserver/discover/v1/console_access.go index 12a136e6d..86e6dd5ac 100644 --- a/apiserver/httpserver/discover/v1/console_access.go +++ b/apiserver/httpserver/discover/v1/console_access.go @@ -1303,3 +1303,78 @@ func (h *HTTPServerV1) DeleteServiceContractInterfaces(req *restful.Request, rsp ret := h.namingServer.DeleteServiceContractInterfaces(ctx, msg) handler.WriteHeaderAndProto(ret) } + +// CreateLaneGroups 批量创建泳道组 +func (h *HTTPServerV1) CreateLaneGroups(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + groups := make([]*apitraffic.LaneGroup, 0) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apitraffic.LaneGroup{} + groups = append(groups, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + ret := h.namingServer.CreateLaneGroups(ctx, groups) + handler.WriteHeaderAndProto(ret) +} + +// UpdateLaneGroups 批量更新泳道组 +func (h *HTTPServerV1) UpdateLaneGroups(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + groups := make([]*apitraffic.LaneGroup, 0) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apitraffic.LaneGroup{} + groups = append(groups, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + ret := h.namingServer.UpdateLaneGroups(ctx, groups) + handler.WriteHeaderAndProto(ret) +} + +// DeleteLaneGroups 批量删除泳道组 +func (h *HTTPServerV1) DeleteLaneGroups(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + groups := make([]*apitraffic.LaneGroup, 0) + ctx, err := handler.ParseArray(func() proto.Message { + msg := &apitraffic.LaneGroup{} + groups = append(groups, msg) + return msg + }) + if err != nil { + handler.WriteHeaderAndProto(api.NewBatchWriteResponseWithMsg(apimodel.Code_ParseException, err.Error())) + return + } + + ret := h.namingServer.DeleteLaneGroups(ctx, groups) + handler.WriteHeaderAndProto(ret) +} + +// GetLaneGroups 批量删除泳道组 +func (h *HTTPServerV1) GetLaneGroups(req *restful.Request, rsp *restful.Response) { + handler := &httpcommon.Handler{ + Request: req, + Response: rsp, + } + queryParams := httpcommon.ParseQueryParams(req) + ctx := handler.ParseHeaderContext() + ret := h.namingServer.GetLaneGroups(ctx, queryParams) + handler.WriteHeaderAndProto(ret) +} diff --git a/apiserver/httpserver/discover/v1/server.go b/apiserver/httpserver/discover/v1/server.go index 8544e34d9..42df28ae3 100644 --- a/apiserver/httpserver/discover/v1/server.go +++ b/apiserver/httpserver/discover/v1/server.go @@ -93,6 +93,7 @@ func (h *HTTPServerV1) GetConsoleAccessServer(include []string) (*restful.WebSer h.addCircuitBreakerRuleAccess(ws) case routingAccess: h.addRoutingRuleAccess(ws) + h.addLaneRuleAccess(ws) case rateLimitAccess: h.addRateLimitRuleAccess(ws) } @@ -139,6 +140,7 @@ func (h *HTTPServerV1) addDefaultAccess(ws *restful.WebService) { // 管理端接口:增删改查请求全部操作存储层 h.addServiceAccess(ws) h.addRoutingRuleAccess(ws) + h.addLaneRuleAccess(ws) h.addRateLimitRuleAccess(ws) h.addCircuitBreakerRuleAccess(ws) } @@ -207,6 +209,14 @@ func (h *HTTPServerV1) addRoutingRuleAccess(ws *restful.WebService) { // Deprecate -- end } +// addLaneRuleAccess 泳道规则 +func (h *HTTPServerV1) addLaneRuleAccess(ws *restful.WebService) { + ws.Route(ws.POST("/lane/groups").To(h.CreateLaneGroups)) + ws.Route(ws.POST("/lane/groups/delete").To(h.DeleteLaneGroups)) + ws.Route(ws.PUT("/lane/groups").To(h.UpdateLaneGroups)) + ws.Route(ws.GET("/lane/groups").To(h.GetLaneGroups)) +} + func (h *HTTPServerV1) addRateLimitRuleAccess(ws *restful.WebService) { ws.Route(docs.EnrichCreateRateLimitsApiDocs(ws.POST("/ratelimits").To(h.CreateRateLimits))) ws.Route(docs.EnrichDeleteRateLimitsApiDocs(ws.POST("/ratelimits/delete").To(h.DeleteRateLimits))) diff --git a/apiserver/httpserver/docs/admin_apidoc.go b/apiserver/httpserver/docs/admin_apidoc.go index 928a663c2..810134d49 100644 --- a/apiserver/httpserver/docs/admin_apidoc.go +++ b/apiserver/httpserver/docs/admin_apidoc.go @@ -149,3 +149,10 @@ func EnrichEnablePprofApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { Enable bool `json:"enable"` }{}) } + +func EnrichGetServerFunctionsApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { + return r. + Doc("查询服务端的接口名称列表"). + Metadata(restfulspec.KeyOpenAPITags, maintainApiTags). + Returns(0, "", map[string][]string{}) +} diff --git a/apiserver/httpserver/docs/auth_apidoc.go b/apiserver/httpserver/docs/auth_apidoc.go index f25ba9a68..28b9dc11a 100644 --- a/apiserver/httpserver/docs/auth_apidoc.go +++ b/apiserver/httpserver/docs/auth_apidoc.go @@ -26,7 +26,8 @@ import ( var ( authApiTags = []string{"AuthRule"} usersApiTags = []string{"Users"} - userGroupApiTags = []string{"Users"} + userGroupApiTags = []string{"UserGroups"} + roleApiTags = []string{"Roles"} ) func EnrichAuthStatusApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { @@ -350,3 +351,43 @@ func EnrichResetGroupTokenApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder UserGroup apisecurity.UserGroup `json:"userGroup"` }{}) } + +func EnrichCreateRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { + return r. + Doc("批量创建角色"). + Metadata(restfulspec.KeyOpenAPITags, roleApiTags). + Reads([]apisecurity.Role{}, "batch create role"). + Returns(0, "", struct { + BaseResponse + }{}) +} + +func EnrichUpdateRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { + return r. + Doc("批量更新角色"). + Metadata(restfulspec.KeyOpenAPITags, roleApiTags). + Reads([]apisecurity.Role{}, "batch update role"). + Returns(0, "", struct { + BaseResponse + }{}) +} + +func EnrichDeleteRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { + return r. + Doc("批量删除角色"). + Metadata(restfulspec.KeyOpenAPITags, roleApiTags). + Reads([]apisecurity.Role{}, "batch delete role"). + Returns(0, "", struct { + BaseResponse + }{}) +} + +func EnrichGetRolesApiDocs(r *restful.RouteBuilder) *restful.RouteBuilder { + return r. + Doc("查询角色列表"). + Metadata(restfulspec.KeyOpenAPITags, roleApiTags). + Reads([]apisecurity.Role{}, "query roles"). + Returns(0, "", struct { + BaseResponse + }{}) +} diff --git a/apiserver/httpserver/utils/handler.go b/apiserver/httpserver/utils/handler.go index 4aa2ad794..02be2a841 100644 --- a/apiserver/httpserver/utils/handler.go +++ b/apiserver/httpserver/utils/handler.go @@ -71,15 +71,13 @@ func (h *Handler) ParseArrayByText(createMessage func() proto.Message, text stri func (h *Handler) parseArray(createMessage func() proto.Message, jsonDecoder *json.Decoder) (context.Context, error) { requestID := h.Request.HeaderParameter("Request-Id") // read open bracket - _, err := jsonDecoder.Token() - if err != nil { + if _, err := jsonDecoder.Token(); err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, err } for jsonDecoder.More() { protoMessage := createMessage() - err := UnmarshalNext(jsonDecoder, protoMessage) - if err != nil { + if err := UnmarshalNext(jsonDecoder, protoMessage); err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) return nil, err } @@ -286,7 +284,7 @@ func (h *Handler) WriteHeaderAndProto(obj api.ResponseMessage) { status := api.CalcCode(obj) if status != http.StatusOK { - accesslog.Error(obj.String(), utils.ZapRequestID(requestID)) + accesslog.Error(h.Request.Request.RequestURI+" "+obj.String(), utils.ZapRequestID(requestID)) } if code := obj.GetCode().GetValue(); code != api.ExecuteSuccess { h.Response.AddHeader(utils.PolarisCode, fmt.Sprintf("%d", code)) @@ -317,9 +315,8 @@ func (h *Handler) WriteHeaderAndProtoV2(obj api.ResponseMessageV2) { h.Response.AddHeader(utils.PolarisRequestID, requestID) h.Response.WriteHeader(status) - m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} - err := m.Marshal(h.Response, obj) - if err != nil { + m := newJsonpbMarshaler() + if err := m.Marshal(h.Response, obj); err != nil { accesslog.Error(err.Error(), utils.ZapRequestID(requestID)) } } @@ -380,14 +377,18 @@ func ParseJsonBody(req *restful.Request, value interface{}) error { return nil } +func newJsonpbMarshaler() jsonpb.Marshaler { + return jsonpb.Marshaler{Indent: " ", EmitDefaults: true} +} + func (h *Handler) handleResponse(obj api.ResponseMessage) error { if !enableProtoCache { - m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + m := newJsonpbMarshaler() return m.Marshal(h.Response, obj) } cacheVal := convert(obj) if cacheVal == nil { - m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + m := newJsonpbMarshaler() return m.Marshal(h.Response, obj) } if saveVal := protoCache.Get(cacheVal.CacheType, cacheVal.Key); saveVal != nil { @@ -401,7 +402,7 @@ func (h *Handler) handleResponse(obj api.ResponseMessage) error { if err := cacheVal.Marshal(obj); err != nil { accesslog.Warn("[Api-http][ProtoCache] prepare message fail, direct send msg", zap.String("key", cacheVal.Key), zap.Error(err)) - m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + m := newJsonpbMarshaler() return m.Marshal(h.Response, obj) } @@ -409,7 +410,7 @@ func (h *Handler) handleResponse(obj api.ResponseMessage) error { if !ok || cacheVal == nil { accesslog.Warn("[Api-http][ProtoCache] put cache ignore", zap.String("key", cacheVal.Key), zap.String("cacheType", cacheVal.CacheType)) - m := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + m := newJsonpbMarshaler() return m.Marshal(h.Response, obj) } if len(cacheVal.GetBuf()) > 0 { diff --git a/apiserver/nacosserver/core/storage.go b/apiserver/nacosserver/core/storage.go index c7826443f..746af9e6c 100644 --- a/apiserver/nacosserver/core/storage.go +++ b/apiserver/nacosserver/core/storage.go @@ -197,7 +197,7 @@ func (n *NacosDataStorage) syncTask() { // 计算需要 refresh 的服务信息列表 for _, ns := range nsList { - _, svcs := n.cacheMgr.Service().ListServices(ns.Name) + _, svcs := n.cacheMgr.Service().ListServices(context.Background(), ns.Name) for _, svc := range svcs { revision := n.cacheMgr.Service().GetRevisionWorker().GetServiceInstanceRevision(svc.ID) oldRevision, ok := n.revisions[svc.ID] diff --git a/apiserver/nacosserver/model/service.go b/apiserver/nacosserver/model/service.go index c3e8569c8..86666d307 100644 --- a/apiserver/nacosserver/model/service.go +++ b/apiserver/nacosserver/model/service.go @@ -18,6 +18,7 @@ package model import ( + "context" "strings" "github.com/polarismesh/polaris/service" @@ -43,7 +44,7 @@ type ServiceMetadata struct { func HandleServiceListRequest(discoverSvr service.DiscoverServer, namespace string, groupName string, pageNo int, pageSize int) ([]string, int) { - _, services := discoverSvr.Cache().Service().ListServices(namespace) + _, services := discoverSvr.Cache().Service().ListServices(context.Background(), namespace) offset := (pageNo - 1) * pageSize limit := pageSize if offset < 0 { diff --git a/apiserver/nacosserver/server.go b/apiserver/nacosserver/server.go index 913cc3281..43437ebac 100644 --- a/apiserver/nacosserver/server.go +++ b/apiserver/nacosserver/server.go @@ -148,7 +148,7 @@ func copyOption(m map[string]interface{}) map[string]interface{} { func (n *NacosServer) initPolarisResource() error { var err error - n.namespaceSvr, err = namespace.GetServer() + n.namespaceSvr, err = namespace.GetOriginServer() if err != nil { return err } diff --git a/apiserver/nacosserver/v1/config/access.go b/apiserver/nacosserver/v1/config/access.go index ed484fdda..0d3a6d65f 100644 --- a/apiserver/nacosserver/v1/config/access.go +++ b/apiserver/nacosserver/v1/config/access.go @@ -142,7 +142,7 @@ func (n *ConfigServer) ConfigImport(req *restful.Request, rsp *restful.Response) var metaDataItem *ZipItem var items = make([]*ZipItem, 0, 32) - handler.ProcessZip(func(f *zip.File, data []byte) { + err := handler.ProcessZip(func(f *zip.File, data []byte) { if (f.Name == ConfigExportMetadata || f.Name == ConfigExpotrMetadataV2) && metaDataItem == nil { metaDataItem = &ZipItem{ Name: f.Name, @@ -155,6 +155,10 @@ func (n *ConfigServer) ConfigImport(req *restful.Request, rsp *restful.Response) Data: data, }) }) + if err != nil { + nacoshttp.WrirteNacosErrorResponse(err, rsp) + return + } policy := req.QueryParameter("policy") diff --git a/apiserver/nacosserver/v1/discover/instance.go b/apiserver/nacosserver/v1/discover/instance.go index 87db26324..9e632a25b 100644 --- a/apiserver/nacosserver/v1/discover/instance.go +++ b/apiserver/nacosserver/v1/discover/instance.go @@ -77,8 +77,8 @@ func (n *DiscoverServer) handleUpdate(ctx context.Context, namespace, serviceNam return nil } -func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, service string, ins *model.Instance) error { - specIns := model.PrepareSpecInstance(namespace, service, ins) +func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, svcName string, ins *model.Instance) error { + specIns := model.PrepareSpecInstance(namespace, svcName, ins) resp := n.discoverSvr.DeregisterInstance(ctx, specIns) if apimodel.Code(resp.GetCode().GetValue()) != apimodel.Code_ExecuteSuccess { return &model.NacosError{ @@ -90,19 +90,19 @@ func (n *DiscoverServer) handleDeregister(ctx context.Context, namespace, servic } // handleBeat com.alibaba.nacos.naming.core.InstanceOperatorClientImpl#handleBeat -func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, service string, +func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, svcName string, clientBeat *model.ClientBeat) (map[string]interface{}, error) { - service = model.ReplaceNacosService(service) - svc := n.discoverSvr.Cache().Service().GetServiceByName(service, namespace) + svcName = model.ReplaceNacosService(svcName) + svc := n.discoverSvr.Cache().Service().GetServiceByName(svcName, namespace) if svc == nil { return nil, &model.NacosError{ ErrCode: int32(model.ExceptionCode_ServerError), - ErrMsg: "service not found: " + service + "@" + namespace, + ErrMsg: "service not found: " + svcName + "@" + namespace, } } resp := n.healthSvr.Report(ctx, &apiservice.Instance{ - Service: utils.NewStringValue(model.ReplaceNacosService(service)), + Service: utils.NewStringValue(model.ReplaceNacosService(svcName)), Namespace: utils.NewStringValue(namespace), Host: utils.NewStringValue(clientBeat.Ip), Port: utils.NewUInt32Value(uint32(clientBeat.Port)), @@ -136,7 +136,7 @@ func (n *DiscoverServer) handleBeat(ctx context.Context, namespace, service stri func (n *DiscoverServer) handleQueryInstances(ctx context.Context, params map[string]string) (interface{}, error) { namespace := params[model.ParamNamespaceID] group := model.GetGroupName(params[model.ParamServiceName]) - service := model.GetServiceName(params[model.ParamServiceName]) + svcName := model.GetServiceName(params[model.ParamServiceName]) clusters := params["clusters"] clientIP := params["clientIP"] udpPort, _ := strconv.ParseInt(params["udpPort"], 10, 32) @@ -151,14 +151,14 @@ func (n *DiscoverServer) handleQueryInstances(ctx context.Context, params map[st Port: int(udpPort), NamespaceId: namespace, Group: group, - Service: service, + Service: svcName, Cluster: clusters, Type: core.UDPCPush, }) } filterCtx := &core.FilterContext{ - Service: core.ToNacosService(n.discoverSvr.Cache(), namespace, service, group), + Service: core.ToNacosService(n.discoverSvr.Cache(), namespace, svcName, group), Clusters: strings.Split(clusters, ","), EnableOnly: true, HealthyOnly: healthyOnly, diff --git a/apiserver/nacosserver/v1/endpoints.go b/apiserver/nacosserver/v1/endpoints.go index 321679d7b..a9b4cd58d 100644 --- a/apiserver/nacosserver/v1/endpoints.go +++ b/apiserver/nacosserver/v1/endpoints.go @@ -69,5 +69,5 @@ func (n *NacosV1Server) FetchNacosEndpoints(req *restful.Request, rsp *restful.R } rsp.WriteHeader(http.StatusOK) - rsp.Write([]byte(strings.Join(ips, "\n"))) + _, _ = rsp.Write([]byte(strings.Join(ips, "\n"))) } diff --git a/auth/api.go b/auth/api.go index 56fe7fdf0..ebc0227f3 100644 --- a/auth/api.go +++ b/auth/api.go @@ -166,10 +166,14 @@ type UserHelper interface { // PolicyHelper . type PolicyHelper interface { + // GetPolicyRule . + GetPolicyRule(id string) *authcommon.StrategyDetail // CreatePrincipal 创建 principal 的默认 policy 资源 CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error // CleanPrincipal 清理 principal 所关联的 policy、role 资源 CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error + // GetRole . + GetRole(id string) *authcommon.Role } // OperatorInfo 根据 token 解析出来的具体额外信息 @@ -190,7 +194,7 @@ type OperatorInfo struct { Anonymous bool } -func NewAnonymous() OperatorInfo { +func NewAnonymousOperatorInfo() OperatorInfo { return OperatorInfo{ Origin: "", OwnerID: "", diff --git a/auth/policy/auth_checker.go b/auth/policy/auth_checker.go index 0a25a5958..c8efc8d32 100644 --- a/auth/policy/auth_checker.go +++ b/auth/policy/auth_checker.go @@ -18,6 +18,8 @@ package policy import ( + "context" + "encoding/json" "strings" "github.com/pkg/errors" @@ -92,8 +94,12 @@ func (d *DefaultAuthChecker) IsOpenAuth() bool { // AllowResourceOperate 是否允许资源的操作 func (d *DefaultAuthChecker) ResourcePredicate(ctx *authcommon.AcquireContext, res *authcommon.ResourceEntry) bool { - // 如果鉴权能力没有开启,那就默认都可以进行操作 - if !d.IsOpenAuth() { + // 如果是客户端请求,并且鉴权能力没有开启,那就默认都可以进行操作 + if ctx.IsFromClient() && !d.IsOpenClientAuth() { + return true + } + // 如果是控制台请求,并且鉴权能力没有开启,那就默认都可以进行操作 + if ctx.IsFromConsole() && !d.IsOpenConsoleAuth() { return true } @@ -101,7 +107,16 @@ func (d *DefaultAuthChecker) ResourcePredicate(ctx *authcommon.AcquireContext, r if !ok { return false } - return d.cacheMgr.AuthStrategy().Hint(p.(authcommon.Principal), res) != apisecurity.AuthAction_DENY + policyCache := d.cacheMgr.AuthStrategy() + + principals := d.listAllPrincipals(p.(authcommon.Principal)) + for i := range principals { + ret := policyCache.Hint(ctx.GetRequestContext(), principals[i], res) + if ret != apisecurity.AuthAction_DENY { + return true + } + } + return false } // CheckClientPermission 执行检查客户端动作判断是否有权限,并且对 RequestContext 注入操作者数据 @@ -129,22 +144,12 @@ func (d *DefaultAuthChecker) CheckConsolePermission(preCtx *authcommon.AcquireCo } // CheckPermission 执行检查动作判断是否有权限 -// -// step 1. 判断是否开启了鉴权 -// step 2. 对token进行检查判断 -// case 1. 如果 token 被禁用 -// a. 读操作,直接放通 -// b. 写操作,快速失败 -// step 3. 拉取token对应的操作者相关信息,注入到请求上下文中 -// step 4. 进行权限检查 func (d *DefaultAuthChecker) CheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { if err := d.userSvr.CheckCredential(authCtx); err != nil { return false, err } - if log.DebugEnabled() { - log.Debug("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), - zap.String("method", string(authCtx.GetMethod())), zap.Any("resources", authCtx.GetAccessResources())) - } + log.Info("[Auth][Checker] check permission args", utils.RequestID(authCtx.GetRequestContext()), + zap.Any("method", authCtx.GetMethods()), zap.Any("resources", authCtx.GetAccessResources())) if pass, _ := d.doCheckPermission(authCtx); pass { return true, nil @@ -159,11 +164,11 @@ func (d *DefaultAuthChecker) CheckPermission(authCtx *authcommon.AcquireContext) func (d *DefaultAuthChecker) resyncData(authCtx *authcommon.AcquireContext) error { if err := d.cacheMgr.AuthStrategy().Update(); err != nil { - log.Error("[Auth][Checker] force sync policy rule to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + log.Error("[Auth][Checker] force sync policy failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) return err } if err := d.cacheMgr.Role().Update(); err != nil { - log.Error("[Auth][Checker] force sync role to cache failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) + log.Error("[Auth][Checker] force sync role failed", utils.RequestID(authCtx.GetRequestContext()), zap.Error(err)) return err } return nil @@ -171,34 +176,65 @@ func (d *DefaultAuthChecker) resyncData(authCtx *authcommon.AcquireContext) erro // doCheckPermission 执行权限检查 func (d *DefaultAuthChecker) doCheckPermission(authCtx *authcommon.AcquireContext) (bool, error) { - p, _ := authCtx.GetAttachments()[authcommon.PrincipalKey].(authcommon.Principal) if d.IsCredible(authCtx) { return true, nil } + cur := authCtx.GetAttachments()[authcommon.PrincipalKey].(authcommon.Principal) + + principals := d.listAllPrincipals(cur) - allowPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("allow", p) - denyPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("deny", p) + // 遍历所有的 principal,检查是否有一个符合要求 + for i := range principals { + principal := principals[i] + allowPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("allow", principal) + denyPolicies := d.cacheMgr.AuthStrategy().GetPrincipalPolicies("deny", principal) - resources := authCtx.GetAccessResources() + resources := authCtx.GetAccessResources() - // 先执行 deny 策略 - for i := range denyPolicies { - item := denyPolicies[i] - if d.MatchPolicy(authCtx, item, p, resources) { - return false, ErrorNotPermission + // 先执行 deny 策略 + for i := range denyPolicies { + item := denyPolicies[i] + if d.MatchPolicy(authCtx, item, principal, resources) { + return false, ErrorNotPermission + } } - } - // 处理 allow 策略,只要有一个放开,就可以认为通过 - for i := range allowPolicies { - item := allowPolicies[i] - if d.MatchPolicy(authCtx, item, p, resources) { - return true, nil + // 处理 allow 策略,只要有一个放开,就可以认为通过 + for i := range allowPolicies { + item := allowPolicies[i] + if d.MatchPolicy(authCtx, item, principal, resources) { + return true, nil + } } } return false, ErrorNotPermission } +func (d *DefaultAuthChecker) listAllPrincipals(p authcommon.Principal) []authcommon.Principal { + principals := make([]authcommon.Principal, 0, 4) + principals = append(principals, p) + // 获取角色列表 + roles := d.cacheMgr.Role().GetPrincipalRoles(p) + for i := range roles { + principals = append(principals, authcommon.Principal{ + PrincipalID: roles[i].ID, + PrincipalType: authcommon.PrincipalRole, + }) + } + + // 如果是用户,获取所在的用户组列表 + if p.PrincipalType == authcommon.PrincipalUser { + groups := d.cacheMgr.User().GetUserLinkGroupIds(p.PrincipalID) + for i := range groups { + principals = append(principals, authcommon.Principal{ + PrincipalID: groups[i], + PrincipalType: authcommon.PrincipalGroup, + }) + } + } + return principals +} + // IsCredible 检查是否是可信的请求 func (d *DefaultAuthChecker) IsCredible(authCtx *authcommon.AcquireContext) bool { reqHeaders, ok := authCtx.GetRequestContext().Value(utils.ContextRequestHeaders).(map[string][]string) @@ -227,12 +263,13 @@ func (d *DefaultAuthChecker) IsCredible(authCtx *authcommon.AcquireContext) bool func (d *DefaultAuthChecker) MatchPolicy(authCtx *authcommon.AcquireContext, policy *authcommon.StrategyDetail, principal authcommon.Principal, resources map[apisecurity.ResourceType][]authcommon.ResourceEntry) bool { if !d.MatchCalleeFunctions(authCtx, principal, policy) { + log.Error("server function match policy fail", utils.RequestID(authCtx.GetRequestContext()), + zap.String("principal", principal.String()), zap.String("policy-id", policy.ID)) return false } if !d.MatchResourceOperateable(authCtx, principal, policy) { - return false - } - if !d.MatchResourceConditions(authCtx, principal, policy) { + log.Error("access resource match policy fail", utils.RequestID(authCtx.GetRequestContext()), + zap.String("principal", principal.String()), zap.String("policy-id", policy.ID)) return false } return true @@ -241,66 +278,92 @@ func (d *DefaultAuthChecker) MatchPolicy(authCtx *authcommon.AcquireContext, pol // MatchCalleeFunctions 检查操作方法是否和策略匹配 func (d *DefaultAuthChecker) MatchCalleeFunctions(authCtx *authcommon.AcquireContext, principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { + + // 如果开启了兼容模式,并且策略没有对可调用方法的拦截,那么就认为匹配成功 + if d.conf.Compatible && len(policy.CalleeMethods) == 0 { + return true + } + functions := policy.CalleeMethods - for i := range functions { - if functions[i] == string(authCtx.GetMethod()) { - return true + + allMatch := 0 + for _, method := range authCtx.GetMethods() { + curMatch := false + for i := range functions { + if utils.IsMatchAll(functions[i]) { + return true + } + if functions[i] == string(method) { + curMatch = true + break + } + if utils.IsWildMatch(string(method), functions[i]) { + curMatch = true + break + } } - if utils.IsWildMatch(string(authCtx.GetMethod()), functions[i]) { - return true + if curMatch { + allMatch++ } } - return false + return allMatch == len(authCtx.GetMethods()) } +type ( + compatibleChecker func(ctx context.Context, cacheSvr cachetypes.CacheManager, resource *authcommon.ResourceEntry) bool +) + +var ( + compatibleResource = map[apisecurity.ResourceType]compatibleChecker{ + apisecurity.ResourceType_UserGroups: func(ctx context.Context, cacheSvr cachetypes.CacheManager, + resource *authcommon.ResourceEntry) bool { + saveVal := cacheSvr.User().GetGroup(resource.ID) + if saveVal == nil { + return false + } + operator := utils.ParseUserID(ctx) + _, exist := saveVal.UserIds[operator] + return exist + }, + apisecurity.ResourceType_PolicyRules: func(ctx context.Context, cacheSvr cachetypes.CacheManager, + resource *authcommon.ResourceEntry) bool { + saveVal := cacheSvr.AuthStrategy().GetPolicyRule(resource.ID) + if saveVal == nil { + return false + } + operator := utils.ParseUserID(ctx) + for i := range saveVal.Principals { + if saveVal.Principals[i].PrincipalID == operator { + return true + } + } + return false + }, + } +) + // checkAction 检查操作资源是否和策略匹配 func (d *DefaultAuthChecker) MatchResourceOperateable(authCtx *authcommon.AcquireContext, principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { - matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { - for i := range resources { - actionResult := d.cacheMgr.AuthStrategy().Hint(principal, &resources[i]) - if actionResult.String() == policy.Action { - return true - } - } - return false - } - reqRes := authCtx.GetAccessResources() - isMatch := false - for k, v := range reqRes { - if isMatch = matchCheck(k, v); isMatch { - break - } + // 检查下 principal 有没有 condition 信息 + principalCondition := make([]authcommon.Condition, 0, 4) + // 这里主要兼容一些内部特殊场景,可能在 role/user/group 关联某个策略时,会有一些额外的关系属性,这里在 extend 统一查找 + _ = json.Unmarshal([]byte(principal.Extend["condition"]), &principalCondition) + + ctx := context.Background() + if len(principalCondition) != 0 { + ctx = context.WithValue(context.Background(), authcommon.ContextKeyConditions{}, principalCondition) } - return isMatch -} -// MatchResourceConditions 检查操作资源所拥有的标签是否和策略匹配 -func (d *DefaultAuthChecker) MatchResourceConditions(authCtx *authcommon.AcquireContext, - principal authcommon.Principal, policy *authcommon.StrategyDetail) bool { matchCheck := func(resType apisecurity.ResourceType, resources []authcommon.ResourceEntry) bool { - conditions := policy.Conditions - for i := range resources { - allMatch := true - for j := range conditions { - condition := conditions[j] - resVal, ok := resources[i].Metadata[condition.Key] - if !ok { - allMatch = false - break - } - compareFunc, ok := conditionCompareDict[condition.CompareFunc] - if !ok { - allMatch = false - break - } - if allMatch = compareFunc(resVal, condition.Value); !allMatch { - break - } + actionResult := d.cacheMgr.AuthStrategy().Hint(ctx, principal, &resources[i]) + if policy.IsMatchAction(actionResult.String()) { + return true } - if allMatch { + // 兼容模式下,对于用户组和策略规则,走一遍兜底的检查逻辑 + if _, ok := compatibleResource[resType]; ok && d.conf.Compatible { return true } } @@ -308,19 +371,10 @@ func (d *DefaultAuthChecker) MatchResourceConditions(authCtx *authcommon.Acquire } reqRes := authCtx.GetAccessResources() - isMatch := false + isMatch := true for k, v := range reqRes { - if isMatch = matchCheck(k, v); isMatch { - break - } + subMatch := matchCheck(k, v) + isMatch = isMatch && subMatch } return isMatch } - -var ( - conditionCompareDict = map[string]func(string, string) bool{ - "for_any_value:string_equal": func(s1, s2 string) bool { - return s1 == s2 - }, - } -) diff --git a/auth/policy/auth_checker_test.go b/auth/policy/auth_checker_test.go deleted file mode 100644 index 4668fc661..000000000 --- a/auth/policy/auth_checker_test.go +++ /dev/null @@ -1,1136 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy_test - -import ( - "context" - "testing" - - "github.com/golang/mock/gomock" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - - "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/auth/policy" - defaultuser "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - cachetypes "github.com/polarismesh/polaris/cache/api" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" -) - -func newPolicyServer() (*policy.Server, auth.StrategyServer, error) { - return policy.BuildServer() -} - -func Test_DefaultAuthChecker_CheckConsolePermission_Write_NoStrict(t *testing.T) { - reset(false) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - users := createMockUser(10) - groups := createMockUserGroup(users) - - namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) - services := createMockService(namespaces) - serviceMap := convertServiceSliceToMap(services) - strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - - cfg, storage := initCache(ctrl) - - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) - storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) - storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - cancel() - cacheMgr.Close() - }) - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgr) - - _, svr, err := newPolicyServer() - if err != nil { - t.Fatal(err) - } - if err := svr.Initialize(&auth.Config{ - Strategy: &auth.StrategyConfig{ - Name: auth.DefaultPolicyPluginName, - }, - }, storage, cacheMgr, proxySvr); err != nil { - t.Fatal(err) - } - checker := svr.GetAuthChecker() - - _ = cacheMgr.TestUpdate() - - freeIndex := len(users) + len(groups) + 1 - - t.Run("权限检查非严格模式-主账户资源访问检查", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户资源访问检查(无操作权限)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查非严格模式-子账户资源访问检查(有操作权限)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户资源访问检查(资源无绑定策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户访问用户组资源检查(属于用户组成员)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[(len(users)-1)+2].ID, - Owner: services[(len(users)-1)+2].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户访问用户组资源检查(不属于用户组成员)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[(len(users)-1)+4].ID, - Owner: services[(len(users)-1)+4].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查非严格模式-用户组访问组内成员资源检查", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groups[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(groups[1].Token), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查非严格模式-token非法-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "users[1].Token") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-token为空-匿名账户资源访问检查(资源无绑定策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) -} - -func Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict(t *testing.T) { - reset(true) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - users := createMockUser(10) - groups := createMockUserGroup(users) - - namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) - services := createMockService(namespaces) - serviceMap := convertServiceSliceToMap(services) - strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - - cfg, storage := initCache(ctrl) - - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) - storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) - storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - cancel() - cacheMgr.Close() - }) - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgr) - - _, svr, err := newPolicyServer() - if err != nil { - t.Fatal(err) - } - if err := svr.Initialize(&auth.Config{ - Strategy: &auth.StrategyConfig{ - Name: auth.DefaultPolicyPluginName, - }, - }, storage, cacheMgr, proxySvr); err != nil { - t.Fatal(err) - } - checker := svr.GetAuthChecker() - - _ = cacheMgr.TestUpdate() - - freeIndex := len(users) + len(groups) + 1 - - t.Run("权限检查严格模式-主账户操作资源", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken(users[0].Token), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-子账户操作资源(无操作权限)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-子账户操作资源(有操作权限)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源有策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源有策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken(""), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithOperation(authcommon.Create), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-token非法-匿名账户操作资源(资源没有策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - - dchecker := checker.(*policy.DefaultAuthChecker) - oldConf := dchecker.GetConfig() - defer func() { - dchecker.SetConfig(oldConf) - }() - dchecker.SetConfig(&policy.AuthConfig{ - ConsoleOpen: true, - ConsoleStrict: true, - }) - - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-token为空-匿名账户操作资源(资源没有策略)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_CheckConsolePermission_Write_Strict"), - // authcommon.WithToken(""), - authcommon.WithOperation(authcommon.Create), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - dchecker := checker.(*policy.DefaultAuthChecker) - oldConf := dchecker.GetConfig() - defer func() { - dchecker.SetConfig(oldConf) - }() - dchecker.SetConfig(&policy.AuthConfig{ - ConsoleOpen: true, - ConsoleStrict: true, - }) - - _, err = dchecker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) -} - -func Test_DefaultAuthChecker_CheckConsolePermission_Read_NoStrict(t *testing.T) { - reset(false) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - users := createMockUser(10) - groups := createMockUserGroup(users) - - namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) - services := createMockService(namespaces) - serviceMap := convertServiceSliceToMap(services) - strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - - cfg, storage := initCache(ctrl) - - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) - storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) - storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - cancel() - cacheMgr.Close() - }) - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgr) - - _, svr, err := newPolicyServer() - if err != nil { - t.Fatal(err) - } - if err := svr.Initialize(&auth.Config{ - Strategy: &auth.StrategyConfig{ - Name: auth.DefaultPolicyPluginName, - }, - }, storage, cacheMgr, proxySvr); err != nil { - t.Fatal(err) - } - checker := svr.GetAuthChecker() - - _ = cacheMgr.TestUpdate() - - freeIndex := len(users) + len(groups) + 1 - - t.Run("权限检查非严格模式-主账户正常读操作", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[0].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(""), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查非严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) -} - -func Test_DefaultAuthChecker_CheckConsolePermission_Read_Strict(t *testing.T) { - reset(true) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - users := createMockUser(10) - groups := createMockUserGroup(users) - - namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) - services := createMockService(namespaces) - serviceMap := convertServiceSliceToMap(services) - strategies, _ := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - - cfg, storage := initCache(ctrl) - - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(strategies, nil) - storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) - storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgr, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - cancel() - cacheMgr.Close() - }) - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgr) - - _, svr, err := newPolicyServer() - if err != nil { - t.Fatal(err) - } - if err := svr.Initialize(&auth.Config{ - Strategy: &auth.StrategyConfig{ - Name: auth.DefaultPolicyPluginName, - }, - }, storage, cacheMgr, proxySvr); err != nil { - t.Fatal(err) - } - checker := svr.GetAuthChecker() - dchecker := checker.(*policy.DefaultAuthChecker) - oldConf := dchecker.GetConfig() - defer func() { - dchecker.SetConfig(oldConf) - }() - dchecker.SetConfig(&policy.AuthConfig{ - ConsoleOpen: true, - ConsoleStrict: true, - }) - - _ = cacheMgr.TestUpdate() - - freeIndex := len(users) + len(groups) + 1 - - t.Run("权限检查严格模式-主账户正常读操作", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[0].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[0].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-子账户正常读操作-资源有权限", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[1].ID, - Owner: services[1].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-子账户正常读操作-资源无权限", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-子账户正常读操作-资源无绑定策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, users[1].Token) - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(users[1].Token), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.NoError(t, err, "Should be verify success") - }) - - t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源有策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(""), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-匿名账户正常读操作-token为空-资源无策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken(""), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源有策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[0].ID, - Owner: services[0].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) - - t.Run("权限检查严格模式-匿名账户正常读操作-token非法-资源无策略", func(t *testing.T) { - ctx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "Test_DefaultAuthChecker_VerifyCredential") - authCtx := authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithMethod("Test_DefaultAuthChecker_VerifyCredential"), - // authcommon.WithToken("Test_DefaultAuthChecker_VerifyCredential"), - authcommon.WithOperation(authcommon.Read), - authcommon.WithModule(authcommon.DiscoverModule), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: services[freeIndex].ID, - Owner: services[freeIndex].Owner, - }, - }, - }), - ) - _, err = checker.CheckConsolePermission(authCtx) - t.Logf("%+v", err) - assert.Error(t, err, "Should be verify fail") - }) -} - -func Test_DefaultAuthChecker_Initialize(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - t.Run("使用未迁移至auth.user.option及auth.strategy.option的配置", func(t *testing.T) { - reset(true) - authChecker := &policy.Server{} - cfg := &auth.Config{} - cfg.SetDefault() - cfg.Name = "" - cfg.Option = map[string]interface{}{ - "consoleOpen": true, - "clientOpen": true, - "salt": "polarismesh@2021", - "strict": false, - } - err := authChecker.ParseOptions(cfg) - assert.NoError(t, err) - assert.Equal(t, &policy.AuthConfig{ - ConsoleOpen: true, - ClientOpen: true, - Strict: false, - ConsoleStrict: false, - ClientStrict: false, - }, authChecker.GetOptions()) - }) - - t.Run("使用完全迁移至auth.user.option及auth.strategy.option的配置", func(t *testing.T) { - reset(true) - authChecker := &policy.Server{} - - cfg := &auth.Config{} - cfg.SetDefault() - cfg.User = &auth.UserConfig{ - Name: "", - Option: map[string]interface{}{"salt": "polarismesh@2021"}, - } - cfg.Strategy = &auth.StrategyConfig{ - Name: "", - Option: map[string]interface{}{ - "consoleOpen": true, - "clientOpen": true, - "strict": false, - }, - } - - err := authChecker.ParseOptions(cfg) - assert.NoError(t, err) - assert.Equal(t, &policy.AuthConfig{ - ConsoleOpen: true, - ConsoleStrict: false, - ClientOpen: true, - Strict: false, - }, authChecker.GetOptions()) - }) - - t.Run("使用部分迁移至auth.user.option及auth.strategy.option的配置(应当报错)", func(t *testing.T) { - reset(true) - authChecker := &policy.Server{} - cfg := &auth.Config{} - cfg.SetDefault() - cfg.Name = "" - cfg.Option = map[string]interface{}{ - "clientOpen": true, - "strict": false, - } - cfg.User = &auth.UserConfig{ - Name: "", - Option: map[string]interface{}{"salt": "polarismesh@2021"}, - } - cfg.Strategy = &auth.StrategyConfig{ - Name: "", - Option: map[string]interface{}{ - "consoleOpen": true, - }, - } - - err := authChecker.ParseOptions(cfg) - assert.NoError(t, err) - }) - -} - -func TestDefaultAuthChecker_isCredible(t *testing.T) { - type fields struct { - conf *policy.AuthConfig - cacheMgr cachetypes.CacheManager - userSvr auth.UserServer - } - type args struct { - authCtx *authcommon.AcquireContext - } - tests := []struct { - name string - fields fields - args args - want bool - }{ - {}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &policy.DefaultAuthChecker{} - d.SetConfig(tt.fields.conf) - if got := d.IsCredible(tt.args.authCtx); got != tt.want { - t.Errorf("DefaultAuthChecker.isCredible() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/auth/policy/common_test.go b/auth/policy/common_test.go deleted file mode 100644 index 72b0e02ef..000000000 --- a/auth/policy/common_test.go +++ /dev/null @@ -1,403 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy_test - -import ( - "fmt" - "time" - - "github.com/golang/mock/gomock" - "github.com/google/uuid" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "golang.org/x/crypto/bcrypt" - "google.golang.org/protobuf/types/known/wrapperspb" - - defaultuser "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - "github.com/polarismesh/polaris/common/metrics" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -func reset(strict bool) { - -} - -func initCache(ctrl *gomock.Controller) (*cache.Config, *storemock.MockStore) { - metrics.InitMetrics() - /* - - name: service # 加载服务数据 - option: - disableBusiness: false # 不加载业务服务 - needMeta: true # 加载服务元数据 - - name: instance # 加载实例数据 - option: - disableBusiness: false # 不加载业务服务实例 - needMeta: true # 加载实例元数据 - - name: routingConfig # 加载路由数据 - - name: rateLimitConfig # 加载限流数据 - - name: circuitBreakerConfig # 加载熔断数据 - - name: l5 # 加载l5数据 - - name: users - - name: strategyRule - - name: namespace - */ - cfg := &cache.Config{} - storage := storemock.NewMockStore(ctrl) - - mockTx := storemock.NewMockTx(ctrl) - mockTx.EXPECT().Commit().Return(nil).AnyTimes() - mockTx.EXPECT().Rollback().Return(nil).AnyTimes() - mockTx.EXPECT().CreateReadView().Return(nil).AnyTimes() - - storage.EXPECT().StartReadTx().Return(mockTx, nil).AnyTimes() - storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetInstancesCountTx(gomock.Any()).AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetMoreInstances(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]*model.Instance{ - "123": { - Proto: &service_manage.Instance{ - Id: wrapperspb.String(uuid.NewString()), - Host: wrapperspb.String("127.0.0.1"), - Port: wrapperspb.UInt32(8080), - }, - Valid: true, - }, - }, nil).AnyTimes() - storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - - return cfg, storage -} - -func createMockNamespace(total int, owner string) []*model.Namespace { - namespaces := make([]*model.Namespace, 0, total) - - for i := 0; i < total; i++ { - namespaces = append(namespaces, &model.Namespace{ - Name: fmt.Sprintf("namespace_%d", i), - Owner: owner, - Valid: true, - }) - } - - return namespaces -} - -func createMockService(namespaces []*model.Namespace) []*model.Service { - services := make([]*model.Service, 0, len(namespaces)) - - for i := 0; i < len(namespaces); i++ { - ns := namespaces[i] - services = append(services, &model.Service{ - ID: utils.NewUUID(), - Namespace: ns.Name, - Owner: ns.Owner, - Name: fmt.Sprintf("service_%d", i), - Valid: true, - }) - } - - return services -} - -func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { - strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) - defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) - - owner := "" - for i := 0; i < len(users); i++ { - user := users[i] - if user.Owner == "" { - owner = user.ID - break - } - } - - for i := 0; i < len(users); i++ { - user := users[i] - service := services[i] - id := utils.NewUUID() - strategies = append(strategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: user.ID, - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: false, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - - defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: user.ID, - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: true, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - - for i := 0; i < len(groups); i++ { - group := groups[i] - service := services[len(users)+i] - id := utils.NewUUID() - strategies = append(strategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: group.ID, - PrincipalType: authcommon.PrincipalGroup, - }, - }, - Default: false, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - - defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: group.ID, - PrincipalType: authcommon.PrincipalGroup, - }, - }, - Default: true, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - - return defaultStrategies, strategies -} - -func convertServiceSliceToMap(services []*model.Service) map[string]*model.Service { - ret := make(map[string]*model.Service) - - for i := range services { - service := services[i] - ret[service.ID] = service - } - - return ret -} - -// createMockUser 默认 users[0] 为 owner 用户 -func createMockUser(total int, prefix ...string) []*authcommon.User { - users := make([]*authcommon.User, 0, total) - - ownerId := utils.NewUUID() - - nameTemp := "user-%d" - if len(prefix) != 0 { - nameTemp = prefix[0] + nameTemp - } - - for i := 0; i < total; i++ { - id := fmt.Sprintf("fake-user-id-%d-%s", i, utils.NewUUID()) - if i == 0 { - id = ownerId - } - pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) - token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") - users = append(users, &authcommon.User{ - ID: id, - Name: fmt.Sprintf(nameTemp, i), - Password: string(pwd), - Owner: func() string { - if id == ownerId { - return "" - } - return ownerId - }(), - Source: "Polaris", - Mobile: "", - Email: "", - Type: func() authcommon.UserRoleType { - if id == ownerId { - return authcommon.OwnerUserRole - } - return authcommon.SubAccountUserRole - }(), - Token: token, - TokenEnable: true, - Valid: true, - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - return users -} - -func createApiMockUser(total int, prefix ...string) []*apisecurity.User { - users := make([]*apisecurity.User, 0, total) - - models := createMockUser(total, prefix...) - - for i := range models { - users = append(users, &apisecurity.User{ - Name: utils.NewStringValue("test-" + models[i].Name), - Password: utils.NewStringValue("123456"), - Source: utils.NewStringValue("Polaris"), - Comment: utils.NewStringValue(models[i].Comment), - Mobile: utils.NewStringValue(models[i].Mobile), - Email: utils.NewStringValue(models[i].Email), - }) - } - - return users -} - -func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { - groups := make([]*authcommon.UserGroupDetail, 0, len(users)) - - for i := range users { - user := users[i] - id := utils.NewUUID() - - token, _ := defaultuser.CreateToken("", id, "polarismesh@2021") - groups = append(groups, &authcommon.UserGroupDetail{ - UserGroup: &authcommon.UserGroup{ - ID: id, - Name: fmt.Sprintf("test-group-%d", i), - Owner: users[0].ID, - Token: token, - TokenEnable: true, - Valid: true, - Comment: "", - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }, - UserIds: map[string]struct{}{ - user.ID: {}, - }, - }) - } - - return groups -} - -// createMockApiUserGroup -func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { - musers := make([]*authcommon.User, 0, len(users)) - for i := range users { - musers = append(musers, &authcommon.User{ - ID: users[i].GetId().GetValue(), - }) - } - - models := createMockUserGroup(musers) - ret := make([]*apisecurity.UserGroup, 0, len(models)) - - for i := range models { - ret = append(ret, &apisecurity.UserGroup{ - Name: utils.NewStringValue(models[i].Name), - Comment: utils.NewStringValue(models[i].Comment), - Relation: &apisecurity.UserGroupRelation{ - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(users[i].GetId().GetValue()), - }, - }, - }, - }) - } - - return ret -} diff --git a/auth/policy/default.go b/auth/policy/default.go index ccff9599b..d9b372902 100644 --- a/auth/policy/default.go +++ b/auth/policy/default.go @@ -29,7 +29,7 @@ import ( type ServerProxyFactory func(svr *Server, pre auth.StrategyServer) (auth.StrategyServer, error) var ( - // serverProxyFactories auth.UserServer API 代理工厂 + // serverProxyFactories auth.StrategyServer API 代理工厂 serverProxyFactories = map[string]ServerProxyFactory{} ) diff --git a/auth/policy/helper.go b/auth/policy/helper.go index e77f14de1..41eb1676d 100644 --- a/auth/policy/helper.go +++ b/auth/policy/helper.go @@ -1,14 +1,32 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + package policy import ( "context" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" ) type DefaultPolicyHelper struct { @@ -18,12 +36,16 @@ type DefaultPolicyHelper struct { checker auth.AuthChecker } +func (h *DefaultPolicyHelper) GetRole(id string) *authcommon.Role { + return h.cacheMgr.Role().GetRole(id) +} + +func (h *DefaultPolicyHelper) GetPolicyRule(id string) *authcommon.StrategyDetail { + return h.cacheMgr.AuthStrategy().GetPolicyRule(id) +} + // CreatePrincipal 创建 principal 的默认 policy 资源 func (h *DefaultPolicyHelper) CreatePrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { - if !h.options.OpenPrincipalDefaultPolicy { - return nil - } - if err := h.storage.AddStrategy(tx, defaultPrincipalPolicy(p)); err != nil { return err } @@ -32,25 +54,50 @@ func (h *DefaultPolicyHelper) CreatePrincipal(ctx context.Context, tx store.Tx, func defaultPrincipalPolicy(p authcommon.Principal) *authcommon.StrategyDetail { // Create the user's default weight policy + ruleId := utils.NewUUID() + + resources := []authcommon.StrategyResource{} + if p.PrincipalType == authcommon.PrincipalUser { + resources = append(resources, authcommon.StrategyResource{ + StrategyID: ruleId, + ResType: int32(apisecurity.ResourceType_Users), + ResID: p.PrincipalID, + }) + } + return &authcommon.StrategyDetail{ - ID: utils.NewUUID(), - Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalUser, p.Name), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Default: true, - Owner: p.Owner, - Revision: utils.NewUUID(), - Resources: []authcommon.StrategyResource{}, - Valid: true, - Comment: "Default Strategy", + ID: ruleId, + Name: authcommon.BuildDefaultStrategyName(authcommon.PrincipalUser, p.Name), + Action: apisecurity.AuthAction_ALLOW.String(), + Default: true, + Owner: p.Owner, + Revision: utils.NewUUID(), + Source: "Polaris", + Resources: resources, + Principals: []authcommon.Principal{p}, + CalleeMethods: []string{ + // 用户操作权限 + string(authcommon.DescribeUsers), + string(authcommon.DescribeUserToken), + string(authcommon.UpdateUser), + string(authcommon.UpdateUserPassword), + string(authcommon.EnableUserToken), + string(authcommon.ResetUserToken), + // 鉴权策略 + string(authcommon.DescribeAuthPolicies), + string(authcommon.DescribeAuthPolicyDetail), + // 角色 + string(authcommon.DescribeAuthRoles), + }, + Valid: true, + Comment: "default principal auth policy rule", } } // CleanPrincipal 清理 principal 所关联的 policy、role 资源 func (h *DefaultPolicyHelper) CleanPrincipal(ctx context.Context, tx store.Tx, p authcommon.Principal) error { - if h.options.OpenPrincipalDefaultPolicy { - if err := h.storage.CleanPrincipalPolicies(tx, p); err != nil { - return err - } + if err := h.storage.CleanPrincipalPolicies(tx, p); err != nil { + return err } if err := h.storage.CleanPrincipalRoles(tx, &p); err != nil { diff --git a/auth/policy/inteceptor/auth/server.go b/auth/policy/inteceptor/auth/server.go index 128046b18..c80c3acf9 100644 --- a/auth/policy/inteceptor/auth/server.go +++ b/auth/policy/inteceptor/auth/server.go @@ -27,7 +27,9 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -78,17 +80,21 @@ func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.Aut resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.CreateStrategy(ctx, strategy) + return svr.nextSvr.CreateStrategy(authCtx.GetRequestContext(), strategy) } // UpdateStrategies 批量更新策略 func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { resources := make([]authcommon.ResourceEntry, 0, len(reqs)) for i := range reqs { - item := reqs[i] - resources = append(resources, authcommon.ResourceEntry{ - ID: item.GetId().GetValue(), - }) + entry := authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: reqs[i].GetId().GetValue(), + } + if saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(reqs[i].GetId().GetValue()); saveRule != nil { + entry.Metadata = saveRule.Metadata + } + resources = append(resources, entry) } authCtx := authcommon.NewAcquireContext( @@ -96,23 +102,30 @@ func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.Mod authcommon.WithOperation(authcommon.Modify), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.UpdateAuthPolicies), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_PolicyRules: resources, + }), ) if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.UpdateStrategies(ctx, reqs) + return svr.nextSvr.UpdateStrategies(authCtx.GetRequestContext(), reqs) } // DeleteStrategies 删除策略 func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.AuthStrategy) *apiservice.BatchWriteResponse { resources := make([]authcommon.ResourceEntry, 0, len(reqs)) for i := range reqs { - item := reqs[i] - resources = append(resources, authcommon.ResourceEntry{ - ID: item.GetId().GetValue(), - }) + entry := authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: reqs[i].GetId().GetValue(), + } + if saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(reqs[i].GetId().GetValue()); saveRule != nil { + entry.Metadata = saveRule.Metadata + } + resources = append(resources, entry) } authCtx := authcommon.NewAcquireContext( @@ -120,13 +133,16 @@ func (svr *Server) DeleteStrategies(ctx context.Context, reqs []*apisecurity.Aut authcommon.WithOperation(authcommon.Delete), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.DeleteAuthPolicies), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_PolicyRules: resources, + }), ) if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) return resp } - return svr.nextSvr.DeleteStrategies(ctx, reqs) + return svr.nextSvr.DeleteStrategies(authCtx.GetRequestContext(), reqs) } // GetStrategies 获取资源列表 @@ -140,44 +156,62 @@ func (svr *Server) GetStrategies(ctx context.Context, query map[string]string) * authcommon.WithMethod(authcommon.DescribeAuthPolicies), ) - if err := svr.userSvr.CheckCredential(authCtx); err != nil { + checker := svr.GetAuthChecker() + if _, err := checker.CheckConsolePermission(authCtx); err != nil { return api.NewAuthBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - checker := svr.GetAuthChecker() - cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { - return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: sd.ID, + ctx = authCtx.GetRequestContext() + ctx = cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { + ok := checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: sd.ID, + Metadata: sd.Metadata, }) + if ok { + return true + } + // 兼容老版本的策略查询逻辑 + if compatible, _ := ctx.Value(model.ContextKeyCompatible{}).(bool); compatible { + for i := range sd.Principals { + if sd.Principals[i].PrincipalID == utils.ParseUserID(ctx) { + return true + } + } + } + return false }) + authCtx.SetRequestContext(ctx) return svr.nextSvr.GetStrategies(ctx, query) } // GetStrategy 获取策略详细 func (svr *Server) GetStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { + entry := authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_PolicyRules, + ID: strategy.GetId().GetValue(), + } + saveRule := svr.nextSvr.PolicyHelper().GetPolicyRule(strategy.GetId().GetValue()) + if saveRule != nil { + entry.Metadata = saveRule.Metadata + } + authCtx := authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithOperation(authcommon.Read), authcommon.WithModule(authcommon.AuthModule), authcommon.WithMethod(authcommon.DescribeAuthPolicyDetail), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_PolicyRules: {entry}, + }), ) checker := svr.GetAuthChecker() - if _, err := checker.CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - - cachetypes.AppendAuthPolicyPredicate(ctx, func(ctx context.Context, sd *authcommon.StrategyDetail) bool { - return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_PolicyRules, - ID: sd.ID, - }) - }) - - return svr.nextSvr.GetStrategy(ctx, strategy) + return svr.nextSvr.GetStrategy(authCtx.GetRequestContext(), strategy) } // GetPrincipalResources 获取某个 principal 的所有可操作资源列表 @@ -194,7 +228,7 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s if _, err := checker.CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.GetPrincipalResources(ctx, query) + return svr.nextSvr.GetPrincipalResources(authCtx.GetRequestContext(), query) } // GetAuthChecker 获取鉴权检查器 @@ -209,20 +243,103 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e // CreateRoles 批量创建角色 func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Create), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.CreateAuthRoles), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp + } + return svr.nextSvr.CreateRoles(authCtx.GetRequestContext(), reqs) } // UpdateRoles 批量更新角色 func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + resources := make([]authcommon.ResourceEntry, 0, len(reqs)) + for i := range reqs { + entry := authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Roles, + ID: reqs[i].GetId(), + } + if saveRule := svr.nextSvr.PolicyHelper().GetRole(reqs[i].GetId()); saveRule != nil { + entry.Metadata = saveRule.Metadata + } + resources = append(resources, entry) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.UpdateAuthRoles), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Roles: resources, + }), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp + } + return svr.nextSvr.UpdateRoles(authCtx.GetRequestContext(), reqs) } // DeleteRoles 批量删除角色 func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + resources := make([]authcommon.ResourceEntry, 0, len(reqs)) + for i := range reqs { + entry := authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Roles, + ID: reqs[i].GetId(), + } + if saveRule := svr.nextSvr.PolicyHelper().GetRole(reqs[i].GetId()); saveRule != nil { + entry.Metadata = saveRule.Metadata + } + resources = append(resources, entry) + } + + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Modify), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DeleteAuthRoles), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Roles: resources, + }), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + resp := api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return resp + } + return svr.nextSvr.DeleteRoles(authCtx.GetRequestContext(), reqs) } // GetRoles 查询角色列表 func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - return nil + authCtx := authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(authcommon.Read), + authcommon.WithModule(authcommon.AuthModule), + authcommon.WithMethod(authcommon.DescribeAuthRoles), + ) + + if _, err := svr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewAuthBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + } + + checker := svr.GetAuthChecker() + ctx = cachetypes.AppendAuthRolePredicate(ctx, func(ctx context.Context, sd *authcommon.Role) bool { + return checker.ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Roles, + ID: sd.ID, + Metadata: sd.Metadata, + }) + }) + + return svr.nextSvr.GetRoles(ctx, query) } diff --git a/auth/policy/inteceptor/paramcheck/server.go b/auth/policy/inteceptor/paramcheck/server.go index 3ae25a934..3e707201c 100644 --- a/auth/policy/inteceptor/paramcheck/server.go +++ b/auth/policy/inteceptor/paramcheck/server.go @@ -31,6 +31,7 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/log" authcommon "github.com/polarismesh/polaris/common/model/auth" + commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" ) @@ -59,8 +60,10 @@ func NewServer(nextSvr auth.StrategyServer) auth.StrategyServer { } type Server struct { - nextSvr auth.StrategyServer - userSvr auth.UserServer + storage store.Store + cacheMgr cachetypes.CacheManager + nextSvr auth.StrategyServer + userSvr auth.UserServer } // PolicyHelper implements auth.StrategyServer. @@ -71,6 +74,8 @@ func (svr *Server) PolicyHelper() auth.PolicyHelper { // Initialize 执行初始化动作 func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMgr cachetypes.CacheManager, userSvr auth.UserServer) error { svr.userSvr = userSvr + svr.cacheMgr = cacheMgr + svr.storage = storage return svr.nextSvr.Initialize(options, storage, cacheMgr, userSvr) } @@ -80,12 +85,30 @@ func (svr *Server) Name() string { } // CreateStrategy 创建策略 -func (svr *Server) CreateStrategy(ctx context.Context, strategy *apisecurity.AuthStrategy) *apiservice.Response { - return svr.nextSvr.CreateStrategy(ctx, strategy) +func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { + if err := svr.checkCreateStrategy(req); err != nil { + return err + } + return svr.nextSvr.CreateStrategy(ctx, req) } // UpdateStrategies 批量更新策略 func (svr *Server) UpdateStrategies(ctx context.Context, reqs []*apisecurity.ModifyAuthStrategy) *apiservice.BatchWriteResponse { + batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + var rsp *apiservice.Response + strategy, err := svr.storage.GetStrategyDetail(reqs[i].GetId().GetValue()) + if err != nil { + log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), zap.Error(err)) + rsp = api.NewModifyAuthStrategyResponse(commonstore.StoreCode2APICode(err), reqs[i]) + } + if strategy == nil { + continue + } else { + rsp = svr.checkUpdateStrategy(ctx, reqs[i], strategy) + } + api.Collect(batchResp, rsp) + } return svr.nextSvr.UpdateStrategies(ctx, reqs) } @@ -158,3 +181,140 @@ func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *a func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { return svr.nextSvr.GetRoles(ctx, query) } + +// checkCreateStrategy 检查创建鉴权策略的请求 +func (svr *Server) checkCreateStrategy(req *apisecurity.AuthStrategy) *apiservice.Response { + // 检查名称信息 + if err := CheckName(req.GetName()); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_InvalidUserName, req) + } + // 检查用户是否存在 + if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetPrincipals())); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUser, req) + } + // 检查用户组是否存在 + if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetPrincipals())); err != nil { + return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) + } + // 检查资源是否存在 + if errResp := svr.checkResourceExist(req.GetResources()); errResp != nil { + return errResp + } + return nil +} + +// checkUpdateStrategy 检查更新鉴权策略的请求 +// Case 1. 修改的是默认鉴权策略的话,只能修改资源,不能添加用户 or 用户组 +// Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 +func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy, + saved *authcommon.StrategyDetail) *apiservice.Response { + if saved.Default { + if len(req.AddPrincipals.Users) != 0 || + len(req.AddPrincipals.Groups) != 0 || + len(req.RemovePrincipals.Groups) != 0 || + len(req.RemovePrincipals.Users) != 0 { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowModifyDefaultStrategyPrincipal, req) + } + + // 主账户的默认策略禁止编辑 + if len(saved.Principals) == 1 && saved.Principals[0].PrincipalType == authcommon.PrincipalUser { + if saved.Principals[0].PrincipalID == utils.ParseOwnerID(ctx) { + return api.NewAuthResponse(apimodel.Code_NotAllowModifyOwnerDefaultStrategy) + } + } + } + + // 检查用户是否存在 + if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetAddPrincipals())); err != nil { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUser, req) + } + + // 检查用户组是否存 + if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetAddPrincipals())); err != nil { + return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) + } + + // 检查资源是否存在 + if errResp := svr.checkResourceExist(req.GetAddResources()); errResp != nil { + return errResp + } + return nil +} + +// checkUserExist 检查用户是否存在 +func (svr *Server) checkUserExist(users []*apisecurity.User) error { + if len(users) == 0 { + return nil + } + return svr.userSvr.GetUserHelper().CheckUsersExist(context.TODO(), users) +} + +// checkUserGroupExist 检查用户组是否存在 +func (svr *Server) checkGroupExist(groups []*apisecurity.UserGroup) error { + if len(groups) == 0 { + return nil + } + return svr.userSvr.GetUserHelper().CheckGroupsExist(context.TODO(), groups) +} + +// checkResourceExist 检查资源是否存在 +func (svr *Server) checkResourceExist(resources *apisecurity.StrategyResources) *apiservice.Response { + namespaces := resources.GetNamespaces() + + nsCache := svr.cacheMgr.Namespace() + for index := range namespaces { + val := namespaces[index] + if val.GetId().GetValue() == "*" { + break + } + if ns := nsCache.GetNamespace(val.GetId().GetValue()); ns == nil { + return api.NewAuthResponse(apimodel.Code_NotFoundNamespace) + } + } + + services := resources.GetServices() + svcCache := svr.cacheMgr.Service() + for index := range services { + val := services[index] + if val.GetId().GetValue() == "*" { + break + } + if svc := svcCache.GetServiceByID(val.GetId().GetValue()); svc == nil { + return api.NewAuthResponse(apimodel.Code_NotFoundService) + } + } + + return nil +} + +func convertPrincipalsToUsers(principals *apisecurity.Principals) []*apisecurity.User { + if principals == nil { + return make([]*apisecurity.User, 0) + } + + users := make([]*apisecurity.User, 0, len(principals.Users)) + for k := range principals.GetUsers() { + user := principals.GetUsers()[k] + users = append(users, &apisecurity.User{ + Id: user.Id, + }) + } + + return users +} + +func convertPrincipalsToGroups(principals *apisecurity.Principals) []*apisecurity.UserGroup { + if principals == nil { + return make([]*apisecurity.UserGroup, 0) + } + + groups := make([]*apisecurity.UserGroup, 0, len(principals.Groups)) + for k := range principals.GetGroups() { + group := principals.GetGroups()[k] + groups = append(groups, &apisecurity.UserGroup{ + Id: group.Id, + }) + } + + return groups +} diff --git a/auth/policy/inteceptor/paramcheck/utils.go b/auth/policy/inteceptor/paramcheck/utils.go new file mode 100644 index 000000000..c055016a2 --- /dev/null +++ b/auth/policy/inteceptor/paramcheck/utils.go @@ -0,0 +1,58 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "errors" + "regexp" + "unicode/utf8" + + "github.com/golang/protobuf/ptypes/wrappers" + + "github.com/polarismesh/polaris/common/utils" +) + +var ( + regNameStr = regexp.MustCompile("^[\u4E00-\u9FA5A-Za-z0-9_\\-.]+$") + regEmail = regexp.MustCompile(`^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$`) +) + +// CheckName 名称检查 +func CheckName(name *wrappers.StringValue) error { + if name == nil { + return errors.New(utils.NilErrString) + } + + if name.GetValue() == "" { + return errors.New(utils.EmptyErrString) + } + + if name.GetValue() == "polariadmin" { + return errors.New("illegal username") + } + + if utf8.RuneCountInString(name.GetValue()) > utils.MaxNameLength { + return errors.New("name too long") + } + + if ok := regNameStr.MatchString(name.GetValue()); !ok { + return errors.New("name contains invalid character") + } + + return nil +} diff --git a/auth/policy/main_test.go b/auth/policy/main_test.go deleted file mode 100644 index 4d5302ba7..000000000 --- a/auth/policy/main_test.go +++ /dev/null @@ -1,215 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy_test - -import ( - "errors" - - _ "github.com/go-sql-driver/mysql" - bolt "go.etcd.io/bbolt" - - "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" - _ "github.com/polarismesh/polaris/cache" - api "github.com/polarismesh/polaris/common/api/v1" - commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/namespace" - "github.com/polarismesh/polaris/plugin" - _ "github.com/polarismesh/polaris/plugin/cmdb/memory" - _ "github.com/polarismesh/polaris/plugin/discoverevent/local" - _ "github.com/polarismesh/polaris/plugin/healthchecker/memory" - _ "github.com/polarismesh/polaris/plugin/healthchecker/redis" - _ "github.com/polarismesh/polaris/plugin/history/logger" - _ "github.com/polarismesh/polaris/plugin/password" - _ "github.com/polarismesh/polaris/plugin/ratelimit/token" - _ "github.com/polarismesh/polaris/plugin/statis/logger" - _ "github.com/polarismesh/polaris/plugin/statis/prometheus" - "github.com/polarismesh/polaris/service/healthcheck" - "github.com/polarismesh/polaris/store" - "github.com/polarismesh/polaris/store/boltdb" - _ "github.com/polarismesh/polaris/store/boltdb" - _ "github.com/polarismesh/polaris/store/mysql" - sqldb "github.com/polarismesh/polaris/store/mysql" - testsuit "github.com/polarismesh/polaris/test/suit" -) - -const ( - tblUser string = "user" - tblStrategy string = "strategy" - tblGroup string = "group" -) - -type Bootstrap struct { - Logger map[string]*commonlog.Options -} - -type TestConfig struct { - Bootstrap Bootstrap `yaml:"bootstrap"` - Cache cache.Config `yaml:"cache"` - Namespace namespace.Config `yaml:"namespace"` - HealthChecks healthcheck.Config `yaml:"healthcheck"` - Store store.Config `yaml:"store"` - Auth auth.Config `yaml:"auth"` - Plugin plugin.Config `yaml:"plugin"` -} - -type AuthTestSuit struct { - testsuit.DiscoverTestSuit -} - -// 判断一个resp是否执行成功 -func respSuccess(resp api.ResponseMessage) bool { - - ret := api.CalcCode(resp) == 200 - - return ret -} - -type options func(cfg *TestConfig) - -func (d *AuthTestSuit) cleanAllUser() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from user where name like 'test%'"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblUser)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} - -func (d *AuthTestSuit) cleanAllUserGroup() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from user_group where name like 'test%'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from user_group_relation"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblGroup)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} - -func (d *AuthTestSuit) cleanAllAuthStrategy() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from auth_strategy where id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from auth_principal where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from auth_strategy_resource where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblStrategy)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} diff --git a/auth/policy/role.go b/auth/policy/role.go index 2417883ba..213167334 100644 --- a/auth/policy/role.go +++ b/auth/policy/role.go @@ -1,28 +1,201 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + package policy import ( "context" + "fmt" + "time" + "github.com/gogo/protobuf/jsonpb" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" + + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + commonstore "github.com/polarismesh/polaris/common/store" + "github.com/polarismesh/polaris/common/utils" ) // CreateRoles 批量创建角色 func (svr *Server) CreateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := svr.CreateRole(ctx, reqs[i]) + api.Collect(responses, rsp) + } + return api.FormatBatchWriteResponse(responses) +} + +// CreateRole 创建角色 +func (svr *Server) CreateRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { + req.Owner = utils.ParseOwnerID(ctx) + + saveData := &authcommon.Role{} + saveData.FromSpec(req) + + if err := svr.storage.AddRole(saveData); err != nil { + log.Error("[Auth][Role] create role into store", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + + return api.NewResponse(apimodel.Code_ExecuteSuccess) } // UpdateRoles 批量更新角色 func (svr *Server) UpdateRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := svr.UpdateRole(ctx, reqs[i]) + api.Collect(responses, rsp) + } + return api.FormatBatchWriteResponse(responses) +} + +// UpdateRole 批量更新角色 +func (svr *Server) UpdateRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { + newData := &authcommon.Role{} + newData.FromSpec(req) + + saveData, err := svr.storage.GetRole(newData.ID) + if err != nil { + log.Error("[Auth][Role] get one role from store", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + if saveData == nil { + log.Error("[Auth][Role] not find expect role", utils.RequestID(ctx), + zap.String("id", newData.ID)) + return api.NewAuthResponse(apimodel.Code_NotFoundResource) + } + + newData.Name = saveData.Name + newData.Owner = saveData.Owner + + if err := svr.storage.AddRole(newData); err != nil { + log.Error("[Auth][Role] update role into store", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + + return api.NewResponse(apimodel.Code_ExecuteSuccess) } // DeleteRoles 批量删除角色 func (svr *Server) DeleteRoles(ctx context.Context, reqs []*apisecurity.Role) *apiservice.BatchWriteResponse { - return nil + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := svr.DeleteRole(ctx, reqs[i]) + api.Collect(responses, rsp) + } + return api.FormatBatchWriteResponse(responses) +} + +// DeleteRole 批量删除角色 +func (svr *Server) DeleteRole(ctx context.Context, req *apisecurity.Role) *apiservice.Response { + newData := &authcommon.Role{} + newData.FromSpec(req) + + saveData, err := svr.storage.GetRole(newData.ID) + if err != nil { + log.Error("[Auth][Role] get one role from store", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + if saveData == nil { + return api.NewAuthResponse(apimodel.Code_ExecuteSuccess) + } + + tx, err := svr.storage.StartTx() + if err != nil { + log.Error("[Auth][Role] start tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + defer func() { + _ = tx.Rollback() + }() + + if err := svr.storage.DeleteRole(tx, newData); err != nil { + log.Error("[Auth][Role] update role into store", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + if err := svr.storage.CleanPrincipalPolicies(tx, authcommon.Principal{ + PrincipalID: saveData.ID, + PrincipalType: authcommon.PrincipalRole, + }); err != nil { + log.Error("[Auth][Role] clean role link policies", utils.RequestID(ctx), + zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + + if err := tx.Commit(); err != nil { + log.Error("[Auth][Role] delete role commit tx", utils.RequestID(ctx), zap.Error(err)) + return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) + } + + return api.NewResponse(apimodel.Code_ExecuteSuccess) } // GetRoles 查询角色列表 -func (svr *Server) GetRoles(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - return nil +func (svr *Server) GetRoles(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { + offset, limit, _ := utils.ParseOffsetAndLimit(filters) + + total, ret, err := svr.cacheMgr.Role().Query(ctx, cachetypes.RoleSearchArgs{ + Filters: filters, + Offset: offset, + Limit: limit, + }) + if err != nil { + log.Error("[Auth][Role] query roles list", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) + } + + rsp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) + rsp.Amount = utils.NewUInt32Value(total) + rsp.Size = utils.NewUInt32Value(uint32(len(ret))) + + for i := range ret { + if err := api.AddAnyDataIntoBatchQuery(rsp, ret[i].ToSpec()); err != nil { + log.Error("[Auth][Role] add role to query list", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) + } + } + return rsp +} + +func recordRoleEntry(ctx context.Context, req *apisecurity.Role, data *authcommon.Role, op model.OperationType) *model.RecordEntry { + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + + entry := &model.RecordEntry{ + ResourceType: model.RAuthRole, + ResourceName: fmt.Sprintf("%s(%s)", data.Name, data.ID), + OperationType: op, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + + return entry } diff --git a/auth/policy/server.go b/auth/policy/server.go index 5603f51c0..fffcae44d 100644 --- a/auth/policy/server.go +++ b/auth/policy/server.go @@ -39,6 +39,8 @@ import ( // AuthConfig 鉴权配置 type AuthConfig struct { + // Compatible 兼容模式 + Compatible bool `json:"compatible" xml:"compatible"` // ConsoleOpen 控制台是否开启鉴权 ConsoleOpen bool `json:"consoleOpen" xml:"consoleOpen"` // ClientOpen 是否开启客户端接口鉴权 @@ -52,13 +54,13 @@ type AuthConfig struct { ClientStrict bool `json:"clientStrict"` // CredibleHeaders 可信请求 Header CredibleHeaders map[string]string - // OpenPrincipalDefaultPolicy 是否开启 principal 默认策略 - OpenPrincipalDefaultPolicy bool `json:"openPrincipalDefaultPolicy"` } // DefaultAuthConfig 返回一个默认的鉴权配置 func DefaultAuthConfig() *AuthConfig { return &AuthConfig{ + // 针对旧鉴权逻辑的兼容模式 + Compatible: true, // 针对控制台接口,默认开启鉴权操作 ConsoleOpen: true, // 这里默认开启 OpenAPI 的强 Token 检查模式 @@ -104,7 +106,9 @@ func (svr *Server) Initialize(options *auth.Config, storage store.Store, cacheMg checker := &DefaultAuthChecker{ policyMgr: svr, } - checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr) + if err := checker.Initialize(svr.options, svr.storage, cacheMgr, userSvr); err != nil { + return err + } svr.checker = checker return nil } @@ -217,8 +221,7 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e log.Info("[Auth][Server] add resource to principal default strategy", zap.Any("resource", afterCtx.GetAttachments()[authcommon.ResourceAttachmentKey]), - zap.Any("add_user", addUserIds), - zap.Any("add_group", addGroupIds), zap.Any("remove_user", removeUserIds), + zap.Any("add_user", addUserIds), zap.Any("add_group", addGroupIds), zap.Any("remove_user", removeUserIds), zap.Any("remove_group", removeGroupIds), ) @@ -241,7 +244,6 @@ func (svr *Server) AfterResourceOperation(afterCtx *authcommon.AcquireContext) e log.Error("[Auth][Server] remove group link resource", zap.Error(err)) return err } - return nil } diff --git a/auth/policy/server_test.go b/auth/policy/server_test.go deleted file mode 100644 index 325113678..000000000 --- a/auth/policy/server_test.go +++ /dev/null @@ -1,302 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy_test - -import ( - "errors" - "testing" - - "github.com/golang/mock/gomock" - "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/polarismesh/polaris/auth" - authmock "github.com/polarismesh/polaris/auth/mock" - "github.com/polarismesh/polaris/auth/policy" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -func Test_AfterResourceOperation(t *testing.T) { - svr := &policy.Server{} - - t.Run("not_need_auth", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - err := svr.AfterResourceOperation(authcommon.NewAcquireContext()) - assert.NoError(t, err) - }) - - t.Run("read_op", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - err := svr.AfterResourceOperation(authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Read), - )) - assert.NoError(t, err) - }) - - t.Run("from_client_not_auth", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(false).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() - - err := svr.AfterResourceOperation(authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromClient(), - )) - assert.NoError(t, err) - }) - - t.Run("from_console_not_auth", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - err := svr.AfterResourceOperation(authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromConsole(), - )) - assert.NoError(t, err) - }) - - t.Run("not_token_detial", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - err := svr.AfterResourceOperation(authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromClient(), - )) - assert.NoError(t, err) - }) - - t.Run("invalid_token_detial", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - ctx := authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromClient(), - ) - ctx.SetAttachment(authcommon.TokenDetailInfoKey, map[string]string{}) - err := svr.AfterResourceOperation(ctx) - assert.NoError(t, err) - }) - - t.Run("empty_token_detial", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(false).AnyTimes() - - ctx := authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromClient(), - ) - t.Run("origin_empty", func(t *testing.T) { - ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ - Origin: "", - }) - err := svr.AfterResourceOperation(ctx) - assert.NoError(t, err) - }) - - t.Run("is_anonymous", func(t *testing.T) { - ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ - Origin: "123", - Anonymous: true, - }) - err := svr.AfterResourceOperation(ctx) - assert.NoError(t, err) - }) - }) - - t.Run("change_principal_policy", func(t *testing.T) { - ctrl := gomock.NewController(t) - mockChecker := authmock.NewMockAuthChecker(ctrl) - svr.MockAuthChecker(mockChecker) - mockChecker.EXPECT().IsOpenClientAuth().Return(true).AnyTimes() - mockChecker.EXPECT().IsOpenConsoleAuth().Return(true).AnyTimes() - - ctx := authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Create), - authcommon.WithFromClient(), - ) - - ownerId := "mock_auth_owner" - curUserId := "123" - - t.Run("user", func(t *testing.T) { - ctx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ - Origin: curUserId, - OperatorID: curUserId, - OwnerID: ownerId, - Role: authcommon.OwnerUserRole, - IsUserToken: true, - }) - - initMockAcquireContext(ctx) - - t.Run("not_found_user", func(t *testing.T) { - userSvr := authmock.NewMockUserServer(ctrl) - mockHelper := authmock.NewMockUserHelper(ctrl) - - userSvr.EXPECT().GetUserHelper().Return(mockHelper) - mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(nil) - - svr.MockUserServer(userSvr) - - err := svr.AfterResourceOperation(ctx) - assert.Error(t, err) - assert.Equal(t, "not found target user", err.Error()) - }) - - t.Run("found_user", func(t *testing.T) { - userSvr := authmock.NewMockUserServer(ctrl) - mockHelper := authmock.NewMockUserHelper(ctrl) - - userSvr.EXPECT().GetUserHelper().Return(mockHelper).AnyTimes() - mockHelper.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(&security.User{ - Id: wrapperspb.String(curUserId), - Owner: wrapperspb.String(ownerId), - }).AnyTimes() - - svr.MockUserServer(userSvr) - - t.Run("store_has_err", func(t *testing.T) { - sctrl := gomock.NewController(t) - defer sctrl.Finish() - mockStore := storemock.NewMockStore(sctrl) - - mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock_err")) - svr.MockStore(mockStore) - - err := svr.AfterResourceOperation(ctx) - assert.Error(t, err) - assert.Equal(t, "mock_err", err.Error()) - }) - - t.Run("not_found_default_policy", func(t *testing.T) { - sctrl := gomock.NewController(t) - defer sctrl.Finish() - mockStore := storemock.NewMockStore(sctrl) - - mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - svr.MockStore(mockStore) - - err := svr.AfterResourceOperation(ctx) - assert.Error(t, err) - assert.Equal(t, "not found default strategy rule", err.Error()) - }) - - t.Run("not_op_resource", func(t *testing.T) { - sctrl := gomock.NewController(t) - defer sctrl.Finish() - mockStore := storemock.NewMockStore(sctrl) - - mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) - svr.MockStore(mockStore) - - err := svr.AfterResourceOperation(ctx) - assert.NoError(t, err) - }) - - t.Run("invalid_op_resource", func(t *testing.T) { - sctrl := gomock.NewController(t) - defer sctrl.Finish() - mockStore := storemock.NewMockStore(sctrl) - - mockStore.EXPECT().GetDefaultStrategyDetailByPrincipal(gomock.Any(), gomock.Any()).Return(&authcommon.StrategyDetail{}, nil) - svr.MockStore(mockStore) - - ctx.SetAttachment(authcommon.ResourceAttachmentKey, map[string]interface{}{}) - - err := svr.AfterResourceOperation(ctx) - assert.NoError(t, err) - }) - - t.Run("delete_resource", func(t *testing.T) { - delCtx := authcommon.NewAcquireContext( - authcommon.WithOperation(authcommon.Delete), - authcommon.WithFromClient(), - ) - delCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.OperatorInfo{ - Origin: curUserId, - OperatorID: curUserId, - OwnerID: ownerId, - Role: authcommon.OwnerUserRole, - IsUserToken: true, - }) - - initMockAcquireContext(delCtx) - - sctrl := gomock.NewController(t) - defer sctrl.Finish() - mockStore := storemock.NewMockStore(sctrl) - mockStore.EXPECT().RemoveStrategyResources(gomock.Any()).DoAndReturn(func(args interface{}) error { - resources := args.([]authcommon.StrategyResource) - for i := range resources { - assert.True(t, resources[i].StrategyID == "", utils.MustJson(resources[i])) - } - return nil - }).Times(1) - - svr.MockStore(mockStore) - err := svr.AfterResourceOperation(delCtx) - assert.NoError(t, err) - }) - }) - }) - }) -} - -func initMockAcquireContext(ctx *authcommon.AcquireContext) { - ctx.SetAttachment(authcommon.LinkUsersKey, []string{}) - ctx.SetAttachment(authcommon.LinkGroupsKey, []string{}) - ctx.SetAttachment(authcommon.RemoveLinkUsersKey, []string{}) - ctx.SetAttachment(authcommon.RemoveLinkGroupsKey, []string{}) -} diff --git a/auth/policy/strategy.go b/auth/policy/strategy.go index 7e48ed12b..022d39d2c 100644 --- a/auth/policy/strategy.go +++ b/auth/policy/strategy.go @@ -20,6 +20,7 @@ package policy import ( "context" "fmt" + "reflect" "strconv" "strings" "time" @@ -43,16 +44,12 @@ import ( type ( // StrategyDetail2Api strategy detail to *apisecurity.AuthStrategy func - StrategyDetail2Api func(user *authcommon.StrategyDetail) *apisecurity.AuthStrategy + StrategyDetail2Api func(ctx context.Context, user *authcommon.StrategyDetail) *apisecurity.AuthStrategy ) // CreateStrategy 创建鉴权策略 func (svr *Server) CreateStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { req.Owner = utils.NewStringValue(utils.ParseOwnerID(ctx)) - if checkErrResp := svr.checkCreateStrategy(req); checkErrResp != nil { - return checkErrResp - } - req.Resources = svr.normalizeResource(req.Resources) data := svr.createAuthStrategyModel(req) @@ -100,11 +97,9 @@ func (svr *Server) UpdateStrategies( // Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 // Case 3. 主账户的默认策略不得修改 func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - strategy, err := svr.storage.GetStrategyDetail(req.GetId().GetValue()) if err != nil { - log.Error("[Auth][Strategy] get strategy from store", utils.ZapRequestID(requestID), + log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), zap.Error(err)) return api.NewModifyAuthStrategyResponse(commonstore.StoreCode2APICode(err), req) } @@ -112,10 +107,6 @@ func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAu return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundAuthStrategyRule, req) } - if checkErrResp := svr.checkUpdateStrategy(ctx, req, strategy); checkErrResp != nil { - return checkErrResp - } - req.AddResources = svr.normalizeResource(req.AddResources) data, needUpdate := svr.updateAuthStrategyAttribute(ctx, req, strategy) if !needUpdate { @@ -124,11 +115,11 @@ func (svr *Server) UpdateStrategy(ctx context.Context, req *apisecurity.ModifyAu if err := svr.storage.UpdateStrategy(data); err != nil { log.Error("[Auth][Strategy] update strategy into store", - utils.ZapRequestID(requestID), zap.Error(err)) + utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } - log.Info("[Auth][Strategy] update strategy into store", utils.ZapRequestID(requestID), + log.Info("[Auth][Strategy] update strategy into store", utils.RequestID(ctx), zap.String("name", strategy.Name)) svr.RecordHistory(authModifyStrategyRecordEntry(ctx, req, data, model.OUpdate)) @@ -151,11 +142,9 @@ func (svr *Server) DeleteStrategies( // Case 1. 只有该策略的 owner 账户可以删除策略 // Case 2. 默认策略不能被删除,默认策略只能随着账户的删除而被清理 func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStrategy) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - strategy, err := svr.storage.GetStrategyDetail(req.GetId().GetValue()) if err != nil { - log.Error("[Auth][Strategy] get strategy from store", utils.ZapRequestID(requestID), + log.Error("[Auth][Strategy] get strategy from store", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthStrategyResponse(commonstore.StoreCode2APICode(err), req) } @@ -165,7 +154,7 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra } if strategy.Default { - log.Error("[Auth][Strategy] delete default strategy is denied", utils.ZapRequestID(requestID)) + log.Error("[Auth][Strategy] delete default strategy is denied", utils.RequestID(ctx)) return api.NewAuthStrategyResponseWithMsg(apimodel.Code_BadRequest, "default strategy can't delete", req) } @@ -175,11 +164,11 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra if err := svr.storage.DeleteStrategy(req.GetId().GetValue()); err != nil { log.Error("[Auth][Strategy] delete strategy from store", - utils.ZapRequestID(requestID), zap.Error(err)) + utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } - log.Info("[Auth][Strategy] delete strategy from store", utils.ZapRequestID(requestID), + log.Info("[Auth][Strategy] delete strategy from store", utils.RequestID(ctx), zap.String("name", req.Name.GetValue())) svr.RecordHistory(authStrategyRecordEntry(ctx, req, strategy, model.ODelete)) @@ -198,10 +187,13 @@ func (svr *Server) DeleteStrategy(ctx context.Context, req *apisecurity.AuthStra // a. 如果当前是超级管理账户,则按照传入的 query 进行查询即可 // b. 如果当前是主账户,则自动注入 owner 字段,即只能查看策略的 owner 是自己的策略 // c. 如果当前是子账户,则自动注入 principal_id 以及 principal_type 字段,即稚嫩查询与自己有关的策略 - func (svr *Server) GetStrategies(ctx context.Context, filters map[string]string) *apiservice.BatchQueryResponse { filters = ParseStrategySearchArgs(ctx, filters) offset, limit, _ := utils.ParseOffsetAndLimit(filters) + + // 透传兼容模式信息数据 + ctx = context.WithValue(ctx, model.ContextKeyCompatible{}, svr.options.Compatible) + total, strategies, err := svr.cacheMgr.AuthStrategy().Query(ctx, cachetypes.PolicySearchArgs{ Filters: filters, Offset: offset, @@ -219,9 +211,9 @@ func (svr *Server) GetStrategies(ctx context.Context, filters map[string]string) if strings.Compare(filters["show_detail"], "true") == 0 { log.Info("[Auth][Strategy] fill strategy detail", utils.RequestID(ctx)) - resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategyFull2Api) + resp.AuthStrategies = enhancedAuthStrategy2Api(ctx, strategies, svr.authStrategyFull2Api) } else { - resp.AuthStrategies = enhancedAuthStrategy2Api(strategies, svr.authStrategy2Api) + resp.AuthStrategies = enhancedAuthStrategy2Api(ctx, strategies, svr.authStrategy2Api) } return resp @@ -258,22 +250,6 @@ func ParseStrategySearchArgs(ctx context.Context, searchFilters map[string]strin searchFilters["principal_type"] = "1" } } - - if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { - // 如果当前账户不是 admin 角色,既不是走资源视角查看,也不是指定principal查看,那么只能查询当前操作用户被关联到的鉴权策略, - if _, ok := searchFilters["res_id"]; !ok { - // 设置 owner 参数,只能查看对应 owner 下的策略 - searchFilters["owner"] = utils.ParseOwnerID(ctx) - if _, ok := searchFilters["principal_id"]; !ok { - // 如果当前不是 owner 角色,那么只能查询与自己有关的策略 - if !utils.ParseIsOwner(ctx) { - searchFilters["principal_id"] = utils.ParseUserID(ctx) - searchFilters["principal_type"] = strconv.Itoa(int(authcommon.PrincipalUser)) - } - } - } - } - return searchFilters } @@ -336,12 +312,11 @@ func (svr *Server) GetStrategy(ctx context.Context, req *apisecurity.AuthStrateg return api.NewAuthStrategyResponse(apimodel.Code_NotAllowedAccess, req) } - return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, svr.authStrategyFull2Api(ret)) + return api.NewAuthStrategyResponse(apimodel.Code_ExecuteSuccess, svr.authStrategyFull2Api(ctx, ret)) } // GetPrincipalResources 获取某个principal可以获取到的所有资源ID数据信息 func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]string) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) if len(query) == 0 { return api.NewAuthResponse(apimodel.Code_EmptyRequest) } @@ -377,7 +352,7 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s item := groups[i] res, err := svr.storage.GetStrategyResources(item.GetId().GetValue(), authcommon.PrincipalGroup) if err != nil { - log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), + log.Error("[Auth][Strategy] get principal link resource", utils.RequestID(ctx), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } @@ -387,21 +362,17 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s pResources, err := svr.storage.GetStrategyResources(principalId, authcommon.PrincipalType(principalRole)) if err != nil { - log.Error("[Auth][Strategy] get principal link resource", utils.ZapRequestID(requestID), + log.Error("[Auth][Strategy] get principal link resource", utils.RequestID(ctx), zap.String("principal-id", principalId), zap.Any("principal-role", principalRole), zap.Error(err)) return api.NewAuthResponse(commonstore.StoreCode2APICode(err)) } resources = append(resources, pResources...) tmp := &apisecurity.AuthStrategy{ - Resources: &apisecurity.StrategyResources{ - Namespaces: make([]*apisecurity.StrategyResourceEntry, 0), - Services: make([]*apisecurity.StrategyResourceEntry, 0), - ConfigGroups: make([]*apisecurity.StrategyResourceEntry, 0), - }, + Resources: &apisecurity.StrategyResources{}, } - svr.fillResourceInfo(tmp, &authcommon.StrategyDetail{ + svr.enrichResourceInfo(ctx, tmp, &authcommon.StrategyDetail{ Resources: resourceDeduplication(resources), }) @@ -409,16 +380,17 @@ func (svr *Server) GetPrincipalResources(ctx context.Context, query map[string]s } // enhancedAuthStrategy2Api -func enhancedAuthStrategy2Api(s []*authcommon.StrategyDetail, fn StrategyDetail2Api) []*apisecurity.AuthStrategy { +func enhancedAuthStrategy2Api(ctx context.Context, s []*authcommon.StrategyDetail, + fn StrategyDetail2Api) []*apisecurity.AuthStrategy { out := make([]*apisecurity.AuthStrategy, 0, len(s)) for k := range s { - out = append(out, fn(s[k])) + out = append(out, fn(ctx, s[k])) } return out } // authStrategy2Api -func (svr *Server) authStrategy2Api(s *authcommon.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategy2Api(ctx context.Context, s *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if s == nil { return nil } @@ -439,7 +411,7 @@ func (svr *Server) authStrategy2Api(s *authcommon.StrategyDetail) *apisecurity.A } // authStrategyFull2Api -func (svr *Server) authStrategyFull2Api(data *authcommon.StrategyDetail) *apisecurity.AuthStrategy { +func (svr *Server) authStrategyFull2Api(ctx context.Context, data *authcommon.StrategyDetail) *apisecurity.AuthStrategy { if data == nil { return nil } @@ -465,36 +437,29 @@ func (svr *Server) authStrategyFull2Api(data *authcommon.StrategyDetail) *apisec Mtime: utils.NewStringValue(commontime.Time2String(data.ModifyTime)), Action: apisecurity.AuthAction(apisecurity.AuthAction_value[data.Action]), DefaultStrategy: utils.NewBoolValue(data.Default), + Functions: data.CalleeMethods, + Metadata: data.Metadata, } - svr.fillPrincipalInfo(out, data) - svr.fillResourceInfo(out, data) + svr.enrichPrincipalInfo(out, data) + svr.enrichResourceInfo(ctx, out, data) return out } // createAuthStrategyModel 创建鉴权策略的存储模型 func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) *authcommon.StrategyDetail { - ret := &authcommon.StrategyDetail{ - ID: utils.NewUUID(), - Name: strategy.Name.GetValue(), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: strategy.Comment.GetValue(), - Default: false, - Owner: strategy.Owner.GetValue(), - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Now(), - ModifyTime: time.Now(), - } + ret := &authcommon.StrategyDetail{} + ret.FromSpec(strategy) // 收集涉及的资源信息 resEntry := make([]authcommon.StrategyResource, 0, 20) - resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Namespaces, - strategy.GetResources().GetNamespaces(), false)...) - resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_Services, - strategy.GetResources().GetServices(), false)...) - resEntry = append(resEntry, svr.collectResEntry(ret.ID, apisecurity.ResourceType_ConfigGroups, - strategy.GetResources().GetConfigGroups(), false)...) + for resType, ptrGetter := range resourceFieldPointerGetters { + slicePtr := ptrGetter(strategy.Resources) + if slicePtr.Elem().IsNil() { + continue + } + resEntry = append(resEntry, svr.collectResourceEntry(ret.ID, resType, slicePtr.Elem(), false)...) + } // 收集涉及的 principal 信息 principals := make([]authcommon.Principal, 0, 20) @@ -502,6 +467,8 @@ func (svr *Server) createAuthStrategyModel(strategy *apisecurity.AuthStrategy) * strategy.GetPrincipals().GetUsers())...) principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalGroup, strategy.GetPrincipals().GetGroups())...) + principals = append(principals, collectPrincipalEntry(ret.ID, authcommon.PrincipalRole, + strategy.GetPrincipals().GetRoles())...) ret.Resources = resEntry ret.Principals = principals @@ -514,11 +481,14 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap saved *authcommon.StrategyDetail) (*authcommon.ModifyStrategyDetail, bool) { var needUpdate bool ret := &authcommon.ModifyStrategyDetail{ - ID: strategy.Id.GetValue(), - Name: saved.Name, - Action: saved.Action, - Comment: saved.Comment, - ModifyTime: time.Now(), + ID: strategy.Id.GetValue(), + Name: saved.Name, + Action: saved.Action, + Comment: saved.Comment, + ModifyTime: time.Now(), + CalleeMethods: saved.CalleeMethods, + Conditions: saved.Conditions, + Metadata: saved.Metadata, } // 只有 owner 可以修改的属性 @@ -540,6 +510,32 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap if computePrincipalChange(ret, strategy) { needUpdate = true } + if strategy.Functions != nil { + needUpdate = true + ret.CalleeMethods = strategy.Functions + } + if strategy.Metadata != nil { + needUpdate = true + ret.Metadata = strategy.Metadata + } + if strategy.Action != saved.GetAction() { + needUpdate = true + ret.Action = strategy.GetAction().String() + } + if strategy.ResourceLabels != nil { + needUpdate = true + ret.Conditions = func() []authcommon.Condition { + conditions := make([]authcommon.Condition, 0, len(strategy.GetResourceLabels())) + for index := range strategy.GetResourceLabels() { + conditions = append(conditions, authcommon.Condition{ + Key: strategy.GetResourceLabels()[index].GetKey(), + Value: strategy.GetResourceLabels()[index].GetValue(), + CompareFunc: strategy.GetResourceLabels()[index].GetCompareType(), + }) + } + return conditions + }() + } return ret, needUpdate } @@ -548,13 +544,16 @@ func (svr *Server) updateAuthStrategyAttribute(ctx context.Context, strategy *ap func (svr *Server) computeResourceChange( modify *authcommon.ModifyStrategyDetail, strategy *apisecurity.ModifyAuthStrategy) bool { var needUpdate bool + + // 收集涉及的资源信息 addResEntry := make([]authcommon.StrategyResource, 0) - addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, - strategy.GetAddResources().GetNamespaces(), false)...) - addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, - strategy.GetAddResources().GetServices(), false)...) - addResEntry = append(addResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_ConfigGroups, - strategy.GetAddResources().GetConfigGroups(), false)...) + for resType, ptrGetter := range resourceFieldPointerGetters { + slicePtr := ptrGetter(strategy.AddResources) + if slicePtr.Elem().IsNil() { + continue + } + addResEntry = append(addResEntry, svr.collectResourceEntry(modify.ID, resType, slicePtr.Elem(), false)...) + } if len(addResEntry) != 0 { needUpdate = true @@ -562,12 +561,13 @@ func (svr *Server) computeResourceChange( } removeResEntry := make([]authcommon.StrategyResource, 0) - removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Namespaces, - strategy.GetRemoveResources().GetNamespaces(), true)...) - removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_Services, - strategy.GetRemoveResources().GetServices(), true)...) - removeResEntry = append(removeResEntry, svr.collectResEntry(modify.ID, apisecurity.ResourceType_ConfigGroups, - strategy.GetRemoveResources().GetConfigGroups(), true)...) + for resType, ptrGetter := range resourceFieldPointerGetters { + slicePtr := ptrGetter(strategy.RemoveResources) + if slicePtr.Elem().IsNil() { + continue + } + removeResEntry = append(removeResEntry, svr.collectResourceEntry(modify.ID, resType, slicePtr.Elem(), true)...) + } if len(removeResEntry) != 0 { needUpdate = true @@ -585,6 +585,8 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a strategy.GetAddPrincipals().GetUsers())...) addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetAddPrincipals().GetGroups())...) + addPrincipals = append(addPrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalRole, + strategy.GetAddPrincipals().GetRoles())...) if len(addPrincipals) != 0 { needUpdate = true @@ -596,6 +598,8 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a strategy.GetRemovePrincipals().GetUsers())...) removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalGroup, strategy.GetRemovePrincipals().GetGroups())...) + removePrincipals = append(removePrincipals, collectPrincipalEntry(modify.ID, authcommon.PrincipalRole, + strategy.GetRemovePrincipals().GetRoles())...) if len(removePrincipals) != 0 { needUpdate = true @@ -605,19 +609,26 @@ func computePrincipalChange(modify *authcommon.ModifyStrategyDetail, strategy *a return needUpdate } +type pbStringValue interface { + GetValue() string +} + // collectResEntry 将资源ID转换为对应的 []authcommon.StrategyResource 数组 -func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceType, - res []*apisecurity.StrategyResourceEntry, delete bool) []authcommon.StrategyResource { - resEntries := make([]authcommon.StrategyResource, 0, len(res)+1) - if len(res) == 0 { - return resEntries +func (svr *Server) collectResourceEntry(ruleId string, resType apisecurity.ResourceType, + res reflect.Value, delete bool) []authcommon.StrategyResource { + if res.Kind() != reflect.Slice || res.Len() == 0 { + return []authcommon.StrategyResource{} } - for index := range res { + resEntries := make([]authcommon.StrategyResource, 0, res.Len()) + for i := 0; i < res.Len(); i++ { + item := res.Index(i).Elem() + resId := item.FieldByName("Id").Interface().(pbStringValue) + resName := item.FieldByName("Name").Interface().(pbStringValue) // 如果是添加的动作,那么需要进行归一化处理 if !delete { // 归一化处理 - if res[index].GetId().GetValue() == "*" || res[index].GetName().GetValue() == "*" { + if resId.GetValue() == "*" || resName.GetValue() == "*" { return []authcommon.StrategyResource{ { StrategyID: ruleId, @@ -631,7 +642,7 @@ func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceTy entry := authcommon.StrategyResource{ StrategyID: ruleId, ResType: int32(resType), - ResID: res[index].GetId().GetValue(), + ResID: resId.GetValue(), } resEntries = append(resEntries, entry) @@ -640,94 +651,165 @@ func (svr *Server) collectResEntry(ruleId string, resType apisecurity.ResourceTy return resEntries } -// collectPrincipalEntry 将 Principal 转换为对应的 []authcommon.Principal 数组 -func collectPrincipalEntry(ruleID string, uType authcommon.PrincipalType, res []*apisecurity.Principal) []authcommon.Principal { - principals := make([]authcommon.Principal, 0, len(res)+1) - if len(res) == 0 { - return principals +// normalizeResource 对于资源进行归一化处理, 如果出现 * 的话,则该资源访问策略就是 * +func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) *apisecurity.StrategyResources { + if resources == nil { + return &apisecurity.StrategyResources{} } - - for index := range res { - principals = append(principals, authcommon.Principal{ - StrategyID: ruleID, - PrincipalID: res[index].GetId().GetValue(), - PrincipalType: uType, - }) + for _, ptrGetter := range resourceFieldPointerGetters { + slicePtr := ptrGetter(resources) + if slicePtr.Elem().IsNil() { + continue + } + sliceVal := slicePtr.Elem() + for i := 0; i < sliceVal.Len(); i++ { + item := sliceVal.Index(i).Elem() + resId := item.FieldByName("Id").Interface().(pbStringValue) + if resId.GetValue() == utils.MatchAll { + sliceVal.Set(reflect.ValueOf([]*apisecurity.StrategyResourceEntry{{ + Id: utils.NewStringValue("*"), + }})) + } + } } - - return principals + return resources } -// checkCreateStrategy 检查创建鉴权策略的请求 -func (svr *Server) checkCreateStrategy(req *apisecurity.AuthStrategy) *apiservice.Response { - // 检查名称信息 - if err := CheckName(req.GetName()); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_InvalidUserName, req) +// enrichPrincipalInfo 填充 principal 摘要信息 +func (svr *Server) enrichPrincipalInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { + users := make([]*apisecurity.Principal, 0, len(data.Principals)) + groups := make([]*apisecurity.Principal, 0, len(data.Principals)) + roles := make([]*apisecurity.Principal, 0, len(data.Principals)) + for index := range data.Principals { + principal := data.Principals[index] + switch principal.PrincipalType { + case authcommon.PrincipalUser: + if user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ + Id: wrapperspb.String(principal.PrincipalID), + }); user != nil { + users = append(users, &apisecurity.Principal{ + Id: utils.NewStringValue(user.GetId().GetValue()), + Name: utils.NewStringValue(user.GetName().GetValue()), + }) + } + case authcommon.PrincipalGroup: + if group := svr.userSvr.GetUserHelper().GetGroup(context.TODO(), &apisecurity.UserGroup{ + Id: wrapperspb.String(principal.PrincipalID), + }); group != nil { + groups = append(groups, &apisecurity.Principal{ + Id: utils.NewStringValue(group.GetId().GetValue()), + Name: utils.NewStringValue(group.GetName().GetValue()), + }) + } + case authcommon.PrincipalRole: + if role := svr.PolicyHelper().GetRole(principal.PrincipalID); role != nil { + roles = append(roles, &apisecurity.Principal{ + Id: utils.NewStringValue(role.ID), + Name: utils.NewStringValue(role.Name), + }) + } + } } - // 检查用户是否存在 - if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetPrincipals())); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUser, req) + + resp.Principals = &apisecurity.Principals{ + Users: users, + Groups: groups, + Roles: roles, } - // 检查用户组是否存在 - if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetPrincipals())); err != nil { - return api.NewAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) +} + +// enrichResourceInfo 填充资源摘要信息 +func (svr *Server) enrichResourceInfo(ctx context.Context, resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { + allMatch := map[apisecurity.ResourceType]struct{}{} + resp.Resources = &apisecurity.StrategyResources{ + Namespaces: make([]*apisecurity.StrategyResourceEntry, 0, 4), + ConfigGroups: make([]*apisecurity.StrategyResourceEntry, 0, 4), + Services: make([]*apisecurity.StrategyResourceEntry, 0, 4), + RouteRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), + RatelimitRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), + CircuitbreakerRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), + FaultdetectRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), + LaneRules: make([]*apisecurity.StrategyResourceEntry, 0, 4), + Users: make([]*apisecurity.StrategyResourceEntry, 0, 4), + UserGroups: make([]*apisecurity.StrategyResourceEntry, 0, 4), + Roles: make([]*apisecurity.StrategyResourceEntry, 0, 4), + AuthPolicies: make([]*apisecurity.StrategyResourceEntry, 0, 4), } - // 检查资源是否存在 - if errResp := svr.checkResourceExist(req.GetResources()); errResp != nil { - return errResp + + for index := range data.Resources { + res := data.Resources[index] + svr.enrichResourceDetial(ctx, res, allMatch, resp) } - return nil } -// checkUpdateStrategy 检查更新鉴权策略的请求 -// Case 1. 修改的是默认鉴权策略的话,只能修改资源,不能添加用户 or 用户组 -// Case 2. 鉴权策略只能被自己的 owner 对应的用户修改 -func (svr *Server) checkUpdateStrategy(ctx context.Context, req *apisecurity.ModifyAuthStrategy, - saved *authcommon.StrategyDetail) *apiservice.Response { - userId := utils.ParseUserID(ctx) - if authcommon.ParseUserRole(ctx) != authcommon.AdminUserRole { - if !utils.ParseIsOwner(ctx) || userId != saved.Owner { - log.Error("[Auth][Strategy] modify strategy denied, current user not owner", - utils.ZapRequestID(utils.ParseRequestID(ctx)), - zap.String("user", userId), - zap.String("owner", saved.Owner), - zap.String("strategy", saved.ID)) - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowedAccess, req) +func (svr *Server) enrichResourceDetial(ctx context.Context, item authcommon.StrategyResource, + allMatch map[apisecurity.ResourceType]struct{}, resp *apisecurity.AuthStrategy) { + + resType := apisecurity.ResourceType(item.ResType) + slicePtr := resourceFieldPointerGetters[resType](resp.Resources) + if slicePtr.Elem().IsNil() { + return + } + sliceVal := slicePtr.Elem() + + if item.ResID == "*" { + allMatch[resType] = struct{}{} + sliceVal.Set(reflect.ValueOf([]*apisecurity.StrategyResourceEntry{ + { + Id: utils.NewStringValue("*"), + Namespace: utils.NewStringValue("*"), + Name: utils.NewStringValue("*"), + }, + })) + return + } + if _, ok := allMatch[resType]; !ok { + if data := resourceConvert[resType](ctx, svr, item); data != nil { + // 创建一个新数组并把元素的值追加进去 + resArr := reflect.Append(sliceVal, reflect.ValueOf(data)) + sliceVal.Set(resArr) } } +} - if saved.Default { - if len(req.AddPrincipals.Users) != 0 || - len(req.AddPrincipals.Groups) != 0 || - len(req.RemovePrincipals.Groups) != 0 || - len(req.RemovePrincipals.Users) != 0 { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotAllowModifyDefaultStrategyPrincipal, req) - } +// filter different types of Strategy resources +func resourceDeduplication(resources []authcommon.StrategyResource) []authcommon.StrategyResource { + rLen := len(resources) + ret := make([]authcommon.StrategyResource, 0, rLen) + filters := map[apisecurity.ResourceType]map[string]struct{}{} - // 主账户的默认策略禁止编辑 - if len(saved.Principals) == 1 && saved.Principals[0].PrincipalType == authcommon.PrincipalUser { - if saved.Principals[0].PrincipalID == utils.ParseOwnerID(ctx) { - return api.NewAuthResponse(apimodel.Code_NotAllowModifyOwnerDefaultStrategy) - } + est := struct{}{} + for i := range resources { + res := resources[i] + filter, ok := filters[apisecurity.ResourceType(res.ResType)] + if !ok { + filters[apisecurity.ResourceType(res.ResType)] = map[string]struct{}{} + filter = filters[apisecurity.ResourceType(res.ResType)] + } + if _, exist := filter[res.ResID]; !exist { + filter[res.ResID] = est + ret = append(ret, res) } } + return ret +} - // 检查用户是否存在 - if err := svr.checkUserExist(convertPrincipalsToUsers(req.GetAddPrincipals())); err != nil { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUser, req) - } - - // 检查用户组是否存 - if err := svr.checkGroupExist(convertPrincipalsToGroups(req.GetAddPrincipals())); err != nil { - return api.NewModifyAuthStrategyResponse(apimodel.Code_NotFoundUserGroup, req) +// collectPrincipalEntry 将 Principal 转换为对应的 []authcommon.Principal 数组 +func collectPrincipalEntry(ruleID string, uType authcommon.PrincipalType, res []*apisecurity.Principal) []authcommon.Principal { + principals := make([]authcommon.Principal, 0, len(res)+1) + if len(res) == 0 { + return principals } - // 检查资源是否存在 - if errResp := svr.checkResourceExist(req.GetAddResources()); errResp != nil { - return errResp + for index := range res { + principals = append(principals, authcommon.Principal{ + StrategyID: ruleID, + PrincipalID: res[index].GetId().GetValue(), + PrincipalType: uType, + }) } - return nil + return principals } // authStrategyRecordEntry 转换为鉴权策略的记录结构体 @@ -769,295 +851,266 @@ func authModifyStrategyRecordEntry( return entry } -func convertPrincipalsToUsers(principals *apisecurity.Principals) []*apisecurity.User { - if principals == nil { - return make([]*apisecurity.User, 0) - } - - users := make([]*apisecurity.User, 0, len(principals.Users)) - for k := range principals.GetUsers() { - user := principals.GetUsers()[k] - users = append(users, &apisecurity.User{ - Id: user.Id, - }) - } - - return users -} - -func convertPrincipalsToGroups(principals *apisecurity.Principals) []*apisecurity.UserGroup { - if principals == nil { - return make([]*apisecurity.UserGroup, 0) - } - - groups := make([]*apisecurity.UserGroup, 0, len(principals.Groups)) - for k := range principals.GetGroups() { - group := principals.GetGroups()[k] - groups = append(groups, &apisecurity.UserGroup{ - Id: group.Id, - }) - } - - return groups -} - -// checkUserExist 检查用户是否存在 -func (svr *Server) checkUserExist(users []*apisecurity.User) error { - if len(users) == 0 { - return nil - } - return svr.userSvr.GetUserHelper().CheckUsersExist(context.TODO(), users) -} - -// checkUserGroupExist 检查用户组是否存在 -func (svr *Server) checkGroupExist(groups []*apisecurity.UserGroup) error { - if len(groups) == 0 { - return nil - } - return svr.userSvr.GetUserHelper().CheckGroupsExist(context.TODO(), groups) -} - -// checkResourceExist 检查资源是否存在 -func (svr *Server) checkResourceExist(resources *apisecurity.StrategyResources) *apiservice.Response { - namespaces := resources.GetNamespaces() - - nsCache := svr.cacheMgr.Namespace() - for index := range namespaces { - val := namespaces[index] - if val.GetId().GetValue() == "*" { - break - } - ns := nsCache.GetNamespace(val.GetId().GetValue()) - if ns == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundNamespace) - } - } - - services := resources.GetServices() - svcCache := svr.cacheMgr.Service() - for index := range services { - val := services[index] - if val.GetId().GetValue() == "*" { - break - } - svc := svcCache.GetServiceByID(val.GetId().GetValue()) - if svc == nil { - return api.NewAuthResponse(apimodel.Code_NotFoundService) - } +var ( + resourceFieldNames = map[string]apisecurity.ResourceType{ + "namespaces": apisecurity.ResourceType_Namespaces, + "service": apisecurity.ResourceType_Services, + "config_groups": apisecurity.ResourceType_ConfigGroups, + "route_rules": apisecurity.ResourceType_RouteRules, + "ratelimit_rules": apisecurity.ResourceType_RateLimitRules, + "circuitbreaker_rules": apisecurity.ResourceType_CircuitBreakerRules, + "faultdetect_rules": apisecurity.ResourceType_FaultDetectRules, + "lane_rules": apisecurity.ResourceType_LaneRules, + "users": apisecurity.ResourceType_Users, + "user_groups": apisecurity.ResourceType_UserGroups, + "roles": apisecurity.ResourceType_Roles, + "auth_policies": apisecurity.ResourceType_PolicyRules, + } + + resourceFieldPointerGetters = map[apisecurity.ResourceType]func(*apisecurity.StrategyResources) reflect.Value{ + apisecurity.ResourceType_Namespaces: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetNamespaces() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.Namespaces) + }, + apisecurity.ResourceType_Services: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetServices() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.Services) + }, + apisecurity.ResourceType_ConfigGroups: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetConfigGroups() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.ConfigGroups) + }, + apisecurity.ResourceType_RouteRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetRouteRules() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.RouteRules) + }, + apisecurity.ResourceType_RateLimitRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetRatelimitRules() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.RatelimitRules) + }, + apisecurity.ResourceType_CircuitBreakerRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetCircuitbreakerRules() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.CircuitbreakerRules) + }, + apisecurity.ResourceType_FaultDetectRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetFaultdetectRules() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.FaultdetectRules) + }, + apisecurity.ResourceType_LaneRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetLaneRules() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.LaneRules) + }, + apisecurity.ResourceType_Users: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetUsers() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.Users) + }, + apisecurity.ResourceType_UserGroups: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetUserGroups() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.UserGroups) + }, + apisecurity.ResourceType_Roles: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetRoles() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.Roles) + }, + apisecurity.ResourceType_PolicyRules: func(as *apisecurity.StrategyResources) reflect.Value { + if as.GetAuthPolicies() == nil { + return reflect.ValueOf(&[]*apisecurity.StrategyResourceEntry{}) + } + return reflect.ValueOf(&as.AuthPolicies) + }, } - return nil -} - -// normalizeResource 对于资源进行归一化处理, 如果出现 * 的话,则该资源访问策略就是 * -func (svr *Server) normalizeResource(resources *apisecurity.StrategyResources) *apisecurity.StrategyResources { - namespaces := resources.GetNamespaces() - for index := range namespaces { - val := namespaces[index] - if val.GetId().GetValue() == "*" { - resources.Namespaces = []*apisecurity.StrategyResourceEntry{{ - Id: utils.NewStringValue("*"), - }} - break - } - } - services := resources.GetServices() - for index := range services { - val := services[index] - if val.GetId().GetValue() == "*" { - resources.Services = []*apisecurity.StrategyResourceEntry{{ - Id: utils.NewStringValue("*"), - }} - break - } - } - return resources -} + resourceConvert = map[apisecurity.ResourceType]func(context.Context, + *Server, authcommon.StrategyResource) *apisecurity.StrategyResourceEntry{ -// fillPrincipalInfo 填充 principal 摘要信息 -func (svr *Server) fillPrincipalInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { - users := make([]*apisecurity.Principal, 0, len(data.Principals)) - groups := make([]*apisecurity.Principal, 0, len(data.Principals)) - for index := range data.Principals { - principal := data.Principals[index] - if principal.PrincipalType == authcommon.PrincipalUser { - user := svr.userSvr.GetUserHelper().GetUser(context.TODO(), &apisecurity.User{ - Id: wrapperspb.String(principal.PrincipalID), - }) + // 注册、配置、治理 + apisecurity.ResourceType_Namespaces: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.Namespace().GetNamespace(item.ResID) if user == nil { - continue + log.Warn("[Auth][Strategy] not found namespace in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil } - users = append(users, &apisecurity.Principal{ - Id: utils.NewStringValue(user.GetId().GetValue()), - Name: utils.NewStringValue(user.GetName().GetValue()), - }) - } else { - group := svr.userSvr.GetUserHelper().GetGroup(context.TODO(), &apisecurity.UserGroup{ - Id: wrapperspb.String(principal.PrincipalID), - }) - if group == nil { - continue + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), } - groups = append(groups, &apisecurity.Principal{ - Id: utils.NewStringValue(group.GetId().GetValue()), - Name: utils.NewStringValue(group.GetName().GetValue()), - }) - } - } - - resp.Principals = &apisecurity.Principals{ - Users: users, - Groups: groups, - } -} - -// fillResourceInfo 填充资源摘要信息 -func (svr *Server) fillResourceInfo(resp *apisecurity.AuthStrategy, data *authcommon.StrategyDetail) { - namespaces := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) - services := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) - configGroups := make([]*apisecurity.StrategyResourceEntry, 0, len(data.Resources)) - - var ( - autoAllNs bool - autoAllSvc bool - autoAllConfigGroup bool - ) - - for index := range data.Resources { - res := data.Resources[index] - switch res.ResType { - case int32(apisecurity.ResourceType_Namespaces): - if res.ResID == "*" { - autoAllNs = true - namespaces = []*apisecurity.StrategyResourceEntry{ - { - Id: utils.NewStringValue("*"), - Namespace: utils.NewStringValue("*"), - Name: utils.NewStringValue("*"), - }, - } - continue + }, + apisecurity.ResourceType_ConfigGroups: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + id, _ := strconv.ParseUint(item.ResID, 10, 64) + user := svr.cacheMgr.ConfigGroup().GetGroupByID(id) + if user == nil { + log.Warn("[Auth][Strategy] not found config_group in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil } - - if !autoAllNs { - ns := svr.cacheMgr.Namespace().GetNamespace(res.ResID) - if ns == nil { - log.Warn("[Auth][Strategy] not found namespace in fill-info", - zap.String("id", data.ID), zap.String("namespace", res.ResID)) - continue - } - namespaces = append(namespaces, &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(ns.Name), - Namespace: utils.NewStringValue(ns.Name), - Name: utils.NewStringValue(ns.Name), - }) + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Namespace), + Name: utils.NewStringValue(user.Name), } - case int32(apisecurity.ResourceType_Services): - if res.ResID == "*" { - autoAllSvc = true - services = []*apisecurity.StrategyResourceEntry{ - { - Id: utils.NewStringValue("*"), - Namespace: utils.NewStringValue("*"), - Name: utils.NewStringValue("*"), - }, - } - continue + }, + apisecurity.ResourceType_Services: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.Namespace().GetNamespace(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found namespace in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil } - - if !autoAllSvc { - svc := svr.cacheMgr.Service().GetServiceByID(res.ResID) - if svc == nil { - log.Warn("[Auth][Strategy] not found service in fill-info", - zap.String("id", data.ID), zap.String("service", res.ResID)) - continue - } - services = append(services, &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(svc.ID), - Namespace: utils.NewStringValue(svc.Namespace), - Name: utils.NewStringValue(svc.Name), - }) + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), } - case int32(apisecurity.ResourceType_ConfigGroups): - if res.ResID == "*" { - autoAllConfigGroup = true - configGroups = []*apisecurity.StrategyResourceEntry{ - { - Id: utils.NewStringValue("*"), - Namespace: utils.NewStringValue("*"), - Name: utils.NewStringValue("*"), - }, - } - continue + }, + apisecurity.ResourceType_RouteRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.RoutingConfig().GetRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found route_rule in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil } - if !autoAllConfigGroup { - groupId, err := strconv.ParseUint(res.ResID, 10, 64) - if err != nil { - log.Warn("[Auth][Strategy] invalid resource id", - zap.String("id", data.ID), zap.String("config_file_group", res.ResID)) - continue - } - group := svr.cacheMgr.ConfigGroup().GetGroupByID(groupId) - if group == nil { - log.Warn("[Auth][Strategy] not found config_file_group in fill-info", - zap.String("id", data.ID), zap.String("config_file_group", res.ResID)) - continue - } - configGroups = append(configGroups, &apisecurity.StrategyResourceEntry{ - Id: utils.NewStringValue(res.ResID), - Namespace: utils.NewStringValue(group.Namespace), - Name: utils.NewStringValue(group.Name), - }) + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), } - } - } - - resp.Resources = &apisecurity.StrategyResources{ - Namespaces: namespaces, - Services: services, - ConfigGroups: configGroups, - } -} - -type resourceFilter struct { - ns map[string]struct{} - svc map[string]struct{} - conf map[string]struct{} -} - -func (f *resourceFilter) GetFilter(t apisecurity.ResourceType) (map[string]struct{}, bool) { - switch t { - case apisecurity.ResourceType_Namespaces: - return f.ns, true - case apisecurity.ResourceType_Services: - return f.svc, true - case apisecurity.ResourceType_ConfigGroups: - return f.conf, true - } - return nil, false -} - -// filter different types of Strategy resources -func resourceDeduplication(resources []authcommon.StrategyResource) []authcommon.StrategyResource { - rLen := len(resources) - ret := make([]authcommon.StrategyResource, 0, rLen) - rf := resourceFilter{ - ns: make(map[string]struct{}, rLen), - svc: make(map[string]struct{}, rLen), - conf: make(map[string]struct{}, rLen), - } - - est := struct{}{} - for i := range resources { - res := resources[i] - filter, ok := rf.GetFilter(apisecurity.ResourceType(res.ResType)) - if !ok { - continue - } - if _, exist := filter[res.ResID]; !exist { - rf.ns[res.ResID] = est - ret = append(ret, res) - } + }, + apisecurity.ResourceType_LaneRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.LaneRule().GetRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found lane_rule in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_RateLimitRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.RateLimit().GetRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found ratelimit_rule in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_CircuitBreakerRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.CircuitBreaker().GetRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found circuitbreaker_rule in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_FaultDetectRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.FaultDetector().GetRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found faultdetect_rule in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Namespace: utils.NewStringValue(user.Name), + Name: utils.NewStringValue(user.Name), + } + }, + // 鉴权资源 + apisecurity.ResourceType_Users: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.User().GetUserByID(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found user in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_UserGroups: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.User().GetGroup(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found user_group in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_Roles: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.Role().GetRole(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found role in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Name: utils.NewStringValue(user.Name), + } + }, + apisecurity.ResourceType_PolicyRules: func(ctx context.Context, svr *Server, + item authcommon.StrategyResource) *apisecurity.StrategyResourceEntry { + user := svr.cacheMgr.AuthStrategy().GetPolicyRule(item.ResID) + if user == nil { + log.Warn("[Auth][Strategy] not found auth_policy in fill-info", + zap.String("id", item.StrategyID), zap.String("res-id", item.ResID), utils.RequestID(ctx)) + return nil + } + return &apisecurity.StrategyResourceEntry{ + Id: utils.NewStringValue(item.ResID), + Name: utils.NewStringValue(user.Name), + } + }, } - return ret -} +) diff --git a/auth/policy/strategy_test.go b/auth/policy/strategy_test.go deleted file mode 100644 index 5626fd7d6..000000000 --- a/auth/policy/strategy_test.go +++ /dev/null @@ -1,957 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package policy_test - -import ( - "context" - "math/rand" - "reflect" - "testing" - "time" - - "github.com/golang/mock/gomock" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/auth/policy" - defaultuser "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - cachetypes "github.com/polarismesh/polaris/cache/api" - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/eventhub" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -type StrategyTest struct { - admin *authcommon.User - ownerOne *authcommon.User - ownerTwo *authcommon.User - - namespaces []*model.Namespace - services []*model.Service - strategies []*authcommon.StrategyDetail - allStrategies []*authcommon.StrategyDetail - defaultStrategies []*authcommon.StrategyDetail - - users []*authcommon.User - groups []*authcommon.UserGroupDetail - - storage *storemock.MockStore - cacheMgn *cache.CacheManager - checker auth.AuthChecker - - svr auth.StrategyServer - - cancel context.CancelFunc - - ctrl *gomock.Controller -} - -func newStrategyTest(t *testing.T) *StrategyTest { - reset(false) - eventhub.InitEventHub() - - ctrl := gomock.NewController(t) - - users := createMockUser(10) - groups := createMockUserGroup(users) - - namespaces := createMockNamespace(len(users)+len(groups)+10, users[0].ID) - services := createMockService(namespaces) - serviceMap := convertServiceSliceToMap(services) - defaultStrategies, strategies := createMockStrategy(users, groups, services[:len(users)+len(groups)]) - - allStrategies := make([]*authcommon.StrategyDetail, 0, len(defaultStrategies)+len(strategies)) - allStrategies = append(allStrategies, defaultStrategies...) - allStrategies = append(allStrategies, strategies...) - - cfg, storage := initCache(ctrl) - - storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(users, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().GetMoreStrategies(gomock.Any(), gomock.Any()).AnyTimes().Return(allStrategies, nil) - storage.EXPECT().GetMoreNamespaces(gomock.Any()).AnyTimes().Return(namespaces, nil) - storage.EXPECT().GetMoreServices(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(serviceMap, nil) - storage.EXPECT().GetStrategyResources(gomock.Eq(users[1].ID), gomock.Any()).AnyTimes().Return(strategies[1].Resources, nil) - storage.EXPECT().GetStrategyResources(gomock.Eq(groups[1].ID), gomock.Any()).AnyTimes().Return(strategies[len(users)-1+2].Resources, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - err = cacheMgn.OpenResourceCache([]cachetypes.ConfigEntry{ - { - Name: cachetypes.ServiceName, - Option: map[string]interface{}{ - "disableBusiness": false, - "needMeta": true, - }, - }, - { - Name: cachetypes.InstanceName, - }, - { - Name: cachetypes.NamespaceName, - }, - { - Name: cachetypes.UsersName, - }, - { - Name: cachetypes.StrategyRuleName, - }, - }...) - if err != nil { - t.Fatal(err) - } - if err := cache.TestRun(ctx, cacheMgn); err != nil { - t.Fatal(err) - } - _ = cacheMgn.TestUpdate() - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgn) - - _, svr, err := newPolicyServer() - if err != nil { - t.Fatal(err) - } - if err := svr.Initialize(&auth.Config{ - Strategy: &auth.StrategyConfig{ - Name: auth.DefaultPolicyPluginName, - }, - }, storage, cacheMgn, proxySvr); err != nil { - t.Fatal(err) - } - checker := svr.GetAuthChecker() - - t.Cleanup(func() { - cacheMgn.Close() - }) - - return &StrategyTest{ - ownerOne: users[0], - - users: users, - groups: groups, - - namespaces: namespaces, - services: services, - strategies: strategies, - allStrategies: allStrategies, - defaultStrategies: defaultStrategies, - - storage: storage, - cacheMgn: cacheMgn, - checker: checker, - - cancel: cancel, - - svr: svr, - - ctrl: ctrl, - } -} - -func (g *StrategyTest) Clean() { - g.cancel() - _ = g.cacheMgn.Close() -} - -func Test_GetPrincipalResources(t *testing.T) { - - strategyTest := newStrategyTest(t) - defer strategyTest.Clean() - - _ = strategyTest.cacheMgn.TestUpdate() - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - ret := strategyTest.svr.GetPrincipalResources(valCtx, map[string]string{ - "principal_id": strategyTest.users[1].ID, - "principal_type": "user", - }) - - t.Logf("GetPrincipalResources resp : %+v", ret) - assert.EqualValues(t, api.ExecuteSuccess, ret.Code.GetValue(), "need query success") - resources := ret.Resources - assert.Equal(t, 2, len(resources.GetServices()), "need query 2 service resources") -} - -func Test_CreateStrategy(t *testing.T) { - - strategyTest := newStrategyTest(t) - defer strategyTest.Clean() - - _ = strategyTest.cacheMgn.TestUpdate() - - t.Run("正常创建鉴权策略", func(t *testing.T) { - strategyTest.storage.EXPECT().AddStrategy(gomock.Any(), gomock.Any()).Return(nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := utils.NewUUID() - - resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyId}, - Name: &wrapperspb.StringValue{ - Value: "正常创建鉴权策略", - }, - Principals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{{ - Id: &wrapperspb.StringValue{ - Value: strategyTest.users[1].ID, - }, - Name: &wrapperspb.StringValue{ - Value: strategyTest.users[1].Name, - }, - }}, - Groups: []*apisecurity.Principal{}, - }, - Resources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - Action: 0, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("创建鉴权策略-非owner用户发起", func(t *testing.T) { - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - strategyId := utils.NewUUID() - - resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyId}, - Name: &wrapperspb.StringValue{ - Value: "创建鉴权策略-非owner用户发起", - }, - Principals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{{ - Id: &wrapperspb.StringValue{ - Value: strategyTest.users[1].ID, - }, - Name: &wrapperspb.StringValue{ - Value: strategyTest.users[1].Name, - }, - }}, - Groups: []*apisecurity.Principal{}, - }, - Resources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - Action: 0, - }) - - assert.Equal(t, api.OperationRoleException, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("创建鉴权策略-关联用户不存在", func(t *testing.T) { - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := utils.NewUUID() - - resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyId}, - Name: &wrapperspb.StringValue{ - Value: "创建鉴权策略-关联用户不存在", - }, - Principals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{{ - Id: &wrapperspb.StringValue{ - Value: utils.NewUUID(), - }, - Name: &wrapperspb.StringValue{ - Value: "user-1", - }, - }}, - Groups: []*apisecurity.Principal{}, - }, - Resources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - Action: 0, - }) - - assert.Equal(t, api.NotFoundUser, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("创建鉴权策略-关联用户组不存在", func(t *testing.T) { - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := utils.NewUUID() - - resp := strategyTest.svr.CreateStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyId}, - Name: &wrapperspb.StringValue{ - Value: "创建鉴权策略-关联用户组不存在", - }, - Principals: &apisecurity.Principals{ - Groups: []*apisecurity.Principal{{ - Id: &wrapperspb.StringValue{ - Value: utils.NewUUID(), - }, - Name: &wrapperspb.StringValue{ - Value: "user-1", - }, - }}, - }, - Resources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - Action: 0, - }) - - assert.Equal(t, api.NotFoundUserGroup, resp.Code.GetValue(), resp.Info.GetValue()) - }) - -} - -func Test_UpdateStrategy(t *testing.T) { - strategyTest := newStrategyTest(t) - defer strategyTest.Clean() - - _ = strategyTest.cacheMgn.TestUpdate() - - t.Run("正常更新鉴权策略", func(t *testing.T) { - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) - strategyTest.storage.EXPECT().UpdateStrategy(gomock.Any()).Return(nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := strategyTest.strategies[0].ID - - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{ - Value: strategyId, - }, - Name: &wrapperspb.StringValue{ - Value: strategyTest.strategies[0].Name, - }, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: strategyTest.users[2].ID}, - }, - }, - }, - RemovePrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: strategyTest.users[3].ID}, - }, - }, - }, - AddResources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{ - {Id: &wrapperspb.StringValue{Value: strategyTest.namespaces[0].Name}}, - }, - Services: []*apisecurity.StrategyResourceEntry{ - {Id: &wrapperspb.StringValue{Value: strategyTest.services[0].ID}}, - }, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - RemoveResources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: strategyId, - }, - Namespaces: []*apisecurity.StrategyResourceEntry{ - {Id: &wrapperspb.StringValue{Value: strategyTest.namespaces[1].Name}}, - }, - Services: []*apisecurity.StrategyResourceEntry{ - {Id: &wrapperspb.StringValue{Value: strategyTest.services[1].ID}}, - }, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - }, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新鉴权策略-非owner用户发起", func(t *testing.T) { - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - strategyId := utils.NewUUID() - - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{ - Value: strategyId, - }, - Name: &wrapperspb.StringValue{ - Value: strategyTest.strategies[0].Name, - }, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{}, - Groups: []*apisecurity.Principal{}, - }, - RemovePrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{}, - Groups: []*apisecurity.Principal{}, - }, - AddResources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: "", - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - RemoveResources: &apisecurity.StrategyResources{ - StrategyId: &wrapperspb.StringValue{ - Value: "", - }, - Namespaces: []*apisecurity.StrategyResourceEntry{}, - Services: []*apisecurity.StrategyResourceEntry{}, - ConfigGroups: []*apisecurity.StrategyResourceEntry{}, - }, - }, - }) - - assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新鉴权策略-目标策略不存在", func(t *testing.T) { - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - - strategyId := strategyTest.defaultStrategies[0].ID - - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{Value: strategyId}, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, - }, - }, - }, - }, - }) - - assert.Equal(t, api.NotFoundAuthStrategyRule, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新鉴权策略-owner不为自己", func(t *testing.T) { - oldOwner := strategyTest.strategies[2].Owner - - defer func() { - strategyTest.strategies[2].Owner = oldOwner - }() - - strategyTest.strategies[2].Owner = strategyTest.users[2].ID - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[2], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := strategyTest.strategies[2].ID - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{Value: strategyId}, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, - }, - }, - }, - }, - }) - - assert.Equal(t, api.NotAllowedAccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新鉴权策略-关联用户不存在", func(t *testing.T) { - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := strategyTest.strategies[0].ID - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{Value: strategyId}, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, - }, - }, - }, - }, - }) - - assert.Equal(t, api.NotFoundUser, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新鉴权策略-关联用户组不存在", func(t *testing.T) { - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := strategyTest.strategies[0].ID - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{Value: strategyId}, - AddPrincipals: &apisecurity.Principals{ - Groups: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, - }, - }, - }, - }, - }) - - assert.Equal(t, api.NotFoundUserGroup, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新默认鉴权策略-不能更改principals成员", func(t *testing.T) { - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.defaultStrategies[0], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - strategyId := strategyTest.defaultStrategies[0].ID - resp := strategyTest.svr.UpdateStrategies(valCtx, []*apisecurity.ModifyAuthStrategy{ - { - Id: &wrapperspb.StringValue{Value: strategyId}, - AddPrincipals: &apisecurity.Principals{ - Users: []*apisecurity.Principal{ - { - Id: &wrapperspb.StringValue{Value: strategyTest.users[3].ID}, - }, - }, - }, - }, - }) - - assert.Equal(t, api.NotAllowModifyDefaultStrategyPrincipal, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - -} - -func Test_DeleteStrategy(t *testing.T) { - strategyTest := newStrategyTest(t) - defer strategyTest.Clean() - - _ = strategyTest.cacheMgn.TestUpdate() - - t.Run("正常删除鉴权策略", func(t *testing.T) { - index := rand.Intn(len(strategyTest.strategies)) - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) - strategyTest.storage.EXPECT().DeleteStrategy(gomock.Any()).Return(nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - - resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ - {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("删除鉴权策略-非owner用户发起", func(t *testing.T) { - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ - {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[rand.Intn(len(strategyTest.strategies))].ID}}, - }) - - assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("删除鉴权策略-目标策略不存在", func(t *testing.T) { - - index := rand.Intn(len(strategyTest.strategies)) - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ - {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("删除鉴权策略-目标为默认鉴权策略", func(t *testing.T) { - index := rand.Intn(len(strategyTest.defaultStrategies)) - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.defaultStrategies[index], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ - {Id: &wrapperspb.StringValue{Value: strategyTest.defaultStrategies[index].ID}}, - }) - - assert.Equal(t, api.BadRequest, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - - t.Run("删除鉴权策略-目标owner不为自己", func(t *testing.T) { - index := rand.Intn(len(strategyTest.defaultStrategies)) - oldOwner := strategyTest.strategies[index].Owner - - defer func() { - strategyTest.strategies[index].Owner = oldOwner - }() - - strategyTest.strategies[index].Owner = strategyTest.users[2].ID - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - resp := strategyTest.svr.DeleteStrategies(valCtx, []*apisecurity.AuthStrategy{ - {Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}}, - }) - - assert.Equal(t, api.NotAllowedAccess, resp.Responses[0].Code.GetValue(), resp.Responses[0].Info.GetValue()) - }) - -} - -func Test_GetStrategy(t *testing.T) { - strategyTest := newStrategyTest(t) - defer strategyTest.Clean() - - t.Run("正常查询鉴权策略", func(t *testing.T) { - // 主账户查询自己的策略 - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[0], nil) - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[0].ID}, - }) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) - - // 主账户查询自己自账户的策略 - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[1], nil) - valCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - _ = strategyTest.cacheMgn.TestUpdate() - resp = strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[1].ID}, - }) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("查询鉴权策略-目标owner不为自己", func(t *testing.T) { - t.Skip() - var index int - for { - index = rand.Intn(len(strategyTest.defaultStrategies)) - if index != 2 { - break - } - } - oldOwner := strategyTest.strategies[index].Owner - - defer func() { - strategyTest.strategies[index].Owner = oldOwner - }() - - strategyTest.strategies[index].Owner = strategyTest.users[2].ID - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[index], nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[0].Token) - - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[index].ID}, - }) - - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("查询鉴权策略-非owner用户查询自己的", func(t *testing.T) { - - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[1], nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[1].ID}, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("查询鉴权策略-非owner用户查询自己所在用户组的", func(t *testing.T) { - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[len(strategyTest.users)-1+2], nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[len(strategyTest.users)-1+2].ID}, - }) - - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("查询鉴权策略-非owner用户查询别人的", func(t *testing.T) { - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(strategyTest.strategies[2], nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: strategyTest.strategies[2].ID}, - }) - - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("查询鉴权策略-目标策略不存在", func(t *testing.T) { - strategyTest.storage.EXPECT().GetStrategyDetail(gomock.Any()).Return(nil, nil) - - valCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, strategyTest.users[1].Token) - - _ = strategyTest.cacheMgn.TestUpdate() - resp := strategyTest.svr.GetStrategy(valCtx, &apisecurity.AuthStrategy{ - Id: &wrapperspb.StringValue{Value: utils.NewUUID()}, - }) - - assert.Equal(t, api.NotFoundAuthStrategyRule, resp.Code.GetValue(), resp.Info.GetValue()) - }) - -} - -func Test_parseStrategySearchArgs(t *testing.T) { - type args struct { - ctx context.Context - searchFilters map[string]string - } - tests := []struct { - name string - args args - want map[string]string - }{ - { - name: "res_type(namespace) 查询", - args: args{ - ctx: func() context.Context { - ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") - ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) - ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) - return ctx - }(), - searchFilters: map[string]string{ - "res_type": "namespace", - }, - }, - want: map[string]string{ - "res_type": "0", - "owner": "owner", - }, - }, - { - name: "res_type(service) 查询", - args: args{ - ctx: func() context.Context { - ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") - ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) - ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) - return ctx - }(), - searchFilters: map[string]string{ - "res_type": "service", - }, - }, - want: map[string]string{ - "res_type": "1", - "owner": "owner", - }, - }, - { - name: "principal_type(user) 查询", - args: args{ - ctx: func() context.Context { - ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") - ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.SubAccountUserRole) - ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, false) - return ctx - }(), - searchFilters: map[string]string{ - "principal_type": "user", - }, - }, - want: map[string]string{ - "principal_type": "1", - "owner": "owner", - "principal_id": "user", - }, - }, - { - name: "principal_type(group) 查询", - args: args{ - ctx: func() context.Context { - ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") - ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) - ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) - return ctx - }(), - searchFilters: map[string]string{ - "principal_type": "group", - }, - }, - want: map[string]string{ - "principal_type": "2", - "owner": "owner", - }, - }, - { - name: "按照资源ID查询", - args: args{ - ctx: func() context.Context { - ctx := context.WithValue(context.Background(), utils.ContextOwnerIDKey, "owner") - ctx = context.WithValue(ctx, utils.ContextUserIDKey, "user") - ctx = context.WithValue(ctx, utils.ContextUserRoleIDKey, authcommon.OwnerUserRole) - ctx = context.WithValue(ctx, utils.ContextIsOwnerKey, true) - return ctx - }(), - searchFilters: map[string]string{ - "res_id": "res_id", - }, - }, - want: map[string]string{ - "res_id": "res_id", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := policy.ParseStrategySearchArgs(tt.args.ctx, tt.args.searchFilters); !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseStrategySearchArgs() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_AuthServer_NormalOperateStrategy(t *testing.T) { - suit := &AuthTestSuit{} - if err := suit.Initialize(); err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - suit.cleanAllAuthStrategy() - suit.cleanAllUser() - suit.cleanAllUserGroup() - suit.Destroy() - }) - - users := createApiMockUser(10, "test") - - t.Run("正常创建用户", func(t *testing.T) { - resp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - }) - - t.Run("正常更新用户", func(t *testing.T) { - users[0].Comment = utils.NewStringValue("update user comment") - resp := suit.UserServer().UpdateUser(suit.DefaultCtx, users[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ - "id": users[0].GetId().GetValue(), - }) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - assert.Equal(t, 1, int(qresp.Amount.GetValue())) - assert.Equal(t, 1, int(qresp.Size.GetValue())) - - retUsers := qresp.GetUsers()[0] - assert.Equal(t, users[0].GetComment().GetValue(), retUsers.GetComment().GetValue()) - }) - - t.Run("正常删除用户", func(t *testing.T) { - resp := suit.UserServer().DeleteUsers(suit.DefaultCtx, []*apisecurity.User{users[3]}) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ - "id": users[3].GetId().GetValue(), - }) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - assert.Equal(t, 0, int(qresp.Amount.GetValue())) - assert.Equal(t, 0, int(qresp.Size.GetValue())) - }) - - t.Run("正常更新用户Token", func(t *testing.T) { - resp := suit.UserServer().ResetUserToken(suit.DefaultCtx, users[0]) - if !api.IsSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - _ = suit.CacheMgr().TestUpdate() - - qresp := suit.UserServer().GetUserToken(suit.DefaultCtx, users[0]) - if !api.IsSuccess(qresp) { - t.Fatal(qresp.String()) - } - assert.Equal(t, resp.GetUser().GetAuthToken().GetValue(), qresp.GetUser().GetAuthToken().GetValue()) - }) -} diff --git a/auth/user/common_test.go b/auth/user/common_test.go deleted file mode 100644 index a5479c62a..000000000 --- a/auth/user/common_test.go +++ /dev/null @@ -1,407 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package defaultuser_test - -import ( - "fmt" - "time" - - "github.com/golang/mock/gomock" - "github.com/google/uuid" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "golang.org/x/crypto/bcrypt" - "google.golang.org/protobuf/types/known/wrapperspb" - - defaultuser "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - "github.com/polarismesh/polaris/common/metrics" - "github.com/polarismesh/polaris/common/model" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -var ( - _defaultSalt = "polarismesh@2021" -) - -func reset(strict bool) { -} - -func initCache(ctrl *gomock.Controller) (*cache.Config, *storemock.MockStore) { - metrics.InitMetrics() - /* - - name: service # 加载服务数据 - option: - disableBusiness: false # 不加载业务服务 - needMeta: true # 加载服务元数据 - - name: instance # 加载实例数据 - option: - disableBusiness: false # 不加载业务服务实例 - needMeta: true # 加载实例元数据 - - name: routingConfig # 加载路由数据 - - name: rateLimitConfig # 加载限流数据 - - name: circuitBreakerConfig # 加载熔断数据 - - name: l5 # 加载l5数据 - - name: users - - name: strategyRule - - name: namespace - */ - cfg := &cache.Config{} - storage := storemock.NewMockStore(ctrl) - - mockTx := storemock.NewMockTx(ctrl) - mockTx.EXPECT().Commit().Return(nil).AnyTimes() - mockTx.EXPECT().Rollback().Return(nil).AnyTimes() - mockTx.EXPECT().CreateReadView().Return(nil).AnyTimes() - - storage.EXPECT().StartReadTx().Return(mockTx, nil).AnyTimes() - storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetInstancesCountTx(gomock.Any()).AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetMoreInstances(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]*model.Instance{ - "123": { - Proto: &service_manage.Instance{ - Id: wrapperspb.String(uuid.NewString()), - Host: wrapperspb.String("127.0.0.1"), - Port: wrapperspb.UInt32(8080), - }, - Valid: true, - }, - }, nil).AnyTimes() - storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - - return cfg, storage -} - -func createMockNamespace(total int, owner string) []*model.Namespace { - namespaces := make([]*model.Namespace, 0, total) - - for i := 0; i < total; i++ { - namespaces = append(namespaces, &model.Namespace{ - Name: fmt.Sprintf("namespace_%d", i), - Owner: owner, - Valid: true, - }) - } - - return namespaces -} - -func createMockService(namespaces []*model.Namespace) []*model.Service { - services := make([]*model.Service, 0, len(namespaces)) - - for i := 0; i < len(namespaces); i++ { - ns := namespaces[i] - services = append(services, &model.Service{ - ID: utils.NewUUID(), - Namespace: ns.Name, - Owner: ns.Owner, - Name: fmt.Sprintf("service_%d", i), - Valid: true, - }) - } - - return services -} - -// createMockUser 默认 users[0] 为 owner 用户 -func createMockUser(total int, prefix ...string) []*authcommon.User { - users := make([]*authcommon.User, 0, total) - - ownerId := utils.NewUUID() - - nameTemp := "user-%d" - if len(prefix) != 0 { - nameTemp = prefix[0] + nameTemp - } - - for i := 0; i < total; i++ { - id := fmt.Sprintf("fake-user-id-%d-%s", i, utils.NewUUID()) - if i == 0 { - id = ownerId - } - pwd, _ := bcrypt.GenerateFromPassword([]byte("polaris"), bcrypt.DefaultCost) - token, _ := defaultuser.CreateToken(id, "", "polarismesh@2021") - users = append(users, &authcommon.User{ - ID: id, - Name: fmt.Sprintf(nameTemp, i), - Password: string(pwd), - Owner: func() string { - if id == ownerId { - return "" - } - return ownerId - }(), - Source: "Polaris", - Mobile: "", - Email: "", - Type: func() authcommon.UserRoleType { - if id == ownerId { - return authcommon.OwnerUserRole - } - return authcommon.SubAccountUserRole - }(), - Token: token, - TokenEnable: true, - Valid: true, - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - return users -} - -func createApiMockUser(total int, prefix ...string) []*apisecurity.User { - users := make([]*apisecurity.User, 0, total) - - models := createMockUser(total, prefix...) - - for i := range models { - users = append(users, &apisecurity.User{ - Name: utils.NewStringValue("test-" + models[i].Name), - Password: utils.NewStringValue("123456"), - Source: utils.NewStringValue("Polaris"), - Comment: utils.NewStringValue(models[i].Comment), - Mobile: utils.NewStringValue(models[i].Mobile), - Email: utils.NewStringValue(models[i].Email), - }) - } - - return users -} - -func createMockUserGroup(users []*authcommon.User) []*authcommon.UserGroupDetail { - groups := make([]*authcommon.UserGroupDetail, 0, len(users)) - - for i := range users { - user := users[i] - id := utils.NewUUID() - - token, _ := defaultuser.CreateToken("", id, _defaultSalt) - - groups = append(groups, &authcommon.UserGroupDetail{ - UserGroup: &authcommon.UserGroup{ - ID: id, - Name: fmt.Sprintf("test-group-%d", i), - Owner: users[0].ID, - Token: token, - TokenEnable: true, - Valid: true, - Comment: "", - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }, - UserIds: map[string]struct{}{ - user.ID: {}, - }, - }) - } - - return groups -} - -// createMockApiUserGroup -func createMockApiUserGroup(users []*apisecurity.User) []*apisecurity.UserGroup { - musers := make([]*authcommon.User, 0, len(users)) - for i := range users { - musers = append(musers, &authcommon.User{ - ID: users[i].GetId().GetValue(), - }) - } - - models := createMockUserGroup(musers) - ret := make([]*apisecurity.UserGroup, 0, len(models)) - - for i := range models { - ret = append(ret, &apisecurity.UserGroup{ - Name: utils.NewStringValue(models[i].Name), - Comment: utils.NewStringValue(models[i].Comment), - Relation: &apisecurity.UserGroupRelation{ - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(users[i].GetId().GetValue()), - }, - }, - }, - }) - } - - return ret -} - -func createMockStrategy(users []*authcommon.User, groups []*authcommon.UserGroupDetail, services []*model.Service) ([]*authcommon.StrategyDetail, []*authcommon.StrategyDetail) { - strategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) - defaultStrategies := make([]*authcommon.StrategyDetail, 0, len(users)+len(groups)) - - owner := "" - for i := 0; i < len(users); i++ { - user := users[i] - if user.Owner == "" { - owner = user.ID - break - } - } - - for i := 0; i < len(users); i++ { - user := users[i] - service := services[i] - id := utils.NewUUID() - strategies = append(strategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_user_%s_%d", user.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: user.ID, - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: false, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - - defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_default_user_%s_%d", user.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: user.ID, - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: true, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - - for i := 0; i < len(groups); i++ { - group := groups[i] - service := services[len(users)+i] - id := utils.NewUUID() - strategies = append(strategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_group_%s_%d", group.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: group.ID, - PrincipalType: authcommon.PrincipalGroup, - }, - }, - Default: false, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - - defaultStrategies = append(defaultStrategies, &authcommon.StrategyDetail{ - ID: id, - Name: fmt.Sprintf("strategy_default_group_%s_%d", group.Name, i), - Action: apisecurity.AuthAction_READ_WRITE.String(), - Comment: "", - Principals: []authcommon.Principal{ - { - PrincipalID: group.ID, - PrincipalType: authcommon.PrincipalGroup, - }, - }, - Default: true, - Owner: owner, - Resources: []authcommon.StrategyResource{ - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: service.Namespace, - }, - { - StrategyID: id, - ResType: int32(apisecurity.ResourceType_Services), - ResID: service.ID, - }, - }, - Valid: true, - Revision: utils.NewUUID(), - CreateTime: time.Time{}, - ModifyTime: time.Time{}, - }) - } - - return defaultStrategies, strategies -} - -func convertServiceSliceToMap(services []*model.Service) map[string]*model.Service { - ret := make(map[string]*model.Service) - - for i := range services { - service := services[i] - ret[service.ID] = service - } - - return ret -} diff --git a/auth/user/group.go b/auth/user/group.go index 15dbbe487..18045a3dd 100644 --- a/auth/user/group.go +++ b/auth/user/group.go @@ -282,8 +282,6 @@ func (svr *Server) EnableGroupToken(ctx context.Context, req *apisecurity.UserGr // ResetGroupToken 刷新用户组的token func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGroup) *apiservice.Response { var ( - requestID = utils.ParseRequestID(ctx) - platformID = utils.ParsePlatformID(ctx) group, errResp = svr.getGroupFromDB(req.Id.GetValue()) ) @@ -297,8 +295,7 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro newToken, err := createGroupToken(group.ID, svr.authOpt.Salt) if err != nil { - log.Error("reset group token", utils.ZapRequestID(requestID), - utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("reset group token", utils.RequestID(ctx), zap.Error(err)) return api.NewAuthResponseWithMsg(apimodel.Code_ExecuteException, err.Error()) } @@ -312,12 +309,12 @@ func (svr *Server) ResetGroupToken(ctx context.Context, req *apisecurity.UserGro } if err := svr.storage.UpdateGroup(modifyReq); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAuthResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } log.Info("reset group token", zap.String("group-id", req.Id.GetValue()), - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + utils.RequestID(ctx)) svr.RecordHistory(userGroupRecordEntry(ctx, req, group.UserGroup, model.OUpdate)) req.AuthToken = utils.NewStringValue(newToken) diff --git a/auth/user/group_test.go b/auth/user/group_test.go deleted file mode 100644 index 087a95439..000000000 --- a/auth/user/group_test.go +++ /dev/null @@ -1,822 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package defaultuser_test - -import ( - "context" - "testing" - "time" - - "github.com/golang/mock/gomock" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/polarismesh/polaris/auth" - defaultauth "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - cachetypes "github.com/polarismesh/polaris/cache/api" - v1 "github.com/polarismesh/polaris/common/api/v1" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -type GroupTest struct { - ctrl *gomock.Controller - - ownerOne *authcommon.User - ownerTwo *authcommon.User - - users []*authcommon.User - groups []*authcommon.UserGroupDetail - newGroups []*authcommon.UserGroupDetail - allGroups []*authcommon.UserGroupDetail - - storage *storemock.MockStore - cacheMgn *cache.CacheManager - checker auth.AuthChecker - cancel context.CancelFunc - - svr auth.UserServer -} - -func newGroupTest(t *testing.T) *GroupTest { - reset(false) - ctrl := gomock.NewController(t) - - users := createMockUser(10) - groups := createMockUserGroup(users) - - newUsers := createMockUser(10) - newGroups := createMockUserGroup(newUsers) - - allGroups := append(groups, newGroups...) - - storage := storemock.NewMockStore(ctrl) - - storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - storage.EXPECT().AddGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(append(users, newUsers...), nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allGroups, nil) - - cfg := &cache.Config{} - - ctx, cancel := context.WithCancel(context.Background()) - cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Error(err) - } - _ = cacheMgn.OpenResourceCache([]cachetypes.ConfigEntry{ - { - Name: cachetypes.UsersName, - }, - }...) - t.Cleanup(func() { - _ = cacheMgn.Close() - }) - - _, proxySvr, err := defaultauth.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgn) - _ = cacheMgn.TestUpdate() - return &GroupTest{ - ctrl: ctrl, - ownerOne: users[0], - ownerTwo: newUsers[0], - users: users, - groups: groups, - newGroups: newGroups, - allGroups: allGroups, - storage: storage, - cacheMgn: cacheMgn, - cancel: cancel, - svr: proxySvr, - } -} - -func (g *GroupTest) Clean() { - g.cancel() - g.cacheMgn.Close() - g.ctrl.Finish() -} - -func Test_server_CreateGroup(t *testing.T) { - groupTest := newGroupTest(t) - - defer groupTest.Clean() - - t.Run("正常创建用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - - groups := createMockUserGroup(groupTest.users[:1]) - groups[0].ID = utils.NewUUID() - - groupTest.storage.EXPECT().GetGroupByName(gomock.Any(), gomock.Any()).Return(nil, nil) - - resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groups[0].ID), - Name: utils.NewStringValue(groups[0].Name), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("用户组已存在", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - - groups := createMockUserGroup(groupTest.users[:1]) - groups[0].ID = utils.NewUUID() - - groupTest.storage.EXPECT().GetGroupByName(gomock.Any(), gomock.Any()).Return(groups[0].UserGroup, nil) - - resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groups[0].ID), - Name: utils.NewStringValue(groups[0].Name), - }) - - assert.True(t, resp.GetCode().Value == v1.UserGroupExisted, resp.Info.GetValue()) - }) - - t.Run("子用户去创建用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groups := createMockUserGroup(groupTest.users[:1]) - groups[0].ID = utils.NewUUID() - - resp := groupTest.svr.CreateGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groups[0].ID), - Name: utils.NewStringValue(groups[0].Name), - }) - - assert.True(t, resp.GetCode().Value == v1.OperationRoleException, resp.Info.GetValue()) - }) - - t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[3], nil) - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[3].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil).AnyTimes() - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil).AnyTimes() - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_server_GetGroup(t *testing.T) { - groupTest := newGroupTest(t) - - defer groupTest.Clean() - t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.newGroups[0], nil) - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.newGroups[0].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil).AnyTimes() - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil).AnyTimes() - - resp := groupTest.svr.GetGroup(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_server_UpdateGroup(t *testing.T) { - t.Run("主账户更新用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[1], nil) - groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).Return(nil) - - req := &apisecurity.ModifyUserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - Comment: &wrapperspb.StringValue{ - Value: "new test group", - }, - AddRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[1].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[2].ID), - }, - { - Id: utils.NewStringValue(groupTest.users[3].ID), - }, - }, - }, - RemoveRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[1].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[5].ID), - }, - }, - }, - } - - resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) - - assert.True(t, resp.Responses[0].Code.GetValue() == v1.ExecuteSuccess, resp.Responses[0].Info.GetValue()) - }) - - t.Run("主账户更新不是自己负责的用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.newGroups[1], nil) - - req := &apisecurity.ModifyUserGroup{ - Id: utils.NewStringValue(groupTest.newGroups[0].ID), - Comment: &wrapperspb.StringValue{ - Value: "new test group", - }, - AddRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[0].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[2].ID), - }, - { - Id: utils.NewStringValue(groupTest.users[3].ID), - }, - }, - }, - RemoveRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[0].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[5].ID), - }, - }, - }, - } - - resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) - assert.True(t, resp.Responses[0].Code.GetValue() == v1.NotAllowedAccess, resp.Responses[0].Info.GetValue()) - }) - - t.Run("子账户更新用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - req := &apisecurity.ModifyUserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - Comment: &wrapperspb.StringValue{ - Value: "new test group", - }, - AddRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[2].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[2].ID), - }, - { - Id: utils.NewStringValue(groupTest.users[3].ID), - }, - }, - }, - RemoveRelations: &apisecurity.UserGroupRelation{ - GroupId: utils.NewStringValue(groupTest.groups[2].ID), - Users: []*apisecurity.User{ - { - Id: utils.NewStringValue(groupTest.users[5].ID), - }, - }, - }, - } - - resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) - assert.True(t, resp.Responses[0].GetCode().Value == v1.OperationRoleException, resp.Responses[0].Info.GetValue()) - }) - - t.Run("更新用户组-啥都没用动过", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[2], nil) - - req := &apisecurity.ModifyUserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - Comment: &wrapperspb.StringValue{Value: groupTest.groups[2].Comment}, - AddRelations: &apisecurity.UserGroupRelation{}, - RemoveRelations: &apisecurity.UserGroupRelation{}, - } - - resp := groupTest.svr.UpdateGroups(reqCtx, []*apisecurity.ModifyUserGroup{req}) - assert.True(t, resp.Responses[0].GetCode().Value == v1.NoNeedUpdate, resp.Responses[0].Info.GetValue()) - }) - -} - -func Test_server_GetGroupToken(t *testing.T) { - t.Run("主账户去查询owner为自己的用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[0].Token) - - resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户去查询owner不是自己的用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) - - resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己所在的用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[1].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("子账户去查询自己不在的用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - resp := groupTest.svr.GetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, resp.GetCode().Value == v1.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_server_DeleteGroup(t *testing.T) { - - t.Run("正常删除用户组", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ - { - Id: utils.NewStringValue(groupTest.groups[0].ID), - }, - }) - - assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) - }) - - t.Run("删除用户组-用户组不存在", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(nil, nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ - { - Id: utils.NewStringValue(groupTest.groups[0].ID), - }, - }) - - assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) - }) - - t.Run("删除用户组-不是用户组的owner", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ - { - Id: utils.NewStringValue(groupTest.groups[0].ID), - }, - }) - - assert.True(t, batchResp.Responses[0].GetCode().Value == v1.NotAllowedAccess, batchResp.Responses[0].Info.GetValue()) - }) - - t.Run("删除用户组-非owner角色", func(t *testing.T) { - groupTest := newGroupTest(t) - - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[1].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().DeleteGroup(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - batchResp := groupTest.svr.DeleteGroups(reqCtx, []*apisecurity.UserGroup{ - { - Id: utils.NewStringValue(groupTest.groups[0].ID), - }, - }) - - assert.True(t, batchResp.Responses[0].GetCode().Value == v1.OperationRoleException, batchResp.Responses[0].Info.GetValue()) - }) - -} - -func Test_server_UpdateGroupToken(t *testing.T) { - t.Run("正常更新用户组Token的Enable状态", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).AnyTimes().Return(nil) - - batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) - }) - - t.Run("非Owner角色更新用户组Token的Enable状态", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[2].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - - batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.Code.Value == v1.OperationRoleException, batchResp.Info.GetValue()) - }) - - t.Run("更新用户组Token的Enable状态-非group的owner", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - - batchResp := groupTest.svr.EnableGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.Code.Value == v1.NotAllowedAccess, batchResp.Info.GetValue()) - }) -} - -func Test_server_RefreshGroupToken(t *testing.T) { - t.Run("正常更新用户组Token的Enable状态", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerOne.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[0], nil) - groupTest.storage.EXPECT().UpdateGroup(gomock.Any()).Return(nil) - - batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.GetCode().Value == v1.ExecuteSuccess, batchResp.Info.GetValue()) - }) - - t.Run("非Owner角色更新用户组Token的Enable状态", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.users[2].Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).AnyTimes().Return(groupTest.groups[0], nil) - - batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.Code.Value == v1.OperationRoleException, batchResp.Info.GetValue()) - }) - - t.Run("更新用户组Token的Enable状态-非group的owner", func(t *testing.T) { - groupTest := newGroupTest(t) - t.Cleanup(func() { - groupTest.Clean() - }) - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, groupTest.ownerTwo.Token) - - groupTest.storage.EXPECT().GetGroup(gomock.Any()).Return(groupTest.groups[0], nil).AnyTimes() - - batchResp := groupTest.svr.ResetGroupToken(reqCtx, &apisecurity.UserGroup{ - Id: utils.NewStringValue(groupTest.groups[2].ID), - }) - - assert.True(t, batchResp.Code.Value == v1.NotAllowedAccess, batchResp.Info.GetValue()) - }) -} - -func Test_AuthServer_NormalOperateUserGroup(t *testing.T) { - suit := &AuthTestSuit{} - if err := suit.Initialize(); err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - suit.cleanAllAuthStrategy() - suit.cleanAllUser() - suit.cleanAllUserGroup() - suit.Destroy() - }) - - users := createApiMockUser(10, "test") - for i := range users { - users[i].Id = utils.NewStringValue(utils.NewUUID()) - } - - groups := createMockApiUserGroup([]*apisecurity.User{users[0]}) - - t.Run("正常创建用户组", func(t *testing.T) { - bresp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) - if !respSuccess(bresp) { - t.Fatal(bresp.GetInfo().GetValue()) - } - - _ = suit.CacheMgr().TestUpdate() - - resp := suit.UserServer().CreateGroup(suit.DefaultCtx, groups[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - groups[0].Id = utils.NewStringValue(resp.GetUserGroup().Id.Value) - }) - - t.Run("正常更新用户组", func(t *testing.T) { - - time.Sleep(time.Second) - - req := []*apisecurity.ModifyUserGroup{ - { - Id: utils.NewStringValue(groups[0].GetId().GetValue()), - Name: utils.NewStringValue(groups[0].GetName().GetValue()), - Comment: &wrapperspb.StringValue{ - Value: "update user group", - }, - AddRelations: &apisecurity.UserGroupRelation{ - Users: users[3:], - }, - }, - } - - resp := suit.UserServer().UpdateGroups(suit.DefaultCtx, req) - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - _ = suit.CacheMgr().TestUpdate() - - qresp := suit.UserServer().GetGroup(suit.DefaultCtx, groups[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - assert.Equal(t, req[0].GetComment().GetValue(), qresp.GetUserGroup().GetComment().GetValue()) - assert.Equal(t, len(users[3:])+1, len(qresp.GetUserGroup().GetRelation().GetUsers())) - }) - - t.Run("正常更新用户组Token", func(t *testing.T) { - resp := suit.UserServer().ResetGroupToken(suit.DefaultCtx, groups[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - _ = suit.CacheMgr().TestUpdate() - - qresp := suit.UserServer().GetGroupToken(suit.DefaultCtx, groups[0]) - if !respSuccess(qresp) { - t.Fatal(resp.GetInfo().GetValue()) - } - assert.Equal(t, resp.GetUserGroup().GetAuthToken().GetValue(), qresp.GetUserGroup().GetAuthToken().GetValue()) - }) - - t.Run("正常查询某个用户组下的用户列表", func(t *testing.T) { - qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ - "group_id": groups[0].GetId().GetValue(), - }) - - if !respSuccess(qresp) { - t.Fatal(qresp.GetInfo().GetValue()) - } - - assert.Equal(t, 8, len(qresp.GetUsers())) - - expectUsers := []string{users[0].Id.Value} - for _, u := range users[3:] { - expectUsers = append(expectUsers, u.Id.Value) - } - - retUsers := []string{} - for i := range qresp.GetUsers() { - retUsers = append(retUsers, qresp.GetUsers()[i].Id.Value) - } - assert.ElementsMatch(t, expectUsers, retUsers) - }) - - t.Run("正常查询用户组列表", func(t *testing.T) { - qresp := suit.UserServer().GetGroups(suit.DefaultCtx, map[string]string{}) - - if !respSuccess(qresp) { - t.Fatal(qresp.GetInfo().GetValue()) - } - - assert.True(t, len(qresp.GetUserGroups()) == 1) - assert.Equal(t, groups[0].GetId().GetValue(), qresp.GetUserGroups()[0].Id.GetValue()) - }) - - t.Run("查询某个用户所在的所有分组", func(t *testing.T) { - qresp := suit.UserServer().GetGroups(suit.DefaultCtx, map[string]string{ - "user_id": users[0].GetId().GetValue(), - }) - - if !respSuccess(qresp) { - t.Fatal(qresp.GetInfo().GetValue()) - } - - assert.True(t, len(qresp.GetUserGroups()) == 1) - assert.Equal(t, groups[0].GetId().GetValue(), qresp.GetUserGroups()[0].Id.GetValue()) - }) - - t.Run("正常删除用户组", func(t *testing.T) { - resp := suit.UserServer().DeleteGroups(suit.DefaultCtx, groups) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - qresp := suit.UserServer().GetGroup(suit.DefaultCtx, groups[0]) - - if respSuccess(qresp) { - t.Fatal(qresp.GetInfo().GetValue()) - } - - assert.Equal(t, v1.NotFoundUserGroup, qresp.GetCode().GetValue()) - }) -} diff --git a/auth/user/inteceptor/auth/server.go b/auth/user/inteceptor/auth/server.go index 6cc411516..17f6e877c 100644 --- a/auth/user/inteceptor/auth/server.go +++ b/auth/user/inteceptor/auth/server.go @@ -28,6 +28,7 @@ import ( "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" authmodel "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" @@ -46,8 +47,10 @@ type Server struct { } // Initialize 初始化 -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policyMgr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { - return svr.nextSvr.Initialize(authOpt, storage, policyMgr, cacheMgr) +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, + cacheMgr cachetypes.CacheManager) error { + svr.policySvr = policySvr + return svr.nextSvr.Initialize(authOpt, storage, policySvr, cacheMgr) } // Name 用户数据管理server名称 @@ -80,7 +83,7 @@ func (svr *Server) CreateUsers(ctx context.Context, users []*apisecurity.User) * ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.CreateUsers(authCtx.GetRequestContext(), users) } @@ -110,7 +113,7 @@ func (svr *Server) UpdateUser(ctx context.Context, user *apisecurity.User) *apis ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.UpdateUser(authCtx.GetRequestContext(), user) } @@ -140,7 +143,7 @@ func (svr *Server) UpdateUserPassword(ctx context.Context, req *apisecurity.Modi ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.UpdateUserPassword(authCtx.GetRequestContext(), req) } @@ -170,7 +173,7 @@ func (svr *Server) DeleteUsers(ctx context.Context, users []*apisecurity.User) * }), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.DeleteUsers(authCtx.GetRequestContext(), users) } @@ -184,7 +187,7 @@ func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apise authcommon.WithMethod(authcommon.DescribeUsers), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() query["hide_admin"] = strconv.FormatBool(true) @@ -194,7 +197,7 @@ func (svr *Server) GetUsers(ctx context.Context, query map[string]string) *apise query["owner"] = utils.ParseOwnerID(ctx) } - cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { + ctx = cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ Type: apisecurity.ResourceType_Users, ID: u.ID, @@ -229,7 +232,7 @@ func (svr *Server) GetUserToken(ctx context.Context, user *apisecurity.User) *ap ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.GetUserToken(authCtx.GetRequestContext(), user) } @@ -258,9 +261,9 @@ func (svr *Server) EnableUserToken(ctx context.Context, user *apisecurity.User) ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.EnableUserToken(ctx, user) + return svr.nextSvr.EnableUserToken(authCtx.GetRequestContext(), user) } // ResetUserToken 重置用户的token @@ -289,7 +292,7 @@ func (svr *Server) ResetUserToken(ctx context.Context, user *apisecurity.User) * if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) } - return svr.nextSvr.ResetUserToken(ctx, user) + return svr.nextSvr.ResetUserToken(authCtx.GetRequestContext(), user) } // CreateGroup 创建用户组 @@ -302,7 +305,7 @@ func (svr *Server) CreateGroup(ctx context.Context, group *apisecurity.UserGroup ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.CreateGroup(authCtx.GetRequestContext(), group) } @@ -334,7 +337,7 @@ func (svr *Server) UpdateGroups(ctx context.Context, groups []*apisecurity.Modif ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } return svr.nextSvr.UpdateGroups(authCtx.GetRequestContext(), groups) } @@ -364,9 +367,9 @@ func (svr *Server) DeleteGroups(ctx context.Context, groups []*apisecurity.UserG ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.DeleteGroups(ctx, groups) + return svr.nextSvr.DeleteGroups(authCtx.GetRequestContext(), groups) } // GetGroups 查询用户组列表(不带用户详细信息) @@ -378,23 +381,27 @@ func (svr *Server) GetGroups(ctx context.Context, query map[string]string) *apis authcommon.WithMethod(authcommon.DescribeUserGroups), ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() - if authcommon.ParseUserRole(ctx) != authmodel.AdminUserRole { - // step 1: 设置 owner 信息,只能查看归属主帐户下的用户组 - query["owner"] = utils.ParseOwnerID(ctx) - } - - cachetypes.AppendUserPredicate(ctx, func(ctx context.Context, u *authcommon.User) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ + ctx = cachetypes.AppendUserGroupPredicate(ctx, func(ctx context.Context, u *authcommon.UserGroupDetail) bool { + ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authmodel.ResourceEntry{ Type: apisecurity.ResourceType_UserGroups, ID: u.ID, Metadata: u.Metadata, }) + if ok { + return true + } + // 兼容老版本的策略查询逻辑 + if compatible, _ := ctx.Value(model.ContextKeyCompatible{}).(bool); compatible { + _, exist := u.UserIds[utils.ParseUserID(ctx)] + return exist + } + return false }) - delete(query, "owner") - return svr.nextSvr.GetGroups(ctx, query) + authCtx.SetRequestContext(ctx) + return svr.nextSvr.GetGroups(authCtx.GetRequestContext(), query) } // GetGroup 根据用户组信息,查询该用户组下的用户相信 @@ -421,9 +428,9 @@ func (svr *Server) GetGroup(ctx context.Context, req *apisecurity.UserGroup) *ap ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.GetGroup(ctx, req) + return svr.nextSvr.GetGroup(authCtx.GetRequestContext(), req) } // GetGroupToken 获取用户组的 token @@ -450,9 +457,9 @@ func (svr *Server) GetGroupToken(ctx context.Context, group *apisecurity.UserGro ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.GetGroupToken(ctx, group) + return svr.nextSvr.GetGroupToken(authCtx.GetRequestContext(), group) } // EnableGroupToken 取消用户组的 token 使用 @@ -479,9 +486,9 @@ func (svr *Server) EnableGroupToken(ctx context.Context, group *apisecurity.User ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.EnableGroupToken(ctx, group) + return svr.nextSvr.EnableGroupToken(authCtx.GetRequestContext(), group) } // ResetGroupToken 重置用户组的 token @@ -508,7 +515,7 @@ func (svr *Server) ResetGroupToken(ctx context.Context, group *apisecurity.UserG ) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } - return svr.nextSvr.ResetGroupToken(ctx, group) + return svr.nextSvr.ResetGroupToken(authCtx.GetRequestContext(), group) } diff --git a/auth/user/main_test.go b/auth/user/main_test.go deleted file mode 100644 index 6ccb1a835..000000000 --- a/auth/user/main_test.go +++ /dev/null @@ -1,215 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package defaultuser_test - -import ( - "errors" - - _ "github.com/go-sql-driver/mysql" - bolt "go.etcd.io/bbolt" - - "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" - _ "github.com/polarismesh/polaris/cache" - api "github.com/polarismesh/polaris/common/api/v1" - commonlog "github.com/polarismesh/polaris/common/log" - "github.com/polarismesh/polaris/namespace" - "github.com/polarismesh/polaris/plugin" - _ "github.com/polarismesh/polaris/plugin/cmdb/memory" - _ "github.com/polarismesh/polaris/plugin/discoverevent/local" - _ "github.com/polarismesh/polaris/plugin/healthchecker/memory" - _ "github.com/polarismesh/polaris/plugin/healthchecker/redis" - _ "github.com/polarismesh/polaris/plugin/history/logger" - _ "github.com/polarismesh/polaris/plugin/password" - _ "github.com/polarismesh/polaris/plugin/ratelimit/token" - _ "github.com/polarismesh/polaris/plugin/statis/logger" - _ "github.com/polarismesh/polaris/plugin/statis/prometheus" - "github.com/polarismesh/polaris/service/healthcheck" - "github.com/polarismesh/polaris/store" - "github.com/polarismesh/polaris/store/boltdb" - _ "github.com/polarismesh/polaris/store/boltdb" - _ "github.com/polarismesh/polaris/store/mysql" - sqldb "github.com/polarismesh/polaris/store/mysql" - testsuit "github.com/polarismesh/polaris/test/suit" -) - -const ( - tblUser string = "user" - tblStrategy string = "strategy" - tblGroup string = "group" -) - -type Bootstrap struct { - Logger map[string]*commonlog.Options -} - -type TestConfig struct { - Bootstrap Bootstrap `yaml:"bootstrap"` - Cache cache.Config `yaml:"cache"` - Namespace namespace.Config `yaml:"namespace"` - HealthChecks healthcheck.Config `yaml:"healthcheck"` - Store store.Config `yaml:"store"` - Auth auth.Config `yaml:"auth"` - Plugin plugin.Config `yaml:"plugin"` -} - -type AuthTestSuit struct { - testsuit.DiscoverTestSuit -} - -// 判断一个resp是否执行成功 -func respSuccess(resp api.ResponseMessage) bool { - - ret := api.CalcCode(resp) == 200 - - return ret -} - -type options func(cfg *TestConfig) - -func (d *AuthTestSuit) cleanAllUser() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from user where name like 'test%'"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblUser)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} - -func (d *AuthTestSuit) cleanAllUserGroup() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from user_group where name like 'test%'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from user_group_relation"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblGroup)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} - -func (d *AuthTestSuit) cleanAllAuthStrategy() { - if d.Storage.Name() == sqldb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*sqldb.BaseTx) - - defer dbTx.Rollback() - - if _, err := dbTx.Exec("delete from auth_strategy where id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from auth_principal where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - if _, err := dbTx.Exec("delete from auth_strategy_resource where strategy_id != 'fbca9bfa04ae4ead86e1ecf5811e32a9'"); err != nil { - dbTx.Rollback() - panic(err) - } - - dbTx.Commit() - }() - } else if d.Storage.Name() == boltdb.STORENAME { - func() { - tx, err := d.Storage.StartTx() - if err != nil { - panic(err) - } - - dbTx := tx.GetDelegateTx().(*bolt.Tx) - defer dbTx.Rollback() - - if err := dbTx.DeleteBucket([]byte(tblStrategy)); err != nil { - if !errors.Is(err, bolt.ErrBucketNotFound) { - panic(err) - } - } - - dbTx.Commit() - }() - } -} diff --git a/auth/user/server.go b/auth/user/server.go index db1484327..dd7254a3f 100644 --- a/auth/user/server.go +++ b/auth/user/server.go @@ -76,7 +76,8 @@ func (svr *Server) Name() string { return auth.DefaultUserMgnPluginName } -func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, cacheMgr cachetypes.CacheManager) error { +func (svr *Server) Initialize(authOpt *auth.Config, storage store.Store, policySvr auth.StrategyServer, + cacheMgr cachetypes.CacheManager) error { svr.cacheMgr = cacheMgr svr.storage = storage svr.policySvr = policySvr diff --git a/auth/user/user.go b/auth/user/user.go index 54535f9f0..5852dacd7 100644 --- a/auth/user/user.go +++ b/auth/user/user.go @@ -423,7 +423,13 @@ func (svr *Server) ResetUserToken(ctx context.Context, req *apisecurity.User) *a // step 3. 兜底措施:如果开启了鉴权的非严格模式,则根据错误的类型,判断是否转为匿名用户进行访问 // - 如果是访问权限控制相关模块(用户、用户组、权限策略),不得转为匿名用户 func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { + // 如果已经存在了解析出的 token 信息,则直接返回 + if _, ok := authCtx.GetAttachment(authcommon.TokenDetailInfoKey); ok { + return nil + } + checkErr := func() error { + // authToken := utils.ParseAuthToken(authCtx.GetRequestContext()) operator, err := svr.decodeToken(authToken) if err != nil { @@ -459,7 +465,8 @@ func (svr *Server) CheckCredential(authCtx *authcommon.AcquireContext) error { log.Warn("[Auth][Checker] parse operator info, downgrade to anonymous", utils.RequestID(authCtx.GetRequestContext()), zap.Error(checkErr)) // 操作者信息解析失败,降级为匿名用户 - authCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.NewAnonymous()) + authCtx.SetAttachment(authcommon.PrincipalKey, authcommon.NewAnonymousPrincipal()) + authCtx.SetAttachment(authcommon.TokenDetailInfoKey, auth.NewAnonymousOperatorInfo()) } return nil } diff --git a/auth/user/user_test.go b/auth/user/user_test.go deleted file mode 100644 index 3df71acd2..000000000 --- a/auth/user/user_test.go +++ /dev/null @@ -1,1009 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package defaultuser_test - -import ( - "context" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/golang/protobuf/ptypes/wrappers" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "github.com/stretchr/testify/assert" - - "github.com/polarismesh/polaris/auth" - defaultuser "github.com/polarismesh/polaris/auth/user" - "github.com/polarismesh/polaris/cache" - cachetypes "github.com/polarismesh/polaris/cache/api" - api "github.com/polarismesh/polaris/common/api/v1" - commonlog "github.com/polarismesh/polaris/common/log" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" - storemock "github.com/polarismesh/polaris/store/mock" -) - -type UserTest struct { - admin *authcommon.User - ownerOne *authcommon.User - ownerTwo *authcommon.User - - users []*authcommon.User - newUsers []*authcommon.User - groups []*authcommon.UserGroupDetail - newGroups []*authcommon.UserGroupDetail - allGroups []*authcommon.UserGroupDetail - - storage *storemock.MockStore - cacheMgn *cache.CacheManager - - svr auth.UserServer - - cancel context.CancelFunc - ctrl *gomock.Controller -} - -func newUserTest(t *testing.T) *UserTest { - ctrl := gomock.NewController(t) - - commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName).SetOutputLevel(commonlog.DebugLevel) - commonlog.GetScopeOrDefaultByName(commonlog.ConfigLoggerName).SetOutputLevel(commonlog.DebugLevel) - - users := createMockUser(10, "one") - newUsers := createMockUser(10, "two") - admin := createMockUser(1, "admin")[0] - admin.Type = authcommon.AdminUserRole - admin.Owner = "" - groups := createMockUserGroup(users) - - storage := storemock.NewMockStore(ctrl) - storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) - storage.EXPECT().GetServicesCount().AnyTimes().Return(uint32(1), nil) - storage.EXPECT().AddUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - storage.EXPECT().GetUserByName(gomock.Eq("create-user-1"), gomock.Any()).AnyTimes().Return(nil, nil) - storage.EXPECT().GetUserByName(gomock.Eq("create-user-2"), gomock.Any()).AnyTimes().Return(&authcommon.User{ - Name: "create-user-2", - }, nil) - - allUsers := append(append(users, newUsers...), admin) - - storage.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(allUsers, nil) - storage.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).AnyTimes().Return(groups, nil) - storage.EXPECT().UpdateUser(gomock.Any()).AnyTimes().Return(nil) - storage.EXPECT().DeleteUser(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - cfg := &cache.Config{} - ctx, cancel := context.WithCancel(context.Background()) - cacheMgn, err := cache.TestCacheInitialize(ctx, cfg, storage) - if err != nil { - t.Fatal(err) - } - - _ = cacheMgn.OpenResourceCache( - []cachetypes.ConfigEntry{ - { - Name: cachetypes.UsersName, - }, - }..., - ) - time.Sleep(5 * time.Second) - - _ = cache.TestRun(ctx, cacheMgn) - - _, proxySvr, err := defaultuser.BuildServer() - if err != nil { - t.Fatal(err) - } - proxySvr.Initialize(&auth.Config{ - User: &auth.UserConfig{ - Name: auth.DefaultUserMgnPluginName, - Option: map[string]interface{}{ - "salt": "polarismesh@2021", - }, - }, - }, storage, nil, cacheMgn) - - _ = cacheMgn.TestUpdate() - - return &UserTest{ - admin: admin, - ownerOne: users[0], - ownerTwo: newUsers[0], - - users: users, - newUsers: newUsers, - groups: groups, - - storage: storage, - cacheMgn: cacheMgn, - svr: proxySvr, - - cancel: cancel, - ctrl: ctrl, - } -} - -func (g *UserTest) Clean() { - g.ctrl.Finish() - g.cancel() - _ = g.cacheMgn.Close() - time.Sleep(2 * time.Second) -} - -func Test_server_CreateUsers(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("主账户创建账户-成功", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-1"}, - Password: &wrappers.StringValue{Value: "create-user-1"}, - }, - } - - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "create users must success") - }) - - t.Run("主账户创建账户-无用户名-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.InvalidUserName, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-密码错误-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-1"}, - Password: &wrappers.StringValue{Value: ""}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.InvalidUserPassword, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-同名用户-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-2"}, - Password: &wrappers.StringValue{Value: "create-user-2"}, - }, - } - - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.UserExisted, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-与主账户同名", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: userTest.ownerOne.Name}, - Password: &wrappers.StringValue{Value: "create-user-2"}, - }, - } - - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerOne.ID)).Return(userTest.ownerOne, nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.UserExisted, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-token为空-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-2"}, - Password: &wrappers.StringValue{Value: "create-user-2"}, - }, - } - - resp := userTest.svr.CreateUsers(context.Background(), createUsersReq) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.EmptyAutToken, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-token非法-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-2"}, - Password: &wrappers.StringValue{Value: "create-user-2"}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, "utils.ContextAuthTokenKey") - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.AuthTokenVerifyException, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("主账户创建账户-token被禁用-失败", func(t *testing.T) { - userTest.users[0].TokenEnable = false - // 让 cache 可以刷新到 - time.Sleep(time.Second) - - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-2"}, - Password: &wrappers.StringValue{Value: "create-user-2"}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.TokenDisabled, resp.Responses[0].Code.GetValue(), "create users must fail") - - userTest.users[0].TokenEnable = true - time.Sleep(time.Second) - }) - - t.Run("子主账户创建账户-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-1"}, - Password: &wrappers.StringValue{Value: "create-user-1"}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), "create users must fail") - }) - - t.Run("用户组token创建账户-失败", func(t *testing.T) { - createUsersReq := []*apisecurity.User{ - { - Id: &wrappers.StringValue{Value: utils.NewUUID()}, - Name: &wrappers.StringValue{Value: "create-user-1"}, - Password: &wrappers.StringValue{Value: "create-user-1"}, - }, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.groups[1].Token) - resp := userTest.svr.CreateUsers(reqCtx, createUsersReq) - - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.OperationRoleException, resp.Responses[0].Code.GetValue(), "create users must fail") - }) -} - -func Test_server_Login(t *testing.T) { - - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("正常登陆", func(t *testing.T) { - rsp := userTest.svr.Login(&apisecurity.LoginRequest{ - Name: &wrappers.StringValue{Value: userTest.users[0].Name}, - Password: &wrappers.StringValue{Value: "polaris"}, - }) - - assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) - }) - - t.Run("错误的密码", func(t *testing.T) { - rsp := userTest.svr.Login(&apisecurity.LoginRequest{ - Name: &wrappers.StringValue{Value: userTest.users[0].Name}, - Password: &wrappers.StringValue{Value: "polaris_123"}, - }) - - assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) - assert.Equal(t, uint32(apimodel.Code_NotAllowedAccess), rsp.GetCode().GetValue()) - assert.Contains(t, rsp.GetInfo().GetValue(), authcommon.ErrorWrongUsernameOrPassword.Error()) - }) -} - -func Test_server_UpdateUser(t *testing.T) { - - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("主账户更新账户信息-正常更新自己的信息", func(t *testing.T) { - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: userTest.users[0].ID}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") - }) - - t.Run("主账户更新账户信息-更新不存在的子账户", func(t *testing.T) { - uid := utils.NewUUID() - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: uid}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(nil, nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.NotFoundUser, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("主账户更新账户信息-更新不属于自己的子账户", func(t *testing.T) { - uid := utils.NewUUID() - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: uid}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ - ID: uid, - Owner: utils.NewUUID(), - }, nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户信息-正常更新自己的信息", func(t *testing.T) { - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: userTest.users[1].ID}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户信息-更新别的账户", func(t *testing.T) { - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: userTest.users[2].ID}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("用户组Token更新账户信息-更新别的账户", func(t *testing.T) { - req := &apisecurity.User{ - Id: &wrappers.StringValue{Value: userTest.users[2].ID}, - Comment: &wrappers.StringValue{Value: "update owner account info"}, - } - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.groups[1].Token) - resp := userTest.svr.UpdateUser(reqCtx, req) - - t.Logf("UpdateUsers resp : %+v", resp) - assert.Equal(t, api.OperationRoleException, resp.Code.GetValue(), "update user must fail") - }) -} - -func Test_server_UpdateUserPassword(t *testing.T) { - - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("主账户正常更新自身账户密码", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[0].ID}, - OldPassword: &wrappers.StringValue{Value: "polaris"}, - NewPassword: &wrappers.StringValue{Value: "polaris@2021"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") - }) - - t.Run("主账户正常更新自身账户密码-新密码非法", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[0].ID}, - OldPassword: &wrappers.StringValue{Value: "polaris"}, - NewPassword: &wrappers.StringValue{Value: "pola"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") - - req = &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[0].ID}, - OldPassword: &wrappers.StringValue{Value: "polaris"}, - NewPassword: &wrappers.StringValue{Value: ""}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp = userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") - - req = &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[0].ID}, - OldPassword: &wrappers.StringValue{Value: "polaris"}, - NewPassword: &wrappers.StringValue{Value: "polarispolarispolarispolaris"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx = context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp = userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("主账户正常更新子账户密码", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[1].ID}, - NewPassword: &wrappers.StringValue{Value: "polaris@sub"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must success") - }) - - t.Run("主账户正常更新子账户密码-子账户非自己", func(t *testing.T) { - - uid := utils.NewUUID() - - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: uid}, - NewPassword: &wrappers.StringValue{Value: "polaris@subaccount"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ - ID: uid, - Owner: utils.NewUUID(), - }, nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户密码-自身-携带正确原密码", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[2].ID}, - OldPassword: &wrappers.StringValue{Value: "polaris"}, - NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[2].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteSuccess, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户密码-自身-携带错误原密码", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[1].ID}, - OldPassword: &wrappers.StringValue{Value: "users[1].Password"}, - NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户密码-自身-无携带原密码", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[1].ID}, - NewPassword: &wrappers.StringValue{Value: "users[1].Password"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.ExecuteException, resp.Code.GetValue(), "update user must fail") - }) - - t.Run("子账户更新账户密码-不是自己", func(t *testing.T) { - req := &apisecurity.ModifyUserPassword{ - Id: &wrappers.StringValue{Value: userTest.users[2].ID}, - NewPassword: &wrappers.StringValue{Value: "users[2].Password"}, - } - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[2], nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.UpdateUserPassword(reqCtx, req) - t.Logf("CreateUsers resp : %+v", resp) - assert.Equal(t, api.NotAllowedAccess, resp.Code.GetValue(), "update user must fail") - }) -} - -func Test_server_DeleteUser(t *testing.T) { - t.Run("主账户删除自己", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }, - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("主账户删除另外一个主账户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - uid := utils.NewUUID() - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ - ID: uid, - Type: authcommon.OwnerUserRole, - Owner: "", - }, nil) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - { - Id: utils.NewStringValue(uid), - }, - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("主账户删除自己的子账户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.users[1].ID)).Return(userTest.users[1], nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - { - Id: utils.NewStringValue(userTest.users[1].ID), - }, - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户删除不是自己的子账户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - uid := utils.NewUUID() - oid := utils.NewUUID() - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{ - ID: uid, - Type: authcommon.OwnerUserRole, - Owner: oid, - }, nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[0].Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - &apisecurity.User{ - Id: utils.NewStringValue(uid), - }, - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("管理员删除主账户-主账户下没有子账户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil).AnyTimes() - userTest.storage.EXPECT().GetSubCount(gomock.Any()).Return(uint32(0), nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.admin.Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }, - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("管理员删除主账户-主账户下还有子账户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.ownerOne, nil).AnyTimes() - userTest.storage.EXPECT().GetSubCount(gomock.Any()).Return(uint32(1), nil).AnyTimes() - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.admin.Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }, - }) - - assert.True(t, resp.GetCode().Value == api.SubAccountExisted, resp.Info.GetValue()) - }) - - t.Run("子账户删除用户", func(t *testing.T) { - userTest := newUserTest(t) - t.Cleanup(func() { - userTest.Clean() - }) - - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - resp := userTest.svr.DeleteUsers(reqCtx, []*apisecurity.User{ - &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }, - }) - - assert.True(t, resp.GetCode().Value == api.OperationRoleException, resp.Info.GetValue()) - }) -} - -func Test_server_GetUserToken(t *testing.T) { - - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("主账户查询自己的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("子账户查询自己的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - - resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户查询子账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户查询别的主账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.ownerTwo.ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("主账户查询不属于自己子账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - resp := userTest.svr.GetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.newUsers[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_server_RefreshUserToken(t *testing.T) { - - userTest := newUserTest(t) - defer userTest.Clean() - - t.Run("主账户刷新自己的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[0], nil) - - resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[0].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("子账户刷新自己的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[1].Token) - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户刷新子账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.users[1], nil) - resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户刷新别的主账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.ownerTwo, nil).AnyTimes() - resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.ownerTwo.ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("主账户刷新不属于自己子账户的Token", func(t *testing.T) { - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(userTest.newUsers[1], nil).AnyTimes() - resp := userTest.svr.ResetUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.newUsers[1].ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_server_UpdateUserToken(t *testing.T) { - t.Run("主账户刷新自己的Token状态", func(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - _ = userTest.cacheMgn.TestUpdate() - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.ownerOne.ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) - - t.Run("子账户刷新自己的Token状态", func(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - _ = userTest.cacheMgn.TestUpdate() - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.users[4].Token) - - userTest.storage.EXPECT().GetUser(gomock.Any()).Return(&authcommon.User{}, nil).AnyTimes() - userTest.storage.EXPECT().UpdateUser(gomock.Any()).Return(nil).AnyTimes() - - resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[4].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户刷新子账户的Token状态", func(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.users[3].ID)).Return(userTest.users[3], nil) - resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.users[3].ID), - }) - - assert.True(t, resp.GetCode().Value == api.ExecuteSuccess, resp.Info.GetValue()) - }) - - t.Run("主账户刷新别的主账户的Token状态", func(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - - t.Logf("operator-id : %s, user-two-owner : %s", userTest.ownerOne.ID, userTest.ownerTwo.ID) - - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.ownerTwo.ID)).Return(userTest.ownerTwo, nil).AnyTimes() - resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.ownerTwo.ID), - }) - - assert.Truef(t, resp.GetCode().Value == api.NotAllowedAccess, "code=%d, msg=%s", resp.Code.GetValue(), resp.Info.GetValue()) - }) - - t.Run("主账户刷新不属于自己子账户的Token状态", func(t *testing.T) { - userTest := newUserTest(t) - defer userTest.Clean() - - _ = userTest.cacheMgn.TestUpdate() - reqCtx := context.WithValue(context.Background(), utils.ContextAuthTokenKey, userTest.ownerOne.Token) - userTest.storage.EXPECT().GetUser(gomock.Eq(userTest.newUsers[3].ID)).Return(userTest.newUsers[3], nil).AnyTimes() - resp := userTest.svr.EnableUserToken(reqCtx, &apisecurity.User{ - Id: utils.NewStringValue(userTest.newUsers[3].ID), - }) - - assert.True(t, resp.GetCode().Value == api.NotAllowedAccess, resp.Info.GetValue()) - }) -} - -func Test_AuthServer_NormalOperateUser(t *testing.T) { - suit := &AuthTestSuit{} - if err := suit.Initialize(); err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - suit.cleanAllAuthStrategy() - suit.cleanAllUser() - suit.cleanAllUserGroup() - suit.Destroy() - }) - - users := createApiMockUser(10, "test") - - t.Run("正常创建用户", func(t *testing.T) { - resp := suit.UserServer().CreateUsers(suit.DefaultCtx, users) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - }) - - t.Run("非正常创建用户-直接操作存储层", func(t *testing.T) { - err := suit.Storage.AddUser(nil, &authcommon.User{}) - assert.Error(t, err) - }) - - t.Run("正常更新用户", func(t *testing.T) { - users[0].Comment = utils.NewStringValue("update user comment") - resp := suit.UserServer().UpdateUser(suit.DefaultCtx, users[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ - "id": users[0].GetId().GetValue(), - }) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - assert.Equal(t, 1, int(qresp.Amount.GetValue())) - assert.Equal(t, 1, int(qresp.Size.GetValue())) - - retUsers := qresp.GetUsers()[0] - assert.Equal(t, users[0].GetComment().GetValue(), retUsers.GetComment().GetValue()) - }) - - t.Run("正常删除用户", func(t *testing.T) { - resp := suit.UserServer().DeleteUsers(suit.DefaultCtx, []*apisecurity.User{users[3]}) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - qresp := suit.UserServer().GetUsers(suit.DefaultCtx, map[string]string{ - "id": users[3].GetId().GetValue(), - }) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - assert.Equal(t, 0, int(qresp.Amount.GetValue())) - assert.Equal(t, 0, int(qresp.Size.GetValue())) - }) - - t.Run("正常更新用户Token", func(t *testing.T) { - resp := suit.UserServer().ResetUserToken(suit.DefaultCtx, users[0]) - - if !respSuccess(resp) { - t.Fatal(resp.GetInfo().GetValue()) - } - - _ = suit.CacheMgr().TestUpdate() - - qresp := suit.UserServer().GetUserToken(suit.DefaultCtx, users[0]) - if !respSuccess(qresp) { - t.Fatal(resp.GetInfo().GetValue()) - } - assert.Equal(t, resp.GetUser().GetAuthToken().GetValue(), qresp.GetUser().GetAuthToken().GetValue()) - }) -} diff --git a/bootstrap/server.go b/bootstrap/server.go index 117679ddb..f3f190632 100644 --- a/bootstrap/server.go +++ b/bootstrap/server.go @@ -76,7 +76,7 @@ func Start(configFilePath string) { fmt.Printf("[ERROR] config yaml marshal fail\n") return } - fmt.Printf(string(c)) + _, _ = fmt.Println(string(c)) // 初始化日志打印 err = log.Configure(cfg.Bootstrap.Logger) diff --git a/cache/api/funcs.go b/cache/api/funcs.go index 0fa3d26a6..3a4e42ef4 100644 --- a/cache/api/funcs.go +++ b/cache/api/funcs.go @@ -314,7 +314,7 @@ func AppendAuthPolicyPredicate(ctx context.Context, p AuthPolicyPredicate) conte func LoadAuthPolicyPredicates(ctx context.Context) []AuthPolicyPredicate { var predicates []AuthPolicyPredicate - val := ctx.Value(userGroupPredicateCtxKey{}) + val := ctx.Value(authPolicyPredicateCtxKey{}) if val != nil { predicates, _ = val.([]AuthPolicyPredicate) } diff --git a/cache/api/types.go b/cache/api/types.go index d000f1828..506d63aac 100644 --- a/cache/api/types.go +++ b/cache/api/types.go @@ -172,6 +172,15 @@ type CacheManager interface { type ( // NamespacePredicate . NamespacePredicate func(context.Context, *model.Namespace) bool + // NamespaceArgs + NamespaceArgs struct { + // Filter extend filter params + Filter map[string][]string + // Offset + Offset int + // Limit + Limit int + } // NamespaceCache 命名空间的 Cache 接口 NamespaceCache interface { @@ -184,6 +193,8 @@ type ( GetNamespaceList() []*model.Namespace // GetVisibleNamespaces list target namespace can visible other namespaces GetVisibleNamespaces(namespace string) []*model.Namespace + // Query . + Query(context.Context, *NamespaceArgs) (uint32, []*model.Namespace, error) } ) @@ -246,9 +257,9 @@ type ( GetServicesByFilter(ctx context.Context, serviceFilters *ServiceArgs, instanceFilters *store.InstanceArgs, offset, limit uint32) (uint32, []*model.EnhancedService, error) // ListServices get service list and revision by namespace - ListServices(ns string) (string, []*model.Service) + ListServices(ctx context.Context, ns string) (string, []*model.Service) // ListAllServices get all service and revision - ListAllServices() (string, []*model.Service) + ListAllServices(ctx context.Context) (string, []*model.Service) // ListServiceAlias list service link alias list ListServiceAlias(namespace, name string) []*model.Service // GetAliasFor get alias reference service info @@ -317,30 +328,10 @@ type ( FaultDetectArgs struct { // Filter extend filter params Filter map[string]string - // ID route rule id - ID string - // Name route rule name - Name string - // Service service name - Service string - // Namespace namesapce - Namespace string - ServiceNamespace string - DstNamespace string - DstService string - DstMethod string - // Enable - Enable *bool // Offset Offset uint32 // Limit Limit uint32 - // OrderField Sort field - OrderField string - // OrderType Sorting rules - OrderType string - // Predicates 额外的数据检查 - Predicates []FaultDetectPredicate } // FaultDetectCache fault detect rule cache service @@ -362,26 +353,10 @@ type ( LaneGroupArgs struct { // Filter extend filter params Filter map[string]string - // ID route rule id - ID string - // Name route rule name - Name string - // Service service name - Service string - // Namespace namesapce - Namespace string - // Enable - Enable *bool // Offset Offset uint32 // Limit Limit uint32 - // OrderField Sort field - OrderField string - // OrderType Sorting rules - OrderType string - // Predicates 额外的数据检查 - Predicates []FaultDetectPredicate } // LaneCache . LaneCache interface { @@ -428,8 +403,6 @@ type ( OrderField string // OrderType Sorting rules OrderType string - // Predicates 额外的数据检查 - Predicates []RouteRulePredicate } // RouterRuleIterProc Method definition of routing rules @@ -484,8 +457,6 @@ type ( OrderField string // OrderType Sorting rules OrderType string - // Predicates . - Predicates []RateLimitRulePredicate } // RateLimitIterProc rate limit iter func @@ -531,34 +502,10 @@ type ( CircuitBreakerRuleArgs struct { // Filter extend filter params Filter map[string]string - // ID route rule id - ID string - // Name route rule name - Name string - // Service service name - Service string - // Namespace namesapce - Namespace string - // SourceService source service name - SourceService string - // SourceNamespace source service namespace - SourceNamespace string - // DestinationService destination service name - DestinationService string - // DestinationNamespace destination service namespace - DestinationNamespace string - // Enable - Enable *bool // Offset Offset uint32 // Limit Limit uint32 - // OrderField Sort field - OrderField string - // OrderType Sorting rules - OrderType string - // Predicates 额外的数据检查 - Predicates []CircuitBreakerPredicate } // CircuitBreakerCache circuitBreaker配置的cache接口 CircuitBreakerCache interface { @@ -708,10 +655,12 @@ type ( // StrategyCache is a cache for strategy rules. StrategyCache interface { Cache + // GetPolicyRule 获取策略信息 + GetPolicyRule(id string) *authcommon.StrategyDetail // GetPrincipalPolicies 根据 effect 获取 principal 的策略信息 GetPrincipalPolicies(effect string, p authcommon.Principal) []*authcommon.StrategyDetail // Hint 确认某个 principal 对于资源的访问权限 - Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction + Hint(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction // Query . Query(context.Context, PolicySearchArgs) (uint32, []*authcommon.StrategyDetail, error) } diff --git a/cache/auth/policy.go b/cache/auth/policy.go index 771d1bb86..a7262dd65 100644 --- a/cache/auth/policy.go +++ b/cache/auth/policy.go @@ -50,9 +50,6 @@ type policyCache struct { // principalResources principalResources map[authcommon.PrincipalType]*utils.SyncMap[string, *authcommon.PrincipalResourceContainer] - allowResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] - denyResourceLabels *utils.SyncMap[string, *utils.RefSyncSet[string]] - singleFlight *singleflight.Group } @@ -89,8 +86,6 @@ func (sc *policyCache) initContainers() { authcommon.PrincipalUser: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), authcommon.PrincipalGroup: utils.NewSyncMap[string, *authcommon.PrincipalResourceContainer](), } - sc.allowResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() - sc.denyResourceLabels = utils.NewSyncMap[string, *utils.RefSyncSet[string]]() } func (sc *policyCache) Name() string { @@ -133,7 +128,8 @@ func (sc *policyCache) setStrategys(strategies []*authcommon.StrategyDetail) (ma for index := range strategies { rule := strategies[index] - sc.handlePrincipalPolicies(rule) + cacheData := authcommon.NewPolicyDetailCache(rule) + sc.handlePrincipalPolicies(cacheData) if !rule.Valid { sc.rules.Delete(rule.ID) remove++ @@ -143,17 +139,16 @@ func (sc *policyCache) setStrategys(strategies []*authcommon.StrategyDetail) (ma } else { update++ } - sc.rules.Store(rule.ID, authcommon.NewPolicyDetailCache(rule)) + sc.rules.Store(rule.ID, cacheData) } lastMtime = int64(math.Max(float64(lastMtime), float64(rule.ModifyTime.Unix()))) } - return map[string]time.Time{sc.Name(): time.Unix(lastMtime, 0)}, add, update, remove } // handlePrincipalPolicies -func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.StrategyDetail) { +func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.PolicyDetailCache) { // 计算 uid -> auth rule principals := rule.Principals @@ -192,7 +187,7 @@ func (sc *policyCache) handlePrincipalPolicies(rule *authcommon.StrategyDetail) } } -func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule *authcommon.StrategyDetail, del bool) { +func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule *authcommon.PolicyDetailCache, del bool) { linkContainers := sc.allowPolicies[principal.PrincipalType] if rule.Action == apisecurity.AuthAction_DENY.String() { linkContainers = sc.denyPolicies[principal.PrincipalType] @@ -210,27 +205,45 @@ func (sc *policyCache) writePrincipalLink(principal authcommon.Principal, rule * values.Add(rule.ID) } - principalResources, _ := sc.principalResources[principal.PrincipalType].ComputeIfAbsent(principal.PrincipalID, func(k string) *authcommon.PrincipalResourceContainer { - return authcommon.NewPrincipalResourceContainer() - }) + principalResources, _ := sc.principalResources[principal.PrincipalType].ComputeIfAbsent(principal.PrincipalID, + func(k string) *authcommon.PrincipalResourceContainer { + return authcommon.NewPrincipalResourceContainer() + }) - if rule.IsDeny() { - for i := range rule.Resources { - item := rule.Resources[i] - if rule.Valid { - principalResources.SaveDenyResource(item) - } else { - principalResources.DelDenyResource(item) + if oldRule, ok := sc.rules.Load(rule.ID); ok { + // 如果 action 不一致,则需要先清理掉之前的 + if oldRule.GetAction() != rule.GetAction() { + for i := range oldRule.Resources { + principalResources.DelResource(oldRule.GetAction(), oldRule.Resources[i]) + } + } else { + // 如果 action 一致,那么需要 diff 出移除的资源,然后移除 + waitRemove := make([]*authcommon.StrategyResource, 0, 8) + for i := range oldRule.Resources { + item := oldRule.Resources[i] + resContainer, ok := rule.ResourceDict[apisecurity.ResourceType(item.ResType)] + if !ok { + waitRemove = append(waitRemove, &item) + continue + } + if ok := resContainer.Contains(item.ResID); !ok { + waitRemove = append(waitRemove, &item) + } + } + for i := range waitRemove { + item := waitRemove[i] + principalResources.DelResource(rule.GetAction(), *item) } } - return } + + // 处理新的资源 for i := range rule.Resources { item := rule.Resources[i] if rule.Valid { - principalResources.SaveAllowResource(item) + principalResources.SaveResource(rule.GetAction(), item) } else { - principalResources.DelAllowResource(item) + principalResources.DelResource(rule.GetAction(), item) } } } @@ -271,8 +284,17 @@ func (sc *policyCache) GetPrincipalPolicies(effect string, p authcommon.Principa return result } +func (sc *policyCache) GetPolicyRule(id string) *authcommon.StrategyDetail { + strategy, ok := sc.rules.Load(id) + if !ok { + return nil + } + return strategy.StrategyDetail +} + // GetPrincipalResources 返回 principal 的资源信息,返回顺序为 (allow, deny) -func (sc *policyCache) Hint(p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction { +func (sc *policyCache) Hint(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry) apisecurity.AuthAction { + // 先比较下资源是否存在于某些鉴权规则中 resources, ok := sc.principalResources[p.PrincipalType].Load(p.PrincipalID) if !ok { return apisecurity.AuthAction_DENY @@ -281,30 +303,52 @@ func (sc *policyCache) Hint(p authcommon.Principal, r *authcommon.ResourceEntry) if ok { return action } + // 如果没办法从直接的 resource 中判断出来,那就根据资源标签在确认下,注意,这里必须 allMatch 才可以 - if sc.hintLabels(p, r, sc.denyResourceLabels) { + if sc.hintLabels(ctx, p, r, sc.GetPrincipalPolicies("deny", p)) { return apisecurity.AuthAction_DENY } - if sc.hintLabels(p, r, sc.allowResourceLabels) { + if sc.hintLabels(ctx, p, r, sc.GetPrincipalPolicies("allow", p)) { return apisecurity.AuthAction_ALLOW } return apisecurity.AuthAction_DENY } -func (sc *policyCache) hintLabels(p authcommon.Principal, r *authcommon.ResourceEntry, - containers *utils.SyncMap[string, *utils.RefSyncSet[string]]) bool { - allMatch := true - for k, v := range r.Metadata { - labelVals, ok := sc.denyResourceLabels.Load(k) - if !ok { - allMatch = false +func (sc *policyCache) hintLabels(ctx context.Context, p authcommon.Principal, r *authcommon.ResourceEntry, + policies []*authcommon.StrategyDetail) bool { + var principalCondition []authcommon.Condition + if val, ok := ctx.Value(authcommon.ContextKeyConditions{}).([]authcommon.Condition); ok { + principalCondition = val + } + + for i := range policies { + item := policies[i] + conditions := item.Conditions + if len(conditions) == 0 { + conditions = principalCondition } - allMatch = labelVals.Contains(v) - if !allMatch { - break + allMatch := len(conditions) != 0 + for j := range conditions { + condition := conditions[j] + val, ok := r.Metadata[condition.Key] + if !ok { + allMatch = false + break + } + if compareFunc, ok := authcommon.ConditionCompareDict[condition.CompareFunc]; ok { + if allMatch = compareFunc(val, condition.Value); !allMatch { + break + } + } else { + allMatch = false + break + } + } + if allMatch { + return true } } - return allMatch + return false } // Query implements api.StrategyCache. @@ -318,9 +362,9 @@ func (sc *policyCache) Query(ctx context.Context, args types.PolicySearchArgs) ( searchOwner, hasOwner := args.Filters["owner"] searchDefault, hasDefault := args.Filters["default"] searchResType, hasResType := args.Filters["res_type"] - searchResID, _ := args.Filters["res_id"] + searchResID := args.Filters["res_id"] searchPrincipalId, hasPrincipalId := args.Filters["principal_id"] - searchPrincipalType, _ := args.Filters["principal_type"] + searchPrincipalType := args.Filters["principal_type"] predicates := types.LoadAuthPolicyPredicates(ctx) @@ -380,32 +424,21 @@ func (sc *policyCache) Query(ctx context.Context, args types.PolicySearchArgs) ( } rules = append(rules, val.StrategyDetail) }) - + sort.Slice(rules, func(i, j int) bool { + return rules[i].ModifyTime.After(rules[j].ModifyTime) + }) total, ret := sc.toPage(rules, args) return total, ret, nil } func (sc *policyCache) toPage(rules []*authcommon.StrategyDetail, args types.PolicySearchArgs) (uint32, []*authcommon.StrategyDetail) { - beginIndex := args.Offset - endIndex := beginIndex + args.Limit - totalCount := uint32(len(rules)) - - if totalCount == 0 { - return totalCount, []*authcommon.StrategyDetail{} - } - if beginIndex >= endIndex { - return totalCount, []*authcommon.StrategyDetail{} + total := uint32(len(rules)) + if args.Offset >= total || args.Limit == 0 { + return total, nil } - if beginIndex >= totalCount { - return totalCount, []*authcommon.StrategyDetail{} + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total } - if endIndex > totalCount { - endIndex = totalCount - } - - sort.Slice(rules, func(i, j int) bool { - return rules[i].ModifyTime.After(rules[j].ModifyTime) - }) - - return totalCount, rules[beginIndex:endIndex] + return total, rules[args.Offset:endIdx] } diff --git a/cache/auth/role.go b/cache/auth/role.go index f249c5b66..80555412b 100644 --- a/cache/auth/role.go +++ b/cache/auth/role.go @@ -21,12 +21,13 @@ import ( "context" "time" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" + types "github.com/polarismesh/polaris/cache/api" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" ) // NewRoleCache @@ -146,7 +147,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { if isDel { users := role.Users for i := range users { - container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].PrincipalID, func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -154,7 +155,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } groups := role.UserGroups for i := range groups { - container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].PrincipalID, func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -164,7 +165,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } users := role.Users for i := range users { - container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].SelfID(), + container, _ := r.principalRoles[authcommon.PrincipalUser].ComputeIfAbsent(users[i].PrincipalID, func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -172,7 +173,7 @@ func (r *roleCache) dealPrincipalRoles(role *authcommon.Role, isDel bool) { } groups := role.UserGroups for i := range groups { - container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].SelfID(), + container, _ := r.principalRoles[authcommon.PrincipalGroup].ComputeIfAbsent(groups[i].PrincipalID, func(k string) *utils.SyncSet[string] { return utils.NewSyncSet[string]() }) @@ -226,8 +227,8 @@ func (r *roleCache) toPage(total uint32, roles []*authcommon.Role, args types.Ro if args.Limit == 0 { return total, roles } - start := args.Limit * (args.Offset - 1) - end := args.Limit * args.Offset + start := args.Limit * args.Offset + end := args.Limit * (args.Offset + 1) if start > total { return total, nil } @@ -239,7 +240,11 @@ func (r *roleCache) toPage(total uint32, roles []*authcommon.Role, args types.Ro // GetPrincipalRoles implements api.RoleCache. func (r *roleCache) GetPrincipalRoles(p authcommon.Principal) []*authcommon.Role { - containers, ok := r.principalRoles[p.PrincipalType].Load(p.PrincipalID) + roleContainers, ok := r.principalRoles[p.PrincipalType] + if !ok { + return nil + } + containers, ok := roleContainers.Load(p.PrincipalID) if !ok { return nil } diff --git a/cache/auth/user.go b/cache/auth/user.go index 1e7084875..c8f989d04 100644 --- a/cache/auth/user.go +++ b/cache/auth/user.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "math" + "sort" "sync/atomic" "time" @@ -376,7 +377,7 @@ func (uc *userCache) QueryUsers(ctx context.Context, args types.UserSearchArgs) if hasId && searchId != key { return } - if hasOwner && val.Owner != searchOwner { + if hasOwner && (val.Owner != searchOwner && val.ID != searchOwner) { return } if hasName && !utils.IsWildMatch(val.Name, searchName) { @@ -393,28 +394,31 @@ func (uc *userCache) QueryUsers(ctx context.Context, args types.UserSearchArgs) result = append(result, val) }) + sort.Slice(result, func(i, j int) bool { + return result[i].ModifyTime.After(result[j].ModifyTime) + }) total, ret := uc.listUsersPage(result, args) return total, ret, nil } func (uc *userCache) listUsersPage(users []*authcommon.User, args types.UserSearchArgs) (uint32, []*authcommon.User) { total := uint32(len(users)) - if args.Limit == 0 { - return total, nil - } - start := args.Limit * (args.Offset - 1) - end := args.Limit * args.Offset - if start > total { + if args.Offset >= total || args.Limit == 0 { return total, nil } - if end > total { - end = total + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total } - return total, users[start:end] + return total, users[args.Offset:endIdx] } // QueryUserGroups . func (uc *userCache) QueryUserGroups(ctx context.Context, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail, error) { + if err := uc.Update(); err != nil { + return 0, nil, err + } + searchId, hasId := args.Filters["id"] searchName, hasName := args.Filters["name"] searchOwner, hasOwner := args.Filters["owner"] @@ -460,13 +464,14 @@ func (uc *userCache) QueryUserGroups(ctx context.Context, args types.UserGroupSe return total, ret, nil } -func (uc *userCache) listUserGroupsPage(groups []*authcommon.UserGroupDetail, args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail) { +func (uc *userCache) listUserGroupsPage(groups []*authcommon.UserGroupDetail, + args types.UserGroupSearchArgs) (uint32, []*authcommon.UserGroupDetail) { total := uint32(len(groups)) if args.Limit == 0 { return total, nil } - start := args.Limit * (args.Offset - 1) - end := args.Limit * args.Offset + start := args.Limit * args.Offset + end := args.Limit * (args.Offset + 1) if start > total { return total, nil } diff --git a/cache/default.go b/cache/default.go index d89e03fce..bde20521e 100644 --- a/cache/default.go +++ b/cache/default.go @@ -110,6 +110,7 @@ func newCacheManager(ctx context.Context, cacheOpt *Config, storage store.Store) mgr.RegisterCacher(types.CacheRole, cacheauth.NewRoleCache(storage, mgr)) // 北极星SDK Client mgr.RegisterCacher(types.CacheClient, cacheclient.NewClientCache(storage, mgr)) + // 灰度规则 mgr.RegisterCacher(types.CacheGray, cachegray.NewGrayCache(storage, mgr)) if len(mgr.caches) != int(types.CacheLast) { diff --git a/cache/mock/cache_mock.go b/cache/mock/cache_mock.go index 330596ad5..6a8441864 100644 --- a/cache/mock/cache_mock.go +++ b/cache/mock/cache_mock.go @@ -581,6 +581,22 @@ func (mr *MockNamespaceCacheMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockNamespaceCache)(nil).Name)) } +// Query mocks base method. +func (m *MockNamespaceCache) Query(arg0 context.Context, arg1 *api.NamespaceArgs) (uint32, []*model.Namespace, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].([]*model.Namespace) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Query indicates an expected call of Query. +func (mr *MockNamespaceCacheMockRecorder) Query(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockNamespaceCache)(nil).Query), arg0, arg1) +} + // Update mocks base method. func (m *MockNamespaceCache) Update() error { m.ctrl.T.Helper() @@ -829,18 +845,18 @@ func (mr *MockServiceCacheMockRecorder) IteratorServices(iterProc interface{}) * } // ListAllServices mocks base method. -func (m *MockServiceCache) ListAllServices() (string, []*model.Service) { +func (m *MockServiceCache) ListAllServices(ctx context.Context) (string, []*model.Service) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAllServices") + ret := m.ctrl.Call(m, "ListAllServices", ctx) ret0, _ := ret[0].(string) ret1, _ := ret[1].([]*model.Service) return ret0, ret1 } // ListAllServices indicates an expected call of ListAllServices. -func (mr *MockServiceCacheMockRecorder) ListAllServices() *gomock.Call { +func (mr *MockServiceCacheMockRecorder) ListAllServices(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllServices", reflect.TypeOf((*MockServiceCache)(nil).ListAllServices)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllServices", reflect.TypeOf((*MockServiceCache)(nil).ListAllServices), ctx) } // ListServiceAlias mocks base method. @@ -858,18 +874,18 @@ func (mr *MockServiceCacheMockRecorder) ListServiceAlias(namespace, name interfa } // ListServices mocks base method. -func (m *MockServiceCache) ListServices(ns string) (string, []*model.Service) { +func (m *MockServiceCache) ListServices(ctx context.Context, ns string) (string, []*model.Service) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListServices", ns) + ret := m.ctrl.Call(m, "ListServices", ctx, ns) ret0, _ := ret[0].(string) ret1, _ := ret[1].([]*model.Service) return ret0, ret1 } // ListServices indicates an expected call of ListServices. -func (mr *MockServiceCacheMockRecorder) ListServices(ns interface{}) *gomock.Call { +func (mr *MockServiceCacheMockRecorder) ListServices(ctx, ns interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServices", reflect.TypeOf((*MockServiceCache)(nil).ListServices), ns) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServices", reflect.TypeOf((*MockServiceCache)(nil).ListServices), ctx, ns) } // Name mocks base method. @@ -2874,6 +2890,20 @@ func (mr *MockStrategyCacheMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStrategyCache)(nil).Close)) } +// GetPolicyRule mocks base method. +func (m *MockStrategyCache) GetPolicyRule(id string) *auth.StrategyDetail { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPolicyRule", id) + ret0, _ := ret[0].(*auth.StrategyDetail) + return ret0 +} + +// GetPolicyRule indicates an expected call of GetPolicyRule. +func (mr *MockStrategyCacheMockRecorder) GetPolicyRule(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPolicyRule", reflect.TypeOf((*MockStrategyCache)(nil).GetPolicyRule), id) +} + // GetPrincipalPolicies mocks base method. func (m *MockStrategyCache) GetPrincipalPolicies(effect string, p auth.Principal) []*auth.StrategyDetail { m.ctrl.T.Helper() diff --git a/cache/namespace/namespace.go b/cache/namespace/namespace.go index 1447ddb36..41c8b931f 100644 --- a/cache/namespace/namespace.go +++ b/cache/namespace/namespace.go @@ -18,7 +18,9 @@ package namespace import ( + "context" "math" + "sort" "time" "go.uber.org/zap" @@ -63,9 +65,7 @@ func (nsCache *namespaceCache) Initialize(c map[string]interface{}) error { // Update func (nsCache *namespaceCache) Update() error { // 多个线程竞争,只有一个线程进行更新 - _, err, _ := nsCache.updater.Do(nsCache.Name(), func() (interface{}, error) { - return nil, nsCache.DoCacheUpdate(nsCache.Name(), nsCache.realUpdate) - }) + err, _ := nsCache.singleUpdate() return err } @@ -83,7 +83,6 @@ func (nsCache *namespaceCache) realUpdate() (map[string]time.Time, int64, error) } func (nsCache *namespaceCache) setNamespaces(nsSlice []*model.Namespace) map[string]time.Time { - lastMtime := nsCache.LastMtime(nsCache.Name()).Unix() for index := range nsSlice { @@ -200,3 +199,96 @@ func (nsCache *namespaceCache) GetNamespaceList() []*model.Namespace { return nsArr } + +// forceQueryUpdate 为了确保读取的数据是最新的,这里需要做一个强制 update 的动作进行数据读取处理 +func (nsCache *namespaceCache) forceQueryUpdate() error { + err, shared := nsCache.singleUpdate() + // shared == true,表示当前已经有正在 update 执行的任务,这个任务不一定能够读取到最新的数据 + // 为了避免读取到脏数据,在发起一次 singleUpdate + if shared { + log.Debug("[Cache][Namespace] force query update from store") + err, _ = nsCache.singleUpdate() + } + return err +} + +func (nsCache *namespaceCache) singleUpdate() (error, bool) { + // 多个线程竞争,只有一个线程进行更新 + _, err, shared := nsCache.updater.Do(nsCache.Name(), func() (interface{}, error) { + return nil, nsCache.DoCacheUpdate(nsCache.Name(), nsCache.realUpdate) + }) + return err, shared +} + +func (nsCache *namespaceCache) Query(ctx context.Context, args *types.NamespaceArgs) (uint32, []*model.Namespace, error) { + if err := nsCache.forceQueryUpdate(); err != nil { + return 0, nil, err + } + + ret := make([]*model.Namespace, 0, 32) + + predicates := types.LoadNamespacePredicates(ctx) + + searchName, hasName := args.Filter["name"] + searchOwner, hasOwner := args.Filter["owner"] + + nsCache.ids.ReadRange(func(key string, val *model.Namespace) { + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + + if hasName { + matchOne := false + for i := range searchName { + if utils.IsWildMatch(val.Name, searchName[i]) { + matchOne = true + break + } + } + // 如果没有匹配到,直接返回 + if !matchOne { + return + } + } + + if hasOwner { + matchOne := false + for i := range searchOwner { + if utils.IsWildMatch(val.Owner, searchOwner[i]) { + matchOne = true + break + } + } + // 如果没有匹配到,直接返回 + if !matchOne { + return + } + } + + ret = append(ret, val) + }) + + sort.Slice(ret, func(i, j int) bool { + return ret[i].ModifyTime.After(ret[j].ModifyTime) + }) + + total, ret := nsCache.toPage(len(ret), ret, args) + return uint32(total), ret, nil +} + +func (c *namespaceCache) toPage(total int, items []*model.Namespace, + args *types.NamespaceArgs) (int, []*model.Namespace) { + if len(items) == 0 { + return 0, []*model.Namespace{} + } + if args.Limit == 0 { + return total, items + } + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total + } + return total, items[args.Offset:endIdx] +} diff --git a/cache/service/circuitbreaker.go b/cache/service/circuitbreaker.go index ff02db390..03699fa8e 100644 --- a/cache/service/circuitbreaker.go +++ b/cache/service/circuitbreaker.go @@ -22,6 +22,8 @@ import ( "crypto/sha1" "fmt" "sort" + "strconv" + "strings" "sync" "time" @@ -392,13 +394,155 @@ func (c *circuitBreakerCache) GetCircuitBreakerCount() int { return len(names) } +var ( + ignoreCircuitBreakerRuleFilter = map[string]struct{}{ + "brief": {}, + "service": {}, + "serviceNamespace": {}, + "exactName": {}, + "excludeId": {}, + } + + cbBlurSearchFields = map[string]func(*model.CircuitBreakerRule) string{ + "name": func(cbr *model.CircuitBreakerRule) string { + return cbr.Name + }, + "description": func(cbr *model.CircuitBreakerRule) string { + return cbr.Description + }, + "srcservice": func(cbr *model.CircuitBreakerRule) string { + return cbr.SrcService + }, + "dstservice": func(cbr *model.CircuitBreakerRule) string { + return cbr.DstService + }, + "dstmethod": func(cbr *model.CircuitBreakerRule) string { + return cbr.DstMethod + }, + } + + circuitBreakerSort = map[string]func(asc bool, a, b *model.CircuitBreakerRule) bool{ + "mtime": func(asc bool, a, b *model.CircuitBreakerRule) bool { + ret := a.ModifyTime.Before(b.ModifyTime) + return ret && asc + }, + "id": func(asc bool, a, b *model.CircuitBreakerRule) bool { + ret := a.ID < b.ID + return ret && asc + }, + "name": func(asc bool, a, b *model.CircuitBreakerRule) bool { + ret := a.Name < b.Name + return ret && asc + }, + } +) + // Query implements api.CircuitBreakerCache. -func (c *circuitBreakerCache) Query(context.Context, *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { - panic("unimplemented") +func (c *circuitBreakerCache) Query(ctx context.Context, args *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule, error) { + if err := c.Update(); err != nil { + return 0, nil, err + } + + predicates := types.LoadCircuitBreakerRulePredicates(ctx) + + searchSvc, hasSvc := args.Filter["service"] + searchNs, hasSvcNs := args.Filter["serviceNamespace"] + exactNameValue, hasExactName := args.Filter["exactName"] + excludeIdValue, hasExcludeId := args.Filter["excludeId"] + + lowerFilter := make(map[string]string, len(args.Filter)) + for k, v := range args.Filter { + if _, ok := ignoreCircuitBreakerRuleFilter[k]; ok { + continue + } + lowerFilter[strings.ToLower(k)] = v + } + + results := make([]*model.CircuitBreakerRule, 0, 32) + c.rules.ReadRange(func(key string, val *model.CircuitBreakerRule) { + if hasSvcNs { + srcNsValue := val.SrcNamespace + dstNsValue := val.DstNamespace + if !((srcNsValue == "*" || srcNsValue == searchNs) || (dstNsValue == "*" || dstNsValue == searchNs)) { + return + } + } + if hasSvc { + srcSvcValue := val.SrcService + dstSvcValue := val.DstService + if !((srcSvcValue == searchSvc || srcSvcValue == "*") || (dstSvcValue == searchSvc || dstSvcValue == "*")) { + return + } + } + if hasExactName && exactNameValue != val.Name { + return + } + if hasExcludeId && excludeIdValue != val.ID { + return + } + for fieldKey, filterValue := range lowerFilter { + getter, isBlur := cbBlurSearchFields[fieldKey] + if isBlur { + if utils.IsWildMatch(getter(val), filterValue) { + return + } + } else if fieldKey == "enable" { + if filterValue != strconv.FormatBool(val.Enable) { + return + } + } else if fieldKey == "level" { + levels := strings.Split(filterValue, ",") + var inLevel = false + for _, level := range levels { + levelInt, _ := strconv.Atoi(level) + if int64(levelInt) == int64(val.Level) { + inLevel = true + break + } + } + if !inLevel { + return + } + } else { + // FIXME 暂时不知道还有什么字段查询需要适配,等待自测验证 + } + } + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + + results = append(results, val) + }) + + sortFunc, ok := circuitBreakerSort[args.Filter["order_field"]] + if !ok { + sortFunc = circuitBreakerSort["mtime"] + } + asc := "asc" == strings.ToLower(args.Filter["order_type"]) + sort.Slice(results, func(i, j int) bool { + return sortFunc(asc, results[i], results[j]) + }) + + total, ret := c.toPage(uint32(len(results)), results, args) + return total, ret, nil +} + +func (c *circuitBreakerCache) toPage(total uint32, items []*model.CircuitBreakerRule, + args *types.CircuitBreakerRuleArgs) (uint32, []*model.CircuitBreakerRule) { + if args.Limit == 0 { + return total, items + } + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total + } + return total, items[args.Offset:endIdx] } -// GetRule implements api.FaultDetectCache. -func (f *circuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { - rule, _ := f.rules.Load(id) +// GetRule implements api.CircuitBreakerCache. +func (c *circuitBreakerCache) GetRule(id string) *model.CircuitBreakerRule { + rule, _ := c.rules.Load(id) return rule } diff --git a/cache/service/faultdetect.go b/cache/service/faultdetect.go index 92a88d9b7..8e59792d6 100644 --- a/cache/service/faultdetect.go +++ b/cache/service/faultdetect.go @@ -22,6 +22,7 @@ import ( "crypto/sha1" "fmt" "sort" + "strings" "sync" "time" @@ -350,9 +351,125 @@ func (f *faultDetectCache) GetFaultDetectRuleCount(fun func(k, v interface{}) bo } } +var ( + ignoreFaultDetectRuleFilter = map[string]struct{}{ + "brief": {}, + "service": {}, + "serviceNamespace": {}, + "exactName": {}, + "excludeId": {}, + } + + fdBlurSearchFields = map[string]func(*model.FaultDetectRule) string{ + "name": func(cbr *model.FaultDetectRule) string { + return cbr.Name + }, + "description": func(cbr *model.FaultDetectRule) string { + return cbr.Description + }, + "dstservice": func(cbr *model.FaultDetectRule) string { + return cbr.DstService + }, + "dstmethod": func(cbr *model.FaultDetectRule) string { + return cbr.DstMethod + }, + } + + faultDetectSort = map[string]func(asc bool, a, b *model.FaultDetectRule) bool{ + "mtime": func(asc bool, a, b *model.FaultDetectRule) bool { + ret := a.ModifyTime.Before(b.ModifyTime) + return ret && asc + }, + "id": func(asc bool, a, b *model.FaultDetectRule) bool { + ret := a.ID < b.ID + return ret && asc + }, + "name": func(asc bool, a, b *model.FaultDetectRule) bool { + ret := a.Name < b.Name + return ret && asc + }, + } +) + // Query implements api.FaultDetectCache. -func (f *faultDetectCache) Query(context.Context, *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { - panic("unimplemented") +func (f *faultDetectCache) Query(ctx context.Context, args *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule, error) { + if err := f.Update(); err != nil { + return 0, nil, err + } + + results := make([]*model.FaultDetectRule, 0, 32) + + predicates := types.LoadFaultDetectRulePredicates(ctx) + + searchSvc, hasSvc := args.Filter["service"] + searchNs, hasSvcNs := args.Filter["serviceNamespace"] + exactNameValue, hasExactName := args.Filter["exactName"] + excludeIdValue, hasExcludeId := args.Filter["excludeId"] + + lowerFilter := make(map[string]string, len(args.Filter)) + for k, v := range args.Filter { + if _, ok := ignoreCircuitBreakerRuleFilter[k]; ok { + continue + } + lowerFilter[strings.ToLower(k)] = v + } + + f.rules.ReadRange(func(key string, val *model.FaultDetectRule) { + if hasSvc && hasSvcNs { + dstServiceValue := val.DstService + dstNamespaceValue := val.DstNamespace + if !(dstServiceValue == searchSvc && dstNamespaceValue == searchNs) { + return + } + } + if hasExactName && exactNameValue != val.Name { + return + } + if hasExcludeId && excludeIdValue != val.ID { + return + } + for fieldKey, filterValue := range lowerFilter { + getter, isBlur := fdBlurSearchFields[fieldKey] + if isBlur { + if utils.IsWildMatch(getter(val), filterValue) { + return + } + } else { + // FIXME 暂时不知道还有什么字段查询需要适配,等待自测验证 + } + } + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + + results = append(results, val) + }) + + sortFunc, ok := faultDetectSort[args.Filter["order_field"]] + if !ok { + sortFunc = faultDetectSort["mtime"] + } + asc := "asc" == strings.ToLower(args.Filter["order_type"]) + sort.Slice(results, func(i, j int) bool { + return sortFunc(asc, results[i], results[j]) + }) + + total, ret := f.toPage(uint32(len(results)), results, args) + return total, ret, nil +} + +func (f *faultDetectCache) toPage(total uint32, items []*model.FaultDetectRule, + args *types.FaultDetectArgs) (uint32, []*model.FaultDetectRule) { + if args.Limit == 0 { + return total, items + } + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total + } + return total, items[args.Offset:endIdx] } // GetRule implements api.FaultDetectCache. diff --git a/cache/service/lane.go b/cache/service/lane.go index 7d5b31973..cb54cb805 100644 --- a/cache/service/lane.go +++ b/cache/service/lane.go @@ -19,6 +19,8 @@ package service import ( "context" + "sort" + "strings" "time" "github.com/golang/protobuf/proto" @@ -145,12 +147,8 @@ func (lc *LaneCache) processLaneRuleUpsert(old, item *model.LaneGroupProto, affe waitDelServices[ns][svc] = struct{}{} } removeServiceIfExist := func(ns, svc string) { - if _, ok := waitDelServices[ns]; !ok { - waitDelServices[ns] = map[string]struct{}{} - } - if _, ok := waitDelServices[ns][svc]; ok { - delete(waitDelServices[ns], svc) - } + waitDelServices[ns] = map[string]struct{}{} + delete(waitDelServices[ns], svc) } handle := func(rule *model.LaneGroupProto, serviceOp func(ns, svc string), ruleOp func(string, string, *model.LaneGroupProto)) { @@ -356,13 +354,80 @@ func anyToSelector(data *anypb.Any, msg proto.Message) error { return nil } +var ( + laneGroupSort = map[string]func(asc bool, a, b *model.LaneGroupProto) bool{ + "mtime": func(asc bool, a, b *model.LaneGroupProto) bool { + ret := a.ModifyTime.Before(b.ModifyTime) + return ret && asc + }, + "id": func(asc bool, a, b *model.LaneGroupProto) bool { + ret := a.ID < b.ID + return ret && asc + }, + "name": func(asc bool, a, b *model.LaneGroupProto) bool { + ret := a.Name < b.Name + return ret && asc + }, + } +) + // Query implements api.LaneCache. -func (lc *LaneCache) Query(context.Context, *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { - panic("unimplemented") +func (lc *LaneCache) Query(ctx context.Context, args *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto, error) { + if err := lc.Update(); err != nil { + return 0, nil, err + } + + predicates := types.LoadLaneRulePredicates(ctx) + + searchName, hasName := args.Filter["name"] + searchId, hasId := args.Filter["id"] + + results := make([]*model.LaneGroupProto, 0, 32) + + lc.rules.ReadRange(func(key string, val *model.LaneGroupProto) { + if hasName && !utils.IsWildMatch(val.Name, searchName) { + return + } + if hasId && val.ID != searchId { + return + } + + for i := range predicates { + if !predicates[i](ctx, val) { + return + } + } + + results = append(results, val) + }) + + sortFunc, ok := laneGroupSort[args.Filter["order_field"]] + if !ok { + sortFunc = laneGroupSort["mtime"] + } + asc := "asc" == strings.ToLower(args.Filter["order_type"]) + sort.Slice(results, func(i, j int) bool { + return sortFunc(asc, results[i], results[j]) + }) + + total, ret := lc.toPage(uint32(len(results)), results, args) + return total, ret, nil +} + +func (lc *LaneCache) toPage(total uint32, items []*model.LaneGroupProto, + args *types.LaneGroupArgs) (uint32, []*model.LaneGroupProto) { + if args.Limit == 0 { + return total, items + } + endIdx := args.Offset + args.Limit + if endIdx > total { + endIdx = total + } + return total, items[args.Offset:endIdx] } // GetRule implements api.LaneCache. -func (f *LaneCache) GetRule(id string) *model.LaneGroup { - rule, _ := f.rules.Load(id) +func (lc *LaneCache) GetRule(id string) *model.LaneGroup { + rule, _ := lc.rules.Load(id) return rule.LaneGroup } diff --git a/cache/service/ratelimit_query.go b/cache/service/ratelimit_query.go index 53f643073..6673aff4e 100644 --- a/cache/service/ratelimit_query.go +++ b/cache/service/ratelimit_query.go @@ -27,20 +27,14 @@ import ( "github.com/polarismesh/polaris/common/utils" ) -// forceUpdate 更新配置 -func (rlc *rateLimitCache) forceUpdate() error { - if err := rlc.Update(); err != nil { - return err - } - return nil -} - // QueryRateLimitRules func (rlc *rateLimitCache) QueryRateLimitRules(ctx context.Context, args types.RateLimitRuleArgs) (uint32, []*model.RateLimit, error) { - if err := rlc.forceUpdate(); err != nil { + if err := rlc.Update(); err != nil { return 0, nil, err } + predicates := types.LoadRatelimitRulePredicates(ctx) + hasService := len(args.Service) != 0 hasNamespace := len(args.Namespace) != 0 @@ -65,6 +59,13 @@ func (rlc *rateLimitCache) QueryRateLimitRules(ctx context.Context, args types.R if args.Disable != nil && *args.Disable != rule.Disable { return } + + for i := range predicates { + if !predicates[i](ctx, rule) { + return + } + } + res = append(res, rule) } rlc.IteratorRateLimit(process) diff --git a/cache/service/router_rule.go b/cache/service/router_rule.go index 1529494bf..b146aec70 100644 --- a/cache/service/router_rule.go +++ b/cache/service/router_rule.go @@ -41,13 +41,9 @@ type ( container *RouteRuleContainer - lastMtimeV1 time.Time - lastMtimeV2 time.Time + lastMtime time.Time singleFlight singleflight.Group - - // waitDealV1RuleIds Records need to be converted from V1 to V2 routing rules ID - waitDealV1RuleIds *utils.SyncMap[string, *model.RoutingConfig] } ) @@ -61,9 +57,7 @@ func NewRouteRuleCache(s store.Store, cacheMgr types.CacheManager) types.Routing // initialize The function of implementing the cache interface func (rc *RouteRuleCache) Initialize(_ map[string]interface{}) error { - rc.lastMtimeV1 = time.Unix(0, 0) - rc.lastMtimeV2 = time.Unix(0, 0) - rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() + rc.lastMtime = time.Unix(0, 0) rc.container = newRouteRuleContainer() rc.serviceCache = rc.BaseCache.CacheMgr.GetCacher(types.CacheService).(*serviceCache) return nil @@ -80,12 +74,6 @@ func (rc *RouteRuleCache) Update() error { // update The function of implementing the cache interface func (rc *RouteRuleCache) realUpdate() (map[string]time.Time, int64, error) { - outV1, err := rc.storage.GetRoutingConfigsForCache(rc.LastFetchTime(), rc.IsFirstUpdate()) - if err != nil { - log.Errorf("[Cache] routing config v1 cache get from store err: %s", err.Error()) - return nil, -1, err - } - outV2, err := rc.storage.GetRoutingConfigsV2ForCache(rc.LastFetchTime(), rc.IsFirstUpdate()) if err != nil { log.Errorf("[Cache] routing config v2 cache get from store err: %s", err.Error()) @@ -93,19 +81,16 @@ func (rc *RouteRuleCache) realUpdate() (map[string]time.Time, int64, error) { } lastMtimes := map[string]time.Time{} - rc.setRoutingConfigV1(lastMtimes, outV1) - rc.setRoutingConfigV2(lastMtimes, outV2) + rc.setRouterRules(lastMtimes, outV2) rc.container.reload() - return lastMtimes, int64(len(outV1) + len(outV2)), err + return lastMtimes, int64(len(outV2)), err } // Clear The function of implementing the cache interface func (rc *RouteRuleCache) Clear() error { rc.BaseCache.Clear() - rc.waitDealV1RuleIds = utils.NewSyncMap[string, *model.RoutingConfig]() rc.container = newRouteRuleContainer() - rc.lastMtimeV1 = time.Unix(0, 0) - rc.lastMtimeV2 = time.Unix(0, 0) + rc.lastMtime = time.Unix(0, 0) return nil } @@ -232,50 +217,8 @@ func (rc *RouteRuleCache) GetRule(id string) *model.ExtendRouterConfig { return rule } -// setRoutingConfigV1 Update the data of the store to the cache and convert to v2 model -func (rc *RouteRuleCache) setRoutingConfigV1(lastMtimes map[string]time.Time, cs []*model.RoutingConfig) { - if len(cs) == 0 { - return - } - lastMtimeV1 := rc.LastMtime(rc.Name()).Unix() - for _, entry := range cs { - if entry.ID == "" { - continue - } - if entry.ModifyTime.Unix() > lastMtimeV1 { - lastMtimeV1 = entry.ModifyTime.Unix() - } - if !entry.Valid { - // Delete the cache converted to V2 - rc.container.deleteV1(entry.ID) - continue - } - rc.waitDealV1RuleIds.Store(entry.ID, entry) - } - - rc.waitDealV1RuleIds.Range(func(key string, val *model.RoutingConfig) { - // Save to the new V2 cache - ok, rules, err := rc.convertV1toV2(val) - if err != nil { - log.Warn("[Cache] routing parse v1 => v2 fail, will try again next", - zap.String("rule-id", val.ID), zap.Error(err)) - return - } - if !ok { - log.Warn("[Cache] routing parse v1 => v2 is nil, will try again next", zap.String("rule-id", val.ID)) - return - } - if ok && len(rules) != 0 { - rc.waitDealV1RuleIds.Delete(key) - rc.container.saveV1(val, rules) - } - }) - lastMtimes[rc.Name()] = time.Unix(lastMtimeV1, 0) - log.Infof("[Cache] convert routing parse v1 => v2 count : %d", rc.container.convertV2Size()) -} - -// setRoutingConfigV2 Store V2 Router Caches -func (rc *RouteRuleCache) setRoutingConfigV2(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { +// setRouterRules Store V2 Router Caches +func (rc *RouteRuleCache) setRouterRules(lastMtimes map[string]time.Time, cs []*model.RouterConfig) { if len(cs) == 0 { return } diff --git a/cache/service/router_rule_bucket.go b/cache/service/router_rule_bucket.go index 5b32ea53f..f60a41eb4 100644 --- a/cache/service/router_rule_bucket.go +++ b/cache/service/router_rule_bucket.go @@ -38,7 +38,7 @@ type ServiceWithRouterRules struct { rules map[string]*model.ExtendRouterConfig revision string - customv1Rules *apitraffic.Routing + customv1RuleRef *utils.AtomicValue[*apitraffic.Routing] } func NewServiceWithRouterRules(svcKey model.ServiceKey, direction model.TrafficDirection) *ServiceWithRouterRules { @@ -51,19 +51,20 @@ func NewServiceWithRouterRules(svcKey model.ServiceKey, direction model.TrafficD // AddRouterRule 添加路由规则,注意,这里只会保留处于 Enable 状态的路由规则 func (s *ServiceWithRouterRules) AddRouterRule(rule *model.ExtendRouterConfig) { - if !rule.Enable { - return - } if rule.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { - s.customv1Rules = &apitraffic.Routing{ + s.customv1RuleRef = utils.NewAtomicValue[*apitraffic.Routing](&apitraffic.Routing{ Inbounds: []*apitraffic.Route{}, Outbounds: []*apitraffic.Route{}, - } + }) } s.mutex.Lock() defer s.mutex.Unlock() - s.rules[rule.ID] = rule + if !rule.Enable { + delete(s.rules, rule.ID) + } else { + s.rules[rule.ID] = rule + } } func (s *ServiceWithRouterRules) DelRouterRule(id string) { @@ -92,6 +93,13 @@ func (s *ServiceWithRouterRules) CountRouterRules() int { return len(s.rules) } +func (s *ServiceWithRouterRules) GetRouteRuleV1() *apitraffic.Routing { + if !s.customv1RuleRef.HasValue() { + return nil + } + return s.customv1RuleRef.Load() +} + func (s *ServiceWithRouterRules) Clear() { s.mutex.Lock() defer s.mutex.Unlock() @@ -135,7 +143,7 @@ func (s *ServiceWithRouterRules) reloadRevision() { } func (s *ServiceWithRouterRules) reloadV1Rules() { - if s.customv1Rules == nil { + if !s.customv1RuleRef.HasValue() { return } @@ -151,19 +159,21 @@ func (s *ServiceWithRouterRules) reloadV1Rules() { routes := make([]*apitraffic.Route, 0, 32) for i := range rules { - if rules[i].Priority != uint32(apitraffic.RoutingPolicy_RulePolicy) { + if rules[i].Policy != apitraffic.RoutingPolicy_RulePolicy.String() { continue } routes = append(routes, model.BuildRoutes(rules[i], s.direction)...) } - s.customv1Rules = &apitraffic.Routing{} + customv1Rules := &apitraffic.Routing{} switch s.direction { case model.TrafficDirection_INBOUND: - s.customv1Rules.Inbounds = routes + customv1Rules.Inbounds = routes case model.TrafficDirection_OUTBOUND: - s.customv1Rules.Outbounds = routes + customv1Rules.Outbounds = routes } + + s.customv1RuleRef.Store(customv1Rules) } func newClientRouteRuleContainer(direction model.TrafficDirection) *ClientRouteRuleContainer { @@ -176,6 +186,9 @@ func newClientRouteRuleContainer(direction model.TrafficDirection) *ClientRouteR } type ClientRouteRuleContainer struct { + // lock . + lock sync.RWMutex + direction model.TrafficDirection // key1: namespace, key2: service exactRules *utils.SyncMap[string, *ServiceWithRouterRules] @@ -188,6 +201,9 @@ type ClientRouteRuleContainer struct { func (c *ClientRouteRuleContainer) SearchRouteRuleV2(svc model.ServiceKey) []*model.ExtendRouterConfig { ret := make([]*model.ExtendRouterConfig, 0, 32) + c.lock.RLock() + defer c.lock.RUnlock() + exactRule, existExactRule := c.exactRules.Load(svc.Domain()) if existExactRule { exactRule.IterateRouterRules(func(erc *model.ExtendRouterConfig) { @@ -205,6 +221,10 @@ func (c *ClientRouteRuleContainer) SearchRouteRuleV2(svc model.ServiceKey) []*mo c.allWildcardRules.IterateRouterRules(func(erc *model.ExtendRouterConfig) { ret = append(ret, erc) }) + + sort.Slice(ret, func(i, j int) bool { + return ret[i].Priority < ret[j].Priority + }) return ret } @@ -222,18 +242,18 @@ func (c *ClientRouteRuleContainer) SearchCustomRuleV1(svc model.ServiceKey) (*ap switch c.direction { case model.TrafficDirection_INBOUND: if existExactRule { - ret.Inbounds = append(ret.Inbounds, exactRule.customv1Rules.Inbounds...) + ret.Inbounds = append(ret.Inbounds, exactRule.GetRouteRuleV1().GetInbounds()...) } if existNsWildcardRule { - ret.Inbounds = append(ret.Inbounds, nsWildcardRule.customv1Rules.Inbounds...) + ret.Inbounds = append(ret.Inbounds, nsWildcardRule.GetRouteRuleV1().GetInbounds()...) } default: if existExactRule { - ret.Outbounds = append(ret.Outbounds, exactRule.customv1Rules.Outbounds...) + ret.Outbounds = append(ret.Outbounds, exactRule.GetRouteRuleV1().GetOutbounds()...) revisions = append(revisions, exactRule.revision) } if existNsWildcardRule { - ret.Outbounds = append(ret.Outbounds, nsWildcardRule.customv1Rules.Outbounds...) + ret.Outbounds = append(ret.Outbounds, nsWildcardRule.GetRouteRuleV1().GetOutbounds()...) } } if existExactRule { @@ -243,6 +263,14 @@ func (c *ClientRouteRuleContainer) SearchCustomRuleV1(svc model.ServiceKey) (*ap revisions = append(revisions, nsWildcardRule.revision) } + // 最终在做一次排序 + sort.Slice(ret.Inbounds, func(i, j int) bool { + return model.CompareRoutingV1(ret.Inbounds[i], ret.Inbounds[j]) + }) + sort.Slice(ret.Outbounds, func(i, j int) bool { + return model.CompareRoutingV1(ret.Outbounds[i], ret.Outbounds[j]) + }) + return ret, revisions } @@ -293,6 +321,19 @@ func (c *ClientRouteRuleContainer) RemoveRule(svcKey model.ServiceKey, ruleId st } } +func (c *ClientRouteRuleContainer) CleanAllRule(ruleId string) { + // level1 级别 cache 处理 + c.exactRules.Range(func(key string, svcContainer *ServiceWithRouterRules) { + svcContainer.DelRouterRule(ruleId) + }) + // level2 级别 cache 处理 + c.nsWildcardRules.Range(func(key string, svcContainer *ServiceWithRouterRules) { + svcContainer.DelRouterRule(ruleId) + }) + // level3 级别 cache 处理 + c.allWildcardRules.DelRouterRule(ruleId) +} + func newRouteRuleContainer() *RouteRuleContainer { return &RouteRuleContainer{ rules: utils.NewSyncMap[string, *model.ExtendRouterConfig](), @@ -331,7 +372,13 @@ type RouteRuleContainer struct { func (b *RouteRuleContainer) saveV2(conf *model.ExtendRouterConfig) { b.rules.Store(conf.ID, conf) handler := func(container *ClientRouteRuleContainer, svcKey model.ServiceKey) { + // 避免读取到中间状态数据 + container.lock.Lock() + defer container.lock.Unlock() + b.effect.Add(svcKey) + // 先删除,再保存 + container.CleanAllRule(conf.ID) container.SaveRule(svcKey, conf) } diff --git a/cache/service/router_rule_query.go b/cache/service/router_rule_query.go index b5262c287..88473f8c8 100644 --- a/cache/service/router_rule_query.go +++ b/cache/service/router_rule_query.go @@ -29,14 +29,6 @@ import ( "github.com/polarismesh/polaris/common/utils" ) -// forceUpdate 更新配置 -func (rc *RouteRuleCache) forceUpdate() error { - if err := rc.Update(); err != nil { - return err - } - return nil -} - func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace, sourceService, destNamespace, destService string, both bool) bool { var ( @@ -121,7 +113,7 @@ func queryRoutingRuleV2ByService(rule *model.ExtendRouterConfig, sourceNamespace // QueryRoutingConfigsV2 Query Route Configuration List func (rc *RouteRuleCache) QueryRoutingConfigsV2(ctx context.Context, args *types.RoutingArgs) (uint32, []*model.ExtendRouterConfig, error) { - if err := rc.forceUpdate(); err != nil { + if err := rc.Update(); err != nil { return 0, nil, err } hasSvcQuery := len(args.Service) != 0 || len(args.Namespace) != 0 @@ -181,7 +173,14 @@ func (rc *RouteRuleCache) QueryRoutingConfigsV2(ctx context.Context, args *types res = append(res, routeRule) } + predicates := types.LoadRouterRulePredicates(ctx) + rc.IteratorRouterRule(func(key string, value *model.ExtendRouterConfig) { + for i := range predicates { + if !predicates[i](ctx, value) { + return + } + } process(key, value) }) diff --git a/cache/service/service.go b/cache/service/service.go index 03a5d5c4a..6b3ae7e56 100644 --- a/cache/service/service.go +++ b/cache/service/service.go @@ -340,13 +340,45 @@ func (sc *serviceCache) GetServicesCount() int { } // ListServices get service list and revision by namespace -func (sc *serviceCache) ListServices(ns string) (string, []*model.Service) { - return sc.serviceList.ListServices(ns) +func (sc *serviceCache) ListServices(ctx context.Context, ns string) (string, []*model.Service) { + revision, matchServices := sc.serviceList.ListServices(ns) + predicates := types.LoadServicePredicates(ctx) + ret := make([]*model.Service, 0, len(matchServices)) + for i := range matchServices { + allMatch := true + for j := range predicates { + if !predicates[j](ctx, matchServices[i]) { + allMatch = false + break + } + } + if allMatch { + ret = append(ret, matchServices[i]) + } + } + matchServices = ret + return revision, matchServices } // ListAllServices get all service and revision -func (sc *serviceCache) ListAllServices() (string, []*model.Service) { - return sc.serviceList.ListAllServices() +func (sc *serviceCache) ListAllServices(ctx context.Context) (string, []*model.Service) { + revision, matchServices := sc.serviceList.ListAllServices() + predicates := types.LoadServicePredicates(ctx) + ret := make([]*model.Service, 0, len(matchServices)) + for i := range matchServices { + pass := true + for j := range predicates { + if !predicates[j](ctx, matchServices[i]) { + pass = false + break + } + } + if pass { + ret = append(ret, matchServices[i]) + } + } + matchServices = ret + return revision, matchServices } // ListServiceAlias get all service alias by target service diff --git a/cache/service/service_contract.go b/cache/service/service_contract.go index a5eddece4..b3dc30a37 100644 --- a/cache/service/service_contract.go +++ b/cache/service/service_contract.go @@ -116,11 +116,11 @@ func (sc *ServiceContractCache) setContracts(values []*model.EnrichServiceContra item := values[i] if !item.Valid { del++ - sc.upsertValueCache(item, true) + _ = sc.upsertValueCache(item, true) continue } upsert++ - sc.upsertValueCache(item, false) + _ = sc.upsertValueCache(item, false) } return map[string]time.Time{ sc.Name(): lastMtime, diff --git a/cache/service/service_query.go b/cache/service/service_query.go index 27e6f4ee4..25e26006c 100644 --- a/cache/service/service_query.go +++ b/cache/service/service_query.go @@ -74,6 +74,23 @@ func (sc *serviceCache) GetServicesByFilter(ctx context.Context, serviceFilters matchServices = tmpSvcs } + // 这里需要额外做过滤判断 + predicates := types.LoadServicePredicates(ctx) + ret := make([]*model.Service, 0, len(matchServices)) + for i := range matchServices { + pass := true + for pi := range predicates { + if !predicates[pi](ctx, matchServices[i]) { + pass = false + break + } + } + if pass { + ret = append(ret, matchServices[i]) + } + } + matchServices = ret + amount, services := sortBeforeTrim(matchServices, offset, limit) var enhancedServices []*model.EnhancedService diff --git a/common/api/v1/config_response.go b/common/api/v1/config_response.go index ccf34820b..d9d1c94b5 100644 --- a/common/api/v1/config_response.go +++ b/common/api/v1/config_response.go @@ -226,6 +226,13 @@ func NewConfigFileReleaseHistoryResponse( } } +func NewSimpleConfigFileImportResponse(code apimodel.Code) *apiconfig.ConfigImportResponse { + return &apiconfig.ConfigImportResponse{ + Code: &wrappers.UInt32Value{Value: uint32(code)}, + Info: &wrappers.StringValue{Value: code2info[uint32(code)]}, + } +} + func NewConfigFileImportResponse(code apimodel.Code, createConfigFiles, skipConfigFiles, overwriteConfigFiles []*apiconfig.ConfigFile) *apiconfig.ConfigImportResponse { return &apiconfig.ConfigImportResponse{ diff --git a/common/model/auth/acquire_context.go b/common/model/auth/acquire_context.go index 4a8c7a92d..09c78b79a 100644 --- a/common/model/auth/acquire_context.go +++ b/common/model/auth/acquire_context.go @@ -38,7 +38,7 @@ type AcquireContext struct { // Module 来自那个业务层(服务注册与服务治理、配置模块) module BzModule // Method 操作函数 - method ServerFunctionName + methods []ServerFunctionName // Operation 本次操作涉及的动作 operation ResourceOperation // Resources 本次 @@ -98,7 +98,7 @@ func WithModule(module BzModule) acquireContextOption { // WithMethod 本次操作函数名称 func WithMethod(method ServerFunctionName) acquireContextOption { return func(authCtx *AcquireContext) { - authCtx.method = method + authCtx.methods = []ServerFunctionName{method} } } @@ -118,7 +118,7 @@ func WithOperation(operation ResourceOperation) acquireContextOption { // @return acquireContextOption func WithAccessResources(accessResources map[apisecurity.ResourceType][]ResourceEntry) acquireContextOption { return func(authCtx *AcquireContext) { - authCtx.accessResources = accessResources + authCtx.SetAccessResources(accessResources) } } @@ -180,6 +180,11 @@ func (authCtx *AcquireContext) GetOperation() ResourceOperation { return authCtx.operation } +// SetOperation 设置本次操作的类型 +func (authCtx *AcquireContext) SetOperation(op ResourceOperation) { + authCtx.operation = op +} + // GetAccessResources 获取本次请求的资源 // // @receiver authCtx @@ -193,7 +198,20 @@ func (authCtx *AcquireContext) GetAccessResources() map[apisecurity.ResourceType // @receiver authCtx // @param accessRes func (authCtx *AcquireContext) SetAccessResources(accessRes map[apisecurity.ResourceType][]ResourceEntry) { - authCtx.accessResources = accessRes + copyM := map[apisecurity.ResourceType][]ResourceEntry{} + for k, v := range accessRes { + if len(v) == 0 { + continue + } + copyM[k] = v + } + + authCtx.accessResources = copyM +} + +// SetMethod 设置本次请求涉及的操作函数 +func (authCtx *AcquireContext) SetMethod(methods []ServerFunctionName) { + authCtx.methods = methods } // GetAttachments 获取本次请求的额外携带信息 @@ -213,8 +231,8 @@ func (authCtx *AcquireContext) SetAttachment(key string, val interface{}) { } // GetMethod 获取本次请求涉及的操作函数 -func (authCtx *AcquireContext) GetMethod() ServerFunctionName { - return authCtx.method +func (authCtx *AcquireContext) GetMethods() []ServerFunctionName { + return authCtx.methods } // SetFromClient 本次请求来自客户端 diff --git a/common/model/auth/auth.go b/common/model/auth/auth.go index 0f879dc17..270fc8574 100644 --- a/common/model/auth/auth.go +++ b/common/model/auth/auth.go @@ -371,6 +371,8 @@ type StrategyDetail struct { Comment string Default bool Owner string + // 来源 + Source string // CalleeMethods 允许访问的服务端接口 CalleeMethods []string Resources []StrategyResource @@ -383,6 +385,61 @@ type StrategyDetail struct { ModifyTime time.Time } +func (s *StrategyDetail) GetAction() apisecurity.AuthAction { + if s.Action == apisecurity.AuthAction_ALLOW.String() { + return apisecurity.AuthAction_ALLOW + } + if s.Action == apisecurity.AuthAction_READ_WRITE.String() { + return apisecurity.AuthAction_ALLOW + } + return apisecurity.AuthAction_DENY +} + +func (s *StrategyDetail) FromSpec(req *apisecurity.AuthStrategy) { + s.ID = utils.NewUUID() + s.Name = req.Name.GetValue() + s.Action = req.GetAction().String() + s.Comment = req.Comment.GetValue() + s.Default = false + s.Owner = req.Owner.GetValue() + s.Valid = true + s.Source = req.GetSource().GetValue() + s.Revision = utils.NewUUID() + s.CreateTime = time.Now() + s.ModifyTime = time.Now() + s.CalleeMethods = req.GetFunctions() + s.Conditions = make([]Condition, 0, len(req.GetResourceLabels())) + for i := range req.GetResourceLabels() { + item := req.GetResourceLabels()[i] + s.Conditions = append(s.Conditions, Condition{ + Key: item.Key, + Value: item.Value, + CompareFunc: item.CompareType, + }) + } + +} + +func (s *StrategyDetail) IsMatchAction(a string) bool { + saveAction := s.Action + if isAllowAction(saveAction) { + saveAction = apisecurity.AuthAction_ALLOW.String() + } + if isAllowAction(a) { + a = apisecurity.AuthAction_ALLOW.String() + } + return saveAction == a +} + +func isAllowAction(s string) bool { + switch s { + case apisecurity.AuthAction_ALLOW.String(), apisecurity.AuthAction_READ_WRITE.String(): + return true + default: + return false + } +} + func (s *StrategyDetail) IsDeny() bool { return s.Action == apisecurity.AuthAction_DENY.String() } @@ -433,6 +490,8 @@ type ModifyStrategyDetail struct { Action string Comment string Metadata map[string]string + CalleeMethods []string + Conditions []Condition AddPrincipals []Principal RemovePrincipals []Principal AddResources []StrategyResource @@ -461,6 +520,31 @@ type StrategyResource struct { ResID string } +func (s StrategyResource) Key() string { + return strconv.Itoa(int(s.ResType)) + "/" + s.ResID +} + +func forUserPrincipal(id string) Principal { + return Principal{ + PrincipalID: id, + PrincipalType: PrincipalUser, + } +} + +func forUserGroupPrincipal(id string) Principal { + return Principal{ + PrincipalID: id, + PrincipalType: PrincipalGroup, + } +} + +func forRolePrincipal(id string) Principal { + return Principal{ + PrincipalID: id, + PrincipalType: PrincipalRole, + } +} + // Principal 策略相关人 type Principal struct { StrategyID string @@ -468,6 +552,16 @@ type Principal struct { Owner string PrincipalID string PrincipalType PrincipalType + Extend map[string]string +} + +func NewAnonymousPrincipal() Principal { + return Principal{ + Name: "__anonymous__", + PrincipalType: PrincipalUser, + PrincipalID: "__anonymous__", + Extend: map[string]string{}, + } } func (p Principal) String() string { @@ -495,6 +589,60 @@ type Role struct { Comment string CreateTime time.Time ModifyTime time.Time - Users []*User - UserGroups []*UserGroup + Users []Principal + UserGroups []Principal +} + +func (r *Role) FromSpec(d *apisecurity.Role) { + r.Name = d.Name + r.Owner = d.Owner + r.Source = d.Source + r.Metadata = d.Metadata + + if len(d.Users) != 0 { + users := make([]Principal, 0, len(d.Users)) + for i := range d.Users { + users = append(users, Principal{PrincipalID: d.Users[i].GetId().GetValue()}) + } + r.Users = users + } + + if len(d.UserGroups) != 0 { + groups := make([]Principal, 0, len(d.UserGroups)) + for i := range d.UserGroups { + groups = append(groups, Principal{PrincipalID: d.UserGroups[i].GetId().GetValue()}) + } + r.UserGroups = groups + } +} + +func (r *Role) ToSpec() *apisecurity.Role { + d := &apisecurity.Role{} + + d.Name = r.Name + d.Owner = r.Owner + d.Source = r.Source + d.Metadata = r.Metadata + + if len(r.Users) != 0 { + users := make([]*apisecurity.User, 0, len(r.Users)) + for i := range r.Users { + users = append(users, &apisecurity.User{ + Id: utils.NewStringValue(r.Users[i].PrincipalID), + }) + } + d.Users = users + } + + if len(d.UserGroups) != 0 { + groups := make([]*apisecurity.UserGroup, 0, len(d.UserGroups)) + for i := range r.UserGroups { + groups = append(groups, &apisecurity.UserGroup{ + Id: utils.NewStringValue(r.UserGroups[i].PrincipalID), + }) + } + d.UserGroups = groups + } + + return d } diff --git a/common/model/auth/const.go b/common/model/auth/const.go index 88ecf139d..d58a621fb 100644 --- a/common/model/auth/const.go +++ b/common/model/auth/const.go @@ -52,7 +52,6 @@ const ( const ( CreateNamespace ServerFunctionName = "CreateNamespace" CreateNamespaces ServerFunctionName = "CreateNamespaces" - DeleteNamespace ServerFunctionName = "DeleteNamespace" DeleteNamespaces ServerFunctionName = "DeleteNamespaces" UpdateNamespaces ServerFunctionName = "UpdateNamespaces" UpdateNamespaceToken ServerFunctionName = "UpdateNamespaceToken" @@ -179,7 +178,13 @@ const ( ) // 全链路灰度 -const () +const ( + CreateLaneGroups ServerFunctionName = "CreateLaneGroups" + DeleteLaneGroups ServerFunctionName = "DeleteLaneGroups" + EnableLaneGroups ServerFunctionName = "EnableLaneGroups" + UpdateLaneGroups ServerFunctionName = "UpdateLaneGroups" + DescribeLaneGroups ServerFunctionName = "DescribeLaneGroups" +) // 用户/用户组 const ( @@ -235,6 +240,226 @@ const ( DescribeCMDBInfo ServerFunctionName = "DescribeCMDBInfo" ) +type ServerFunctionGroup struct { + Name string `json:"name"` + Functions []ServerFunctionName `json:"functions"` +} + +var ServerFunctions = []ServerFunctionGroup{ + { + Name: "Client", + Functions: []ServerFunctionName{ + RegisterInstance, + DeregisterInstance, + ReportServiceContract, + DiscoverServices, + DiscoverInstances, + UpdateInstance, + DiscoverRouterRule, + DiscoverRateLimitRule, + DiscoverCircuitBreakerRule, + DiscoverFaultDetectRule, + DiscoverServiceContract, + DiscoverLaneRule, + DiscoverConfigFile, + WatchConfigFile, + DiscoverConfigFileNames, + DiscoverConfigGroups, + }, + }, + { + Name: "Namespace", + Functions: []ServerFunctionName{ + CreateNamespace, + CreateNamespaces, + DeleteNamespaces, + UpdateNamespaces, + // UpdateNamespaceToken, + DescribeNamespaces, + // DescribeNamespaceToken, + }, + }, + { + Name: "Service", + Functions: []ServerFunctionName{ + CreateServices, + DeleteServices, + UpdateServices, + // UpdateServiceToken, + DescribeAllServices, + DescribeServices, + DescribeServicesCount, + // DescribeServiceToken, + // DescribeServiceOwner, + CreateServiceAlias, + DeleteServiceAliases, + UpdateServiceAlias, + DescribeServiceAliases, + }, + }, + { + Name: "ServiceContract", + Functions: []ServerFunctionName{ + CreateServiceContracts, + DescribeServiceContracts, + DescribeServiceContractVersions, + DeleteServiceContracts, + CreateServiceContractInterfaces, + AppendServiceContractInterfaces, + DeleteServiceContractInterfaces, + }, + }, + { + Name: "Instance", + Functions: []ServerFunctionName{ + CreateInstances, + DeleteInstances, + DeleteInstancesByHost, + UpdateInstances, + UpdateInstancesIsolate, + DescribeInstances, + DescribeInstancesCount, + DescribeInstanceLabels, + CleanInstance, + BatchCleanInstances, + DescribeInstanceLastHeartbeat, + }, + }, + { + Name: "RouteRule", + Functions: []ServerFunctionName{ + CreateRouteRules, + DeleteRouteRules, + UpdateRouteRules, + EnableRouteRules, + DescribeRouteRules, + }, + }, + { + Name: "RateLimitRule", + Functions: []ServerFunctionName{ + CreateRateLimitRules, + DeleteRateLimitRules, + UpdateRateLimitRules, + EnableRateLimitRules, + DescribeRateLimitRules, + }, + }, + { + Name: "CircuitBreakerRule", + Functions: []ServerFunctionName{ + CreateCircuitBreakerRules, + DeleteCircuitBreakerRules, + EnableCircuitBreakerRules, + UpdateCircuitBreakerRules, + DescribeCircuitBreakerRules, + }, + }, + { + Name: "FaultDetectRule", + Functions: []ServerFunctionName{ + CreateFaultDetectRules, + DeleteFaultDetectRules, + EnableFaultDetectRules, + UpdateFaultDetectRules, + DescribeFaultDetectRules, + }, + }, + { + Name: "LaneRule", + Functions: []ServerFunctionName{ + CreateLaneGroups, + DeleteLaneGroups, + EnableLaneGroups, + UpdateLaneGroups, + DescribeLaneGroups, + }, + }, + { + Name: "ConfigGroup", + Functions: []ServerFunctionName{ + CreateConfigFileGroup, + DeleteConfigFileGroup, + UpdateConfigFileGroup, + DescribeConfigFileGroups, + }, + }, + { + Name: "ConfigFile", + Functions: []ServerFunctionName{ + PublishConfigFile, + CreateConfigFile, + UpdateConfigFile, + DeleteConfigFile, + DescribeConfigFileRichInfo, + DescribeConfigFiles, + BatchDeleteConfigFiles, + ExportConfigFiles, + ImportConfigFiles, + DescribeConfigFileReleaseHistories, + DescribeAllConfigFileTemplates, + DescribeConfigFileTemplate, + CreateConfigFileTemplate, + }, + }, + { + Name: "ConfigRelease", + Functions: []ServerFunctionName{ + RollbackConfigFileReleases, + DeleteConfigFileReleases, + StopGrayConfigFileReleases, + DescribeConfigFileRelease, + DescribeConfigFileReleases, + DescribeConfigFileReleaseVersions, + UpsertAndReleaseConfigFile, + }, + }, + { + Name: "User", + Functions: []ServerFunctionName{ + CreateUsers, + DeleteUsers, + DescribeUsers, + DescribeUserToken, + EnableUserToken, + ResetUserToken, + UpdateUser, + UpdateUserPassword, + }, + }, + { + Name: "UserGroup", + Functions: []ServerFunctionName{ + CreateUserGroup, + UpdateUserGroups, + DeleteUserGroups, + DescribeUserGroups, + DescribeUserGroupDetail, + DescribeUserGroupToken, + EnableUserGroupToken, + ResetUserGroupToken, + }, + }, + { + Name: "AuthPolicy", + Functions: []ServerFunctionName{ + CreateAuthPolicy, + UpdateAuthPolicies, + DeleteAuthPolicies, + DescribeAuthPolicies, + DescribeAuthPolicyDetail, + DescribePrincipalResources, + }, + }, + // "AuthRole": { + // CreateAuthRoles, + // UpdateAuthRoles, + // DeleteAuthRoles, + // DescribeAuthRoles, + // DescribeAuthRoleDetail, + // }, +} + var ( SearchTypeMapping = map[string]apisecurity.ResourceType{ "0": apisecurity.ResourceType_Namespaces, diff --git a/common/model/auth/container.go b/common/model/auth/container.go index 4833b3108..cea3d4f52 100644 --- a/common/model/auth/container.go +++ b/common/model/auth/container.go @@ -18,21 +18,22 @@ package auth import ( - "github.com/polarismesh/polaris/common/utils" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + + "github.com/polarismesh/polaris/common/utils" ) // PrincipalResourceContainer principal 资源容器 type PrincipalResourceContainer struct { - denyResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] - allowResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]] + denyResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]] + allowResources *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]] } // NewPrincipalResourceContainer 创建 PrincipalResourceContainer 对象 func NewPrincipalResourceContainer() *PrincipalResourceContainer { return &PrincipalResourceContainer{ - allowResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), - denyResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]](), + allowResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]](), + denyResources: utils.NewSyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]](), } } @@ -40,12 +41,18 @@ func NewPrincipalResourceContainer() *PrincipalResourceContainer { func (p *PrincipalResourceContainer) Hint(rt apisecurity.ResourceType, resId string) (apisecurity.AuthAction, bool) { ids, ok := p.denyResources.Load(rt) if ok { + if ids.Contains(utils.MatchAll) { + return apisecurity.AuthAction_DENY, true + } if ids.Contains(resId) { return apisecurity.AuthAction_DENY, true } } ids, ok = p.allowResources.Load(rt) if ok { + if ids.Contains(utils.MatchAll) { + return apisecurity.AuthAction_ALLOW, true + } if ids.Contains(resId) { return apisecurity.AuthAction_ALLOW, true } @@ -53,46 +60,50 @@ func (p *PrincipalResourceContainer) Hint(rt apisecurity.ResourceType, resId str return 0, false } -// SaveAllowResource 保存允许的资源 -func (p *PrincipalResourceContainer) SaveAllowResource(r StrategyResource) { - p.saveResource(p.allowResources, r) -} - -// DelAllowResource 删除允许的资源 -func (p *PrincipalResourceContainer) DelAllowResource(r StrategyResource) { - p.delResource(p.allowResources, r) -} - -// SaveDenyResource 保存拒绝的资源 -func (p *PrincipalResourceContainer) SaveDenyResource(r StrategyResource) { - p.saveResource(p.denyResources, r) +// SaveResource 保存资源 +func (p *PrincipalResourceContainer) SaveResource(a apisecurity.AuthAction, r StrategyResource) { + if a == apisecurity.AuthAction_ALLOW { + p.saveResource(p.allowResources, r) + } else { + p.saveResource(p.denyResources, r) + } } -// DelDenyResource 删除拒绝的资源 -func (p *PrincipalResourceContainer) DelDenyResource(r StrategyResource) { - p.delResource(p.denyResources, r) +// DelResource 删除资源 +func (p *PrincipalResourceContainer) DelResource(a apisecurity.AuthAction, r StrategyResource) { + if a == apisecurity.AuthAction_ALLOW { + p.delResource(p.allowResources, r) + } else { + p.delResource(p.denyResources, r) + } } func (p *PrincipalResourceContainer) saveResource( - container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], res StrategyResource) { + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]], res StrategyResource) { resType := apisecurity.ResourceType(res.ResType) - container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { - return utils.NewRefSyncSet[string]() + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string, string] { + return utils.NewRefSyncSet[string, string]() }) ids, _ := container.Load(resType) - ids.Add(res.ResID) + ids.Add(utils.Reference[string, string]{ + Key: res.ResID, + Referencer: res.StrategyID, + }) } func (p *PrincipalResourceContainer) delResource( - container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string]], r StrategyResource) { + container *utils.SyncMap[apisecurity.ResourceType, *utils.RefSyncSet[string, string]], res StrategyResource) { - resType := apisecurity.ResourceType(r.ResType) - container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string] { - return utils.NewRefSyncSet[string]() + resType := apisecurity.ResourceType(res.ResType) + container.ComputeIfAbsent(resType, func(k apisecurity.ResourceType) *utils.RefSyncSet[string, string] { + return utils.NewRefSyncSet[string, string]() }) ids, _ := container.Load(resType) - ids.Remove(r.ResID) + ids.Remove(utils.Reference[string, string]{ + Key: res.ResID, + Referencer: res.StrategyID, + }) } diff --git a/common/model/auth/context.go b/common/model/auth/context.go new file mode 100644 index 000000000..d4ccbae19 --- /dev/null +++ b/common/model/auth/context.go @@ -0,0 +1,22 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +type ( + ContextKeyConditions struct{} +) diff --git a/common/model/auth/funcs.go b/common/model/auth/funcs.go new file mode 100644 index 000000000..fb458bfd8 --- /dev/null +++ b/common/model/auth/funcs.go @@ -0,0 +1,66 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import "strings" + +/* +* +string_equal +string_not_equal +string_equal_ignore_case +string_not_equal_ignore_case +string_like +string_not_like +date_equal +date_not_equal +date_greater_than +date_greater_than_equal +date_less_than +date_less_than_equal +ip_equal +ip_not_equal +*/ +var ( + ConditionCompareDict = map[string]func(string, string) bool{ + // string_equal + "string_equal": func(s1, s2 string) bool { + return s1 == s2 + }, + "for_any_value:string_equal": func(s1, s2 string) bool { + return s1 == s2 + }, + // string_not_equal + "string_not_equal": func(s1, s2 string) bool { + return s1 != s2 + }, + "for_any_value:string_not_equal": func(s1, s2 string) bool { + return s1 != s2 + }, + // string_equal_ignore_case + "string_equal_ignore_case": strings.EqualFold, + "for_any_value:string_equal_ignore_case": strings.EqualFold, + // string_not_equal_ignore_case + "string_not_equal_ignore_case": func(s1, s2 string) bool { + return !strings.EqualFold(s1, s2) + }, + "for_any_value:string_not_equal_ignore_case": func(s1, s2 string) bool { + return !strings.EqualFold(s1, s2) + }, + } +) diff --git a/common/model/config_file.go b/common/model/config_file.go index f894901a0..af9ce1347 100644 --- a/common/model/config_file.go +++ b/common/model/config_file.go @@ -492,6 +492,8 @@ func ToConfigGroupAPI(group *ConfigFileGroup) *config_manage.ConfigFileGroup { Business: utils.NewStringValue(group.Business), Department: utils.NewStringValue(group.Department), Metadata: group.Metadata, + Editable: utils.NewBoolValue(true), + Deleteable: utils.NewBoolValue(true), } } diff --git a/common/model/context_key.go b/common/model/context_key.go index 9c8377e7d..2c7cdcc16 100644 --- a/common/model/context_key.go +++ b/common/model/context_key.go @@ -22,4 +22,6 @@ type ( ContextKeyAutoCreateNamespace struct{} // ContextKeyAutoCreateService . ContextKeyAutoCreateService struct{} + // ContextKeyCompatible . + ContextKeyCompatible struct{} ) diff --git a/common/model/lane.go b/common/model/lane.go index db5539a7f..9d79320d6 100644 --- a/common/model/lane.go +++ b/common/model/lane.go @@ -49,6 +49,7 @@ type LaneGroup struct { Revision string Description string Valid bool + Labels map[string]string CreateTime time.Time ModifyTime time.Time // LaneRules id -> *LaneRule diff --git a/common/model/naming.go b/common/model/naming.go index aa3af6ab9..9394ca3f9 100644 --- a/common/model/naming.go +++ b/common/model/naming.go @@ -521,6 +521,7 @@ type FaultDetectRule struct { DstMethod string Rule string Revision string + Metadata map[string]string Valid bool CreateTime time.Time ModifyTime time.Time diff --git a/common/model/operation.go b/common/model/operation.go index c8b520ed2..5c3e2c0e6 100644 --- a/common/model/operation.go +++ b/common/model/operation.go @@ -62,12 +62,15 @@ const ( RUserGroup Resource = "UserGroup" RUserGroupRelation Resource = "UserGroupRelation" RAuthStrategy Resource = "AuthStrategy" + RAuthRole Resource = "Role" RConfigGroup Resource = "ConfigGroup" RConfigFile Resource = "ConfigFile" RConfigFileRelease Resource = "ConfigFileRelease" RCircuitBreakerRule Resource = "CircuitBreakerRule" RFaultDetectRule Resource = "FaultDetectRule" RServiceContract Resource = "ServiceContract" + RLaneGroup Resource = "LaneGroup" + RLaneRule Resource = "LaneRule" ) // RecordEntry Operation records diff --git a/common/model/ratelimit.go b/common/model/ratelimit.go index a344fe178..7de236d5e 100644 --- a/common/model/ratelimit.go +++ b/common/model/ratelimit.go @@ -42,6 +42,26 @@ type RateLimit struct { CreateTime time.Time ModifyTime time.Time EnableTime time.Time + Metadata map[string]string +} + +func (r *RateLimit) CopyNoProto() *RateLimit { + return &RateLimit{ + ID: r.ID, + ServiceID: r.ServiceID, + Name: r.Name, + Method: r.Method, + Labels: r.Labels, + Proto: r.Proto, + Priority: r.Priority, + Rule: r.Rule, + Revision: r.Revision, + Disable: r.Disable, + Valid: r.Valid, + CreateTime: r.CreateTime, + ModifyTime: r.ModifyTime, + EnableTime: r.EnableTime, + } } // Labels2Arguments 适配老的标签到新的参数列表 diff --git a/common/model/routing.go b/common/model/routing.go index 85a3b9ff8..290036b87 100644 --- a/common/model/routing.go +++ b/common/model/routing.go @@ -44,6 +44,8 @@ const ( const ( // V2RuleIDKey v2 版本的规则路由 ID V2RuleIDKey = "__routing_v2_id__" + // V2RuleIDPriority v2 版本的规则路由优先级 + V2RuleIDPriority = "__routing_v2_priority__" // V1RuleIDKey v1 版本的路由规则 ID V1RuleIDKey = "__routing_v1_id__" // V1RuleRouteIndexKey v1 版本 route 规则在自己 route 链中的 index 信息 @@ -116,7 +118,7 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { ) switch r.GetRoutingPolicy() { - case apitraffic.RoutingPolicy_RulePolicy: + case apitraffic.RoutingPolicy_NearbyPolicy: anyValue, err = ptypes.MarshalAny(r.NearbyRouting) if err != nil { return nil, err @@ -146,6 +148,8 @@ func (r *ExtendRouterConfig) ToApi() (*apitraffic.RouteRule, error) { Etime: commontime.Time2String(r.EnableTime), Priority: r.Priority, Description: r.Description, + Editable: true, + Deleteable: true, } if r.EnableTime.Year() > 2000 { rule.Etime = commontime.Time2String(r.EnableTime) @@ -190,6 +194,8 @@ type RouterConfig struct { ModifyTime time.Time `json:"mtime"` // enabletime The last time the rules enabled EnableTime time.Time `json:"etime"` + // Metadata. + Metadata map[string]string `json:"metadata"` } // GetRoutingPolicy Query routing rules type @@ -305,6 +311,7 @@ func (r *RouterConfig) ParseRouteRuleFromAPI(routing *apitraffic.RouteRule) erro r.Policy = routing.GetRoutingPolicy().String() r.Priority = routing.Priority r.Description = routing.Description + r.Metadata = routing.Metadata // Priority range range [0, 10] if r.Priority > 10 { @@ -378,20 +385,21 @@ func parseSubRouteRule(ruleRouting *apitraffic.RuleRoutingConfig) *RuleRoutingCo for i := range ruleRouting.Rules { item := ruleRouting.Rules[i] - source := item.Sources[0] - destination := item.Destinations[0] - - wrapper.Caller = ServiceKey{ - Namespace: source.Namespace, - Name: source.Service, + if len(item.Sources) != 0 { + source := item.Sources[0] + wrapper.Caller = ServiceKey{ + Namespace: source.Namespace, + Name: source.Service, + } } - wrapper.Callee = ServiceKey{ - Namespace: destination.Namespace, - Name: destination.Service, + if len(item.Destinations) != 0 { + destination := item.Destinations[0] + wrapper.Callee = ServiceKey{ + Namespace: destination.Namespace, + Name: destination.Service, + } } - break } - return wrapper } @@ -603,7 +611,15 @@ func CompareRoutingV2(a, b *ExtendRouterConfig) bool { if a.Priority != b.Priority { return a.Priority < b.Priority } - return a.CreateTime.Before(b.CreateTime) + // 如果优先级相同,则比较规则 ID + return a.ID < b.ID +} + +// CompareRoutingV1 Compare the priority of two routing. +func CompareRoutingV1(a, b *apitraffic.Route) bool { + ap := a.ExtendInfo[V2RuleIDPriority] + bp := b.ExtendInfo[V2RuleIDPriority] + return ap < bp } // ConvertRoutingV1ToExtendV2 The routing rules of the V1 version are converted to V2 version for storage @@ -755,7 +771,8 @@ func BuildInBoundsRoute(item *ExtendRouterConfig) []*apitraffic.Route { Sources: v1sources, Destinations: v1destinations, ExtendInfo: map[string]string{ - V2RuleIDKey: item.ID, + V2RuleIDKey: item.ID, + V2RuleIDPriority: fmt.Sprintf("%04d", item.Priority), }, }) } diff --git a/common/utils/atomic.go b/common/utils/atomic.go index 6cc8eea62..8ab0b15eb 100644 --- a/common/utils/atomic.go +++ b/common/utils/atomic.go @@ -21,6 +21,7 @@ import "sync/atomic" func NewAtomicValue[V any](v V) *AtomicValue[V] { a := new(AtomicValue[V]) + a.a = atomic.Value{} a.Store(v) return a } @@ -29,6 +30,14 @@ type AtomicValue[V any] struct { a atomic.Value } +func (a *AtomicValue[V]) HasValue() bool { + if a == nil { + return false + } + v := a.a.Load() + return v != nil +} + func (a *AtomicValue[V]) Store(val V) { a.a.Store(val) } diff --git a/common/utils/collection.go b/common/utils/collection.go index 534edec67..7d85de169 100644 --- a/common/utils/collection.go +++ b/common/utils/collection.go @@ -57,46 +57,52 @@ func (set *Set[K]) Range(fn func(val K)) { } } +type Reference[K, R comparable] struct { + Key K + Referencer R +} + // NewRefSyncSet returns a new Set -func NewRefSyncSet[K comparable]() *RefSyncSet[K] { - return &RefSyncSet[K]{ - container: make(map[K]int), +func NewRefSyncSet[K, R comparable]() *RefSyncSet[K, R] { + return &RefSyncSet[K, R]{ + container: map[K]map[R]struct{}{}, } } -type RefSyncSet[K comparable] struct { - container map[K]int +type RefSyncSet[K, R comparable] struct { + container map[K]map[R]struct{} lock sync.RWMutex } // Add adds a string to the set -func (set *RefSyncSet[K]) Add(val K) { +func (set *RefSyncSet[K, R]) Add(val Reference[K, R]) { set.lock.Lock() defer set.lock.Unlock() - ref, ok := set.container[val] - if ok { - ref++ + if _, ok := set.container[val.Key]; !ok { + set.container[val.Key] = map[R]struct{}{} } - set.container[val] = ref + refs := set.container[val.Key] + refs[val.Referencer] = struct{}{} } // Remove removes a string from the set -func (set *RefSyncSet[K]) Remove(val K) { +func (set *RefSyncSet[K, R]) Remove(val Reference[K, R]) { set.lock.Lock() defer set.lock.Unlock() - ref, ok := set.container[val] - if ok { - ref-- + if _, ok := set.container[val.Key]; !ok { + return } - if ref == 0 { - delete(set.container, val) + refs := set.container[val.Key] + delete(refs, val.Referencer) + if len(refs) == 0 { + delete(set.container, val.Key) } else { - set.container[val] = ref + set.container[val.Key] = refs } } -func (set *RefSyncSet[K]) ToSlice() []K { +func (set *RefSyncSet[K, R]) ToSlice() []K { set.lock.RLock() defer set.lock.RUnlock() @@ -107,7 +113,7 @@ func (set *RefSyncSet[K]) ToSlice() []K { return ret } -func (set *RefSyncSet[K]) Range(fn func(val K)) { +func (set *RefSyncSet[K, R]) Range(fn func(val K)) { set.lock.RLock() snapshot := map[K]struct{}{} for k := range set.container { @@ -120,7 +126,7 @@ func (set *RefSyncSet[K]) Range(fn func(val K)) { } } -func (set *RefSyncSet[K]) Len() int { +func (set *RefSyncSet[K, R]) Len() int { set.lock.RLock() defer set.lock.RUnlock() @@ -128,7 +134,7 @@ func (set *RefSyncSet[K]) Len() int { } // Contains contains target value -func (set *RefSyncSet[K]) Contains(val K) bool { +func (set *RefSyncSet[K, R]) Contains(val K) bool { set.lock.Lock() defer set.lock.Unlock() @@ -136,7 +142,7 @@ func (set *RefSyncSet[K]) Contains(val K) bool { return exist } -func (set *RefSyncSet[K]) String() string { +func (set *RefSyncSet[K, R]) String() string { ret := set.ToSlice() return MustJson(ret) } diff --git a/common/utils/common.go b/common/utils/common.go index 49595eb1b..f08c1f852 100644 --- a/common/utils/common.go +++ b/common/utils/common.go @@ -87,6 +87,9 @@ const ( MaxDbCircuitbreakerOwner = 1024 MaxDbCircuitbreakerVersion = 32 + // ratelimit表 + MaxDbRateLimitName = MaxRuleName + MaxRuleName = 64 MaxPlatformIDLength = 32 diff --git a/common/utils/const.go b/common/utils/const.go index 0320376d1..8e01a8994 100644 --- a/common/utils/const.go +++ b/common/utils/const.go @@ -73,6 +73,6 @@ const ( ContextIsFromSystem = StringContext("from-system") // ContextOperator operator info ContextOperator = StringContext("operator") - // ContextRequestHeaders request headers + // ContextRequestHeaders request headers, save value type is map[string][]string ContextRequestHeaders = StringContext("request-headers") ) diff --git a/config/client_test.go b/config/client_test.go index c95c99b48..33a51472e 100644 --- a/config/client_test.go +++ b/config/client_test.go @@ -792,8 +792,10 @@ func TestServer_GetConfigGroupsWithCache(t *testing.T) { } t.Cleanup(func() { for k := range mockFiles { - testSuit.NamespaceServer().DeleteNamespace(testSuit.DefaultCtx, &apimodel.Namespace{ - Name: wrapperspb.String(k), + testSuit.NamespaceServer().DeleteNamespaces(testSuit.DefaultCtx, []*apimodel.Namespace{ + &apimodel.Namespace{ + Name: wrapperspb.String(k), + }, }) items := mockFiles[k] for _, item := range items { diff --git a/config/config_file_group.go b/config/config_file_group.go index e81f57aae..e7b82c9c0 100644 --- a/config/config_file_group.go +++ b/config/config_file_group.go @@ -256,6 +256,12 @@ func (s *Server) QueryConfigFileGroups(ctx context.Context, log.Error("[Config][Service] get config file count for group error.", utils.RequestID(ctx), utils.ZapNamespace(ret[i].Namespace), utils.ZapGroup(ret[i].Name), zap.Error(err)) } + + // 如果包含特殊标签,也不允许修改 + if _, ok := item.GetMetadata()[model.MetaKey3RdPlatform]; ok { + item.Editable = utils.NewBoolValue(false) + } + item.FileCount = wrapperspb.UInt64(fileCount) values = append(values, item) } diff --git a/config/interceptor/auth/client.go b/config/interceptor/auth/client.go index 892b19e1c..06a7c6d50 100644 --- a/config/interceptor/auth/client.go +++ b/config/interceptor/auth/client.go @@ -29,11 +29,11 @@ import ( ) // UpsertAndReleaseConfigFileFromClient 创建/更新配置文件并发布 -func (s *ServerAuthability) UpsertAndReleaseConfigFileFromClient(ctx context.Context, +func (s *Server) UpsertAndReleaseConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.PublishConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } @@ -44,7 +44,7 @@ func (s *ServerAuthability) UpsertAndReleaseConfigFileFromClient(ctx context.Con } // CreateConfigFileFromClient 调用config_file的方法创建配置文件 -func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, +func (s *Server) CreateConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFile) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{{ @@ -52,8 +52,8 @@ func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, Name: fileInfo.Name, Group: fileInfo.Group}, }, auth.Create, auth.CreateConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -63,12 +63,12 @@ func (s *ServerAuthability) CreateConfigFileFromClient(ctx context.Context, } // UpdateConfigFileFromClient 调用config_file的方法更新配置文件 -func (s *ServerAuthability) UpdateConfigFileFromClient(ctx context.Context, +func (s *Server) UpdateConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFile) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{fileInfo}, auth.Modify, auth.UpdateConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -78,13 +78,13 @@ func (s *ServerAuthability) UpdateConfigFileFromClient(ctx context.Context, } // DeleteConfigFileFromClient 删除配置文件,删除配置文件同时会通知客户端 Not_Found -func (s *ServerAuthability) DeleteConfigFileFromClient(ctx context.Context, +func (s *Server) DeleteConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -94,16 +94,16 @@ func (s *ServerAuthability) DeleteConfigFileFromClient(ctx context.Context, } // PublishConfigFileFromClient 调用config_file_release的方法发布配置文件 -func (s *ServerAuthability) PublishConfigFileFromClient(ctx context.Context, +func (s *Server) PublishConfigFileFromClient(ctx context.Context, fileInfo *apiconfig.ConfigFileRelease) *apiconfig.ConfigClientResponse { - authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, + authCtx := s.collectClientConfigFileRelease(ctx, []*apiconfig.ConfigFileRelease{{ Namespace: fileInfo.Namespace, Name: fileInfo.FileName, Group: fileInfo.Group}, }, auth.Create, auth.PublishConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -113,7 +113,7 @@ func (s *ServerAuthability) PublishConfigFileFromClient(ctx context.Context, } // GetConfigFileWithCache 从缓存中获取配置文件,如果客户端的版本号大于服务端,则服务端重新加载缓存 -func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, +func (s *Server) GetConfigFileWithCache(ctx context.Context, fileInfo *apiconfig.ClientConfigFileInfo) *apiconfig.ConfigClientResponse { authCtx := s.collectClientConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{{ @@ -121,8 +121,8 @@ func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, Name: fileInfo.FileName, Group: fileInfo.Group}, }, auth.Read, auth.DiscoverConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { - return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() @@ -131,12 +131,12 @@ func (s *ServerAuthability) GetConfigFileWithCache(ctx context.Context, } // WatchConfigFiles 监听配置文件变化 -func (s *ServerAuthability) LongPullWatchFile(ctx context.Context, +func (s *Server) LongPullWatchFile(ctx context.Context, request *apiconfig.ClientWatchConfigFileRequest) (config.WatchCallback, error) { authCtx := s.collectClientWatchConfigFiles(ctx, request, auth.Read, auth.WatchConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return func() *apiconfig.ConfigClientResponse { - return api.NewConfigClientResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + return api.NewConfigClientResponse(auth.ConvertToErrCode(err), nil) }, nil } @@ -147,16 +147,16 @@ func (s *ServerAuthability) LongPullWatchFile(ctx context.Context, } // GetConfigFileNamesWithCache 获取某个配置分组下的配置文件 -func (s *ServerAuthability) GetConfigFileNamesWithCache(ctx context.Context, +func (s *Server) GetConfigFileNamesWithCache(ctx context.Context, req *apiconfig.ConfigFileGroupRequest) *apiconfig.ConfigClientListResponse { - authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ + authCtx := s.collectClientConfigFileRelease(ctx, []*apiconfig.ConfigFileRelease{ { Namespace: req.GetConfigFileGroup().GetNamespace(), Group: req.GetConfigFileGroup().GetName(), }, }, auth.Read, auth.DiscoverConfigFileNames) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { out := api.NewConfigClientListResponse(auth.ConvertToErrCode(err)) return out } @@ -167,15 +167,15 @@ func (s *ServerAuthability) GetConfigFileNamesWithCache(ctx context.Context, } // GetConfigGroupsWithCache 获取某个命名空间下的配置分组列表 -func (s *ServerAuthability) GetConfigGroupsWithCache(ctx context.Context, +func (s *Server) GetConfigGroupsWithCache(ctx context.Context, req *apiconfig.ClientConfigFileInfo) *apiconfig.ConfigDiscoverResponse { - authCtx := s.collectClientConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{ + authCtx := s.collectClientConfigFileRelease(ctx, []*apiconfig.ConfigFileRelease{ { Namespace: req.GetNamespace(), }, }, auth.Read, auth.DiscoverConfigGroups) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { out := api.NewConfigDiscoverResponse(auth.ConvertToErrCode(err)) return out } @@ -186,12 +186,12 @@ func (s *ServerAuthability) GetConfigGroupsWithCache(ctx context.Context, } // CasUpsertAndReleaseConfigFileFromClient 创建/更新配置文件并发布 -func (s *ServerAuthability) CasUpsertAndReleaseConfigFileFromClient(ctx context.Context, +func (s *Server) CasUpsertAndReleaseConfigFileFromClient(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.UpsertAndReleaseConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) } diff --git a/config/interceptor/auth/config_file.go b/config/interceptor/auth/config_file.go index 4a3fcaf8c..f515b8f55 100644 --- a/config/interceptor/auth/config_file.go +++ b/config/interceptor/auth/config_file.go @@ -28,12 +28,12 @@ import ( ) // CreateConfigFile 创建配置文件 -func (s *ServerAuthability) CreateConfigFile(ctx context.Context, +func (s *Server) CreateConfigFile(ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{configFile}, auth.Create, auth.CreateConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -43,13 +43,13 @@ func (s *ServerAuthability) CreateConfigFile(ctx context.Context, } // GetConfigFileRichInfo 获取单个配置文件基础信息,包含发布状态等信息 -func (s *ServerAuthability) GetConfigFileRichInfo(ctx context.Context, +func (s *Server) GetConfigFileRichInfo(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{req}, auth.Read, auth.DescribeConfigFileRichInfo) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -57,12 +57,12 @@ func (s *ServerAuthability) GetConfigFileRichInfo(ctx context.Context, } // SearchConfigFile 查询配置文件 -func (s *ServerAuthability) SearchConfigFile(ctx context.Context, +func (s *Server) SearchConfigFile(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFiles) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -71,12 +71,12 @@ func (s *ServerAuthability) SearchConfigFile(ctx context.Context, } // UpdateConfigFile 更新配置文件 -func (s *ServerAuthability) UpdateConfigFile( +func (s *Server) UpdateConfigFile( ctx context.Context, configFile *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext( ctx, []*apiconfig.ConfigFile{configFile}, auth.Modify, auth.UpdateConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -86,13 +86,13 @@ func (s *ServerAuthability) UpdateConfigFile( } // DeleteConfigFile 删除配置文件,删除配置文件同时会通知客户端 Not_Found -func (s *ServerAuthability) DeleteConfigFile(ctx context.Context, +func (s *Server) DeleteConfigFile(ctx context.Context, req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, []*apiconfig.ConfigFile{req}, auth.Delete, auth.DeleteConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -102,12 +102,12 @@ func (s *ServerAuthability) DeleteConfigFile(ctx context.Context, } // BatchDeleteConfigFile 批量删除配置文件 -func (s *ServerAuthability) BatchDeleteConfigFile(ctx context.Context, +func (s *Server) BatchDeleteConfigFile(ctx context.Context, req []*apiconfig.ConfigFile) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileAuthContext(ctx, req, auth.Delete, auth.BatchDeleteConfigFiles) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -116,7 +116,7 @@ func (s *ServerAuthability) BatchDeleteConfigFile(ctx context.Context, return s.nextServer.BatchDeleteConfigFile(ctx, req) } -func (s *ServerAuthability) ExportConfigFile(ctx context.Context, +func (s *Server) ExportConfigFile(ctx context.Context, configFileExport *apiconfig.ConfigFileExportRequest) *apiconfig.ConfigExportResponse { var configFiles []*apiconfig.ConfigFile for _, group := range configFileExport.Groups { @@ -127,8 +127,8 @@ func (s *ServerAuthability) ExportConfigFile(ctx context.Context, configFiles = append(configFiles, configFile) } authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Read, auth.ExportConfigFiles) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileExportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigFileExportResponse(auth.ConvertToErrCode(err), nil) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -136,11 +136,11 @@ func (s *ServerAuthability) ExportConfigFile(ctx context.Context, return s.nextServer.ExportConfigFile(ctx, configFileExport) } -func (s *ServerAuthability) ImportConfigFile(ctx context.Context, +func (s *Server) ImportConfigFile(ctx context.Context, configFiles []*apiconfig.ConfigFile, conflictHandling string) *apiconfig.ConfigImportResponse { authCtx := s.collectConfigFileAuthContext(ctx, configFiles, auth.Create, auth.ImportConfigFiles) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileImportResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewSimpleConfigFileImportResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -148,7 +148,7 @@ func (s *ServerAuthability) ImportConfigFile(ctx context.Context, return s.nextServer.ImportConfigFile(ctx, configFiles, conflictHandling) } -func (s *ServerAuthability) GetAllConfigEncryptAlgorithms( +func (s *Server) GetAllConfigEncryptAlgorithms( ctx context.Context) *apiconfig.ConfigEncryptAlgorithmResponse { return s.nextServer.GetAllConfigEncryptAlgorithms(ctx) } diff --git a/config/interceptor/auth/config_file_group.go b/config/interceptor/auth/config_file_group.go index caa466c57..dcad994d7 100644 --- a/config/interceptor/auth/config_file_group.go +++ b/config/interceptor/auth/config_file_group.go @@ -19,24 +19,29 @@ package config_auth import ( "context" + "strconv" apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" + "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/model/auth" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" ) // CreateConfigFileGroup 创建配置文件组 -func (s *ServerAuthability) CreateConfigFileGroup(ctx context.Context, +func (s *Server) CreateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, - auth.Create, auth.CreateConfigFileGroup) + authcommon.Create, authcommon.CreateConfigFileGroup) // 验证 token 信息 - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -46,41 +51,79 @@ func (s *ServerAuthability) CreateConfigFileGroup(ctx context.Context, } // QueryConfigFileGroups 查询配置文件组 -func (s *ServerAuthability) QueryConfigFileGroups(ctx context.Context, +func (s *Server) QueryConfigFileGroups(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { + authCtx := s.collectConfigGroupAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeConfigFileGroups) - authCtx := s.collectConfigGroupAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileGroups) - - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + ctx = cachetypes.AppendConfigGroupPredicate(ctx, func(ctx context.Context, cfg *model.ConfigFileGroup) bool { + ok := s.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_ConfigGroups, + ID: strconv.FormatUint(cfg.Id, 10), + Metadata: cfg.Metadata, + }) + if ok { + return true + } + saveNs := s.cacheMgr.Namespace().GetNamespace(cfg.Namespace) + if saveNs == nil { + return false + } + // 检查下是否可以访问对应的 namespace + return s.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Namespaces, + ID: saveNs.Name, + Metadata: saveNs.Metadata, + }) + }) + authCtx.SetRequestContext(ctx) + resp := s.nextServer.QueryConfigFileGroups(ctx, filter) if len(resp.ConfigFileGroups) != 0 { for index := range resp.ConfigFileGroups { - group := resp.ConfigFileGroups[index] - editable := true - // 如果包含特殊标签,也不允许修改 - if _, ok := group.GetMetadata()[model.MetaKey3RdPlatform]; ok { - editable = false + item := resp.ConfigFileGroups[index] + authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_ConfigGroups: { + { + Type: apisecurity.ResourceType_ConfigGroups, + ID: strconv.FormatUint(item.GetId().GetValue(), 10), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateConfigFileGroup}) + // 如果检查不通过,设置 editable 为 false + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = utils.NewBoolValue(false) + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteConfigFileGroup}) + // 如果检查不通过,设置 editable 为 false + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = utils.NewBoolValue(false) } - group.Editable = utils.NewBoolValue(editable) } } return resp } // DeleteConfigFileGroup 删除配置文件组 -func (s *ServerAuthability) DeleteConfigFileGroup( +func (s *Server) DeleteConfigFileGroup( ctx context.Context, namespace, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{{Name: utils.NewStringValue(name), Namespace: utils.NewStringValue(namespace)}}, auth.Delete, auth.DeleteConfigFileGroup) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -90,13 +133,13 @@ func (s *ServerAuthability) DeleteConfigFileGroup( } // UpdateConfigFileGroup 更新配置文件组 -func (s *ServerAuthability) UpdateConfigFileGroup(ctx context.Context, +func (s *Server) UpdateConfigFileGroup(ctx context.Context, configFileGroup *apiconfig.ConfigFileGroup) *apiconfig.ConfigResponse { authCtx := s.collectConfigGroupAuthContext(ctx, []*apiconfig.ConfigFileGroup{configFileGroup}, auth.Modify, auth.UpdateConfigFileGroup) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() diff --git a/config/interceptor/auth/config_file_release.go b/config/interceptor/auth/config_file_release.go index 3a2df3e04..599d281fa 100644 --- a/config/interceptor/auth/config_file_release.go +++ b/config/interceptor/auth/config_file_release.go @@ -28,14 +28,14 @@ import ( ) // PublishConfigFile 发布配置文件 -func (s *ServerAuthability) PublishConfigFile(ctx context.Context, +func (s *Server) PublishConfigFile(ctx context.Context, configFileRelease *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{configFileRelease}, auth.Modify, "PublishConfigFile") - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -45,14 +45,14 @@ func (s *ServerAuthability) PublishConfigFile(ctx context.Context, } // GetConfigFileRelease 获取配置文件发布内容 -func (s *ServerAuthability) GetConfigFileRelease(ctx context.Context, +func (s *Server) GetConfigFileRelease(ctx context.Context, req *apiconfig.ConfigFileRelease) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, []*apiconfig.ConfigFileRelease{req}, auth.Read, auth.DescribeConfigFileRelease) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -60,13 +60,13 @@ func (s *ServerAuthability) GetConfigFileRelease(ctx context.Context, } // DeleteConfigFileReleases implements ConfigCenterServer. -func (s *ServerAuthability) DeleteConfigFileReleases(ctx context.Context, +func (s *Server) DeleteConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Delete, auth.DeleteConfigFileReleases) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -74,13 +74,12 @@ func (s *ServerAuthability) DeleteConfigFileReleases(ctx context.Context, } // GetConfigFileReleaseVersions implements ConfigCenterServer. -func (s *ServerAuthability) GetConfigFileReleaseVersions(ctx context.Context, +func (s *Server) GetConfigFileReleaseVersions(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseVersions) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -88,27 +87,28 @@ func (s *ServerAuthability) GetConfigFileReleaseVersions(ctx context.Context, } // GetConfigFileReleases implements ConfigCenterServer. -func (s *ServerAuthability) GetConfigFileReleases(ctx context.Context, +func (s *Server) GetConfigFileReleases(ctx context.Context, filters map[string]string) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleases) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return s.nextServer.GetConfigFileReleases(ctx, filters) } // RollbackConfigFileReleases implements ConfigCenterServer. -func (s *ServerAuthability) RollbackConfigFileReleases(ctx context.Context, +func (s *Server) RollbackConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Modify, auth.RollbackConfigFileReleases) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchWriteResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -116,12 +116,12 @@ func (s *ServerAuthability) RollbackConfigFileReleases(ctx context.Context, } // UpsertAndReleaseConfigFile . -func (s *ServerAuthability) UpsertAndReleaseConfigFile(ctx context.Context, +func (s *Server) UpsertAndReleaseConfigFile(ctx context.Context, req *apiconfig.ConfigFilePublishInfo) *apiconfig.ConfigResponse { authCtx := s.collectConfigFilePublishAuthContext(ctx, []*apiconfig.ConfigFilePublishInfo{req}, auth.Modify, auth.UpsertAndReleaseConfigFile) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigFileResponse(auth.ConvertToErrCode(err), nil) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigResponse(auth.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -130,12 +130,12 @@ func (s *ServerAuthability) UpsertAndReleaseConfigFile(ctx context.Context, return s.nextServer.UpsertAndReleaseConfigFile(ctx, req) } -func (s *ServerAuthability) StopGrayConfigFileReleases(ctx context.Context, +func (s *Server) StopGrayConfigFileReleases(ctx context.Context, reqs []*apiconfig.ConfigFileRelease) *apiconfig.ConfigBatchWriteResponse { authCtx := s.collectConfigFileReleaseAuthContext(ctx, reqs, auth.Modify, auth.StopGrayConfigFileReleases) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigBatchWriteResponse(auth.ConvertToErrCode(err)) } diff --git a/config/interceptor/auth/config_file_release_history.go b/config/interceptor/auth/config_file_release_history.go index 3b4464c7e..c2c0bfc75 100644 --- a/config/interceptor/auth/config_file_release_history.go +++ b/config/interceptor/auth/config_file_release_history.go @@ -28,14 +28,14 @@ import ( ) // GetConfigFileReleaseHistory 获取配置文件发布历史记录 -func (s *ServerAuthability) GetConfigFileReleaseHistories(ctx context.Context, +func (s *Server) GetConfigFileReleaseHistories(ctx context.Context, filter map[string]string) *apiconfig.ConfigBatchQueryResponse { - authCtx := s.collectConfigFileReleaseHistoryAuthContext(ctx, nil, auth.Read, auth.DescribeConfigFileReleaseHistories) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewConfigBatchQueryResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewConfigBatchQueryResponse(auth.ConvertToErrCode(err)) } + ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) return s.nextServer.GetConfigFileReleaseHistories(ctx, filter) diff --git a/config/interceptor/auth/config_file_template.go b/config/interceptor/auth/config_file_template.go index dae5f9ab9..7d83e465f 100644 --- a/config/interceptor/auth/config_file_template.go +++ b/config/interceptor/auth/config_file_template.go @@ -28,10 +28,10 @@ import ( ) // GetAllConfigFileTemplates get all config file templates -func (s *ServerAuthability) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.ConfigBatchQueryResponse { +func (s *Server) GetAllConfigFileTemplates(ctx context.Context) *apiconfig.ConfigBatchQueryResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeAllConfigFileTemplates) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigFileBatchQueryResponseWithMessage(auth.ConvertToErrCode(err), err.Error()) } @@ -41,10 +41,10 @@ func (s *ServerAuthability) GetAllConfigFileTemplates(ctx context.Context) *apic } // GetConfigFileTemplate get config file template -func (s *ServerAuthability) GetConfigFileTemplate(ctx context.Context, name string) *apiconfig.ConfigResponse { +func (s *Server) GetConfigFileTemplate(ctx context.Context, name string) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{}, auth.Read, auth.DescribeConfigFileTemplate) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } @@ -54,12 +54,12 @@ func (s *ServerAuthability) GetConfigFileTemplate(ctx context.Context, name stri } // CreateConfigFileTemplate create config file template -func (s *ServerAuthability) CreateConfigFileTemplate(ctx context.Context, +func (s *Server) CreateConfigFileTemplate(ctx context.Context, template *apiconfig.ConfigFileTemplate) *apiconfig.ConfigResponse { authCtx := s.collectConfigFileTemplateAuthContext(ctx, []*apiconfig.ConfigFileTemplate{template}, auth.Create, auth.CreateConfigFileTemplate) - if _, err := s.policyMgr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + if _, err := s.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewConfigResponseWithInfo(auth.ConvertToErrCode(err), err.Error()) } diff --git a/config/interceptor/auth/resource_listener.go b/config/interceptor/auth/resource_listener.go index f1fda9b46..23d73792f 100644 --- a/config/interceptor/auth/resource_listener.go +++ b/config/interceptor/auth/resource_listener.go @@ -30,12 +30,12 @@ import ( ) // Before this function is called before the resource operation -func (s *ServerAuthability) Before(ctx context.Context, resourceType model.Resource) { +func (s *Server) Before(ctx context.Context, resourceType model.Resource) { // do nothing } // After this function is called after the resource operation -func (s *ServerAuthability) After(ctx context.Context, resourceType model.Resource, res *config.ResourceEvent) error { +func (s *Server) After(ctx context.Context, resourceType model.Resource, res *config.ResourceEvent) error { switch resourceType { case model.RConfigGroup: return s.onConfigGroupResource(ctx, res) @@ -45,7 +45,7 @@ func (s *ServerAuthability) After(ctx context.Context, resourceType model.Resour } // onConfigGroupResource -func (s *ServerAuthability) onConfigGroupResource(ctx context.Context, res *config.ResourceEvent) error { +func (s *Server) onConfigGroupResource(ctx context.Context, res *config.ResourceEvent) error { authCtx := ctx.Value(utils.ContextAuthContextKey).(*auth.AcquireContext) authCtx.SetAttachment(auth.ResourceAttachmentKey, map[apisecurity.ResourceType][]auth.ResourceEntry{ @@ -69,5 +69,5 @@ func (s *ServerAuthability) onConfigGroupResource(ctx context.Context, res *conf authCtx.SetAttachment(auth.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) authCtx.SetAttachment(auth.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) - return s.policyMgr.AfterResourceOperation(authCtx) + return s.policySvr.AfterResourceOperation(authCtx) } diff --git a/config/interceptor/auth/server.go b/config/interceptor/auth/server.go index 4fece2d2b..f7b808d99 100644 --- a/config/interceptor/auth/server.go +++ b/config/interceptor/auth/server.go @@ -33,23 +33,23 @@ import ( "github.com/polarismesh/polaris/config" ) -var _ config.ConfigCenterServer = (*ServerAuthability)(nil) +var _ config.ConfigCenterServer = (*Server)(nil) // Server 配置中心核心服务 -type ServerAuthability struct { +type Server struct { cacheMgr cachetypes.CacheManager nextServer config.ConfigCenterServer - userMgn auth.UserServer - policyMgr auth.StrategyServer + userSvr auth.UserServer + policySvr auth.StrategyServer } func New(nextServer config.ConfigCenterServer, cacheMgr cachetypes.CacheManager, - userMgr auth.UserServer, strategyMgr auth.StrategyServer) config.ConfigCenterServer { - proxy := &ServerAuthability{ + userSvr auth.UserServer, policySvr auth.StrategyServer) config.ConfigCenterServer { + proxy := &Server{ nextServer: nextServer, cacheMgr: cacheMgr, - userMgn: userMgr, - policyMgr: strategyMgr, + userSvr: userSvr, + policySvr: policySvr, } if val, ok := nextServer.(*config.Server); ok { val.SetResourceHooks(proxy) @@ -57,7 +57,7 @@ func New(nextServer config.ConfigCenterServer, cacheMgr cachetypes.CacheManager, return proxy } -func (s *ServerAuthability) collectConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, +func (s *Server) collectConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -68,7 +68,7 @@ func (s *ServerAuthability) collectConfigFileAuthContext(ctx context.Context, re ) } -func (s *ServerAuthability) collectClientConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, +func (s *Server) collectClientConfigFileAuthContext(ctx context.Context, req []*apiconfig.ConfigFile, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -80,8 +80,8 @@ func (s *ServerAuthability) collectClientConfigFileAuthContext(ctx context.Conte ) } -func (s *ServerAuthability) collectClientWatchConfigFiles(ctx context.Context, - req *apiconfig.ClientWatchConfigFileRequest, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *Server) collectClientWatchConfigFiles(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), @@ -92,7 +92,7 @@ func (s *ServerAuthability) collectClientWatchConfigFiles(ctx context.Context, ) } -func (s *ServerAuthability) collectConfigFileReleaseAuthContext(ctx context.Context, req []*apiconfig.ConfigFileRelease, +func (s *Server) collectConfigFileReleaseAuthContext(ctx context.Context, req []*apiconfig.ConfigFileRelease, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -103,7 +103,7 @@ func (s *ServerAuthability) collectConfigFileReleaseAuthContext(ctx context.Cont ) } -func (s *ServerAuthability) collectConfigFilePublishAuthContext(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo, +func (s *Server) collectConfigFilePublishAuthContext(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -114,8 +114,8 @@ func (s *ServerAuthability) collectConfigFilePublishAuthContext(ctx context.Cont ) } -func (s *ServerAuthability) collectClientConfigFileReleaseAuthContext(ctx context.Context, - req []*apiconfig.ConfigFileRelease, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *Server) collectClientConfigFileRelease(ctx context.Context, req []*apiconfig.ConfigFileRelease, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), @@ -126,7 +126,7 @@ func (s *ServerAuthability) collectClientConfigFileReleaseAuthContext(ctx contex ) } -func (s *ServerAuthability) collectConfigFileReleaseHistoryAuthContext( +func (s *Server) collectConfigFileReleaseHistoryAuthContext( ctx context.Context, req []*apiconfig.ConfigFileReleaseHistory, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { @@ -139,7 +139,7 @@ func (s *ServerAuthability) collectConfigFileReleaseHistoryAuthContext( ) } -func (s *ServerAuthability) collectConfigGroupAuthContext(ctx context.Context, req []*apiconfig.ConfigFileGroup, +func (s *Server) collectConfigGroupAuthContext(ctx context.Context, req []*apiconfig.ConfigFileGroup, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), @@ -150,15 +150,15 @@ func (s *ServerAuthability) collectConfigGroupAuthContext(ctx context.Context, r ) } -func (s *ServerAuthability) collectConfigFileTemplateAuthContext(ctx context.Context, - req []*apiconfig.ConfigFileTemplate, op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (s *Server) collectConfigFileTemplateAuthContext(ctx context.Context, req []*apiconfig.ConfigFileTemplate, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithModule(authcommon.ConfigModule), ) } -func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, +func (s *Server) queryConfigGroupResource(ctx context.Context, req []*apiconfig.ConfigFileGroup) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -188,7 +188,7 @@ func (s *ServerAuthability) queryConfigGroupResource(ctx context.Context, } // queryConfigFileResource config file资源的鉴权转换为config group的鉴权 -func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, +func (s *Server) queryConfigFileResource(ctx context.Context, req []*apiconfig.ConfigFile) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -214,7 +214,7 @@ func (s *ServerAuthability) queryConfigFileResource(ctx context.Context, return ret } -func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, +func (s *Server) queryConfigFileReleaseResource(ctx context.Context, req []*apiconfig.ConfigFileRelease) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -240,7 +240,7 @@ func (s *ServerAuthability) queryConfigFileReleaseResource(ctx context.Context, return ret } -func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, +func (s *Server) queryConfigFilePublishResource(ctx context.Context, req []*apiconfig.ConfigFilePublishInfo) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -264,7 +264,7 @@ func (s *ServerAuthability) queryConfigFilePublishResource(ctx context.Context, return ret } -func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Context, +func (s *Server) queryConfigFileReleaseHistoryResource(ctx context.Context, req []*apiconfig.ConfigFileReleaseHistory) map[apisecurity.ResourceType][]authcommon.ResourceEntry { if len(req) == 0 { @@ -290,7 +290,7 @@ func (s *ServerAuthability) queryConfigFileReleaseHistoryResource(ctx context.Co return ret } -func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, namespace string, +func (s *Server) queryConfigGroupRsEntryByNames(ctx context.Context, namespace string, names []string) ([]authcommon.ResourceEntry, error) { configFileGroups := make([]*model.ConfigFileGroup, 0, len(names)) @@ -315,7 +315,7 @@ func (s *ServerAuthability) queryConfigGroupRsEntryByNames(ctx context.Context, return entries, nil } -func (s *ServerAuthability) queryWatchConfigFilesResource(ctx context.Context, +func (s *Server) queryWatchConfigFilesResource(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest) map[apisecurity.ResourceType][]authcommon.ResourceEntry { files := req.GetWatchFiles() if len(files) == 0 { diff --git a/go.mod b/go.mod index f347292b3..f2590eca3 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,14 @@ go 1.21 require ( github.com/BurntSushi/toml v1.2.0 github.com/emicklei/go-restful/v3 v3.9.0 - github.com/envoyproxy/go-control-plane v0.12.0 + github.com/envoyproxy/go-control-plane v0.13.0 github.com/go-openapi/spec v0.20.7 github.com/go-redis/redis/v8 v8.11.5 github.com/go-sql-driver/mysql v1.6.0 github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 - github.com/golang/protobuf v1.5.3 - github.com/google/uuid v1.3.0 + github.com/golang/protobuf v1.5.4 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru v0.5.4 github.com/json-iterator/go v1.1.12 // indirect @@ -24,17 +24,17 @@ require ( github.com/prometheus/client_golang v1.18.0 github.com/smartystreets/goconvey v1.6.4 github.com/spf13/cobra v1.2.1 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 go.uber.org/atomic v1.10.0 go.uber.org/automaxprocs v1.4.0 go.uber.org/zap v1.23.0 - golang.org/x/crypto v0.21.0 - golang.org/x/net v0.23.0 - golang.org/x/sync v0.6.0 - golang.org/x/text v0.14.0 + golang.org/x/crypto v0.23.0 + golang.org/x/net v0.25.0 + golang.org/x/sync v0.7.0 + golang.org/x/text v0.15.0 golang.org/x/time v0.1.1-0.20221020023724-80b9fac54d29 - google.golang.org/grpc v1.58.3 - google.golang.org/protobuf v1.33.0 + google.golang.org/grpc v1.65.0 + google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v2 v2.4.0 ) @@ -48,11 +48,11 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/envoyproxy/protoc-gen-validate v1.0.2 // indirect + github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonreference v0.20.0 // indirect github.com/go-openapi/swag v0.19.15 // indirect @@ -65,31 +65,33 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/client_model v0.6.0 // indirect github.com/prometheus/common v0.45.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/smartystreets/assertions v1.0.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - go.uber.org/goleak v1.1.12 // indirect go.uber.org/multierr v1.8.0 // indirect - golang.org/x/sys v0.18.0 // indirect - google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 // indirect + golang.org/x/sys v0.20.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 + github.com/polarismesh/specification v1.5.3-alpha.2 ) -require github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect +require ( + cel.dev/expr v0.15.0 // indirect + github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect +) require ( github.com/dlclark/regexp2 v1.10.0 go.etcd.io/bbolt v1.3.7 - google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect ) replace gopkg.in/yaml.v2 => gopkg.in/yaml.v2 v2.2.2 diff --git a/go.sum b/go.sum index 9310d5282..ef02521f5 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= +cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -61,8 +63,8 @@ github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqO github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -70,8 +72,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4 h1:/inchEIKaYC1Akx+H+gqO04wryn5h75LSazbRlnya1k= -github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -92,11 +94,11 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI= -github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0= +github.com/envoyproxy/go-control-plane v0.13.0 h1:HzkeUz1Knt+3bK+8LG1bxOO/jzWZmdxpwC51i202les= +github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= -github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= +github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= +github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -154,8 +156,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -169,8 +171,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -188,8 +190,8 @@ github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -292,18 +294,20 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219 h1:XnFyNUWnciM6zgXaz6tm+Egs35rhoD0KGMmKh4gCdi0= github.com/polarismesh/go-restful-openapi/v2 v2.0.0-20220928152401-083908d10219/go.mod h1:4WhwBysTom9Eoy0hQ4W69I0FmO+T0EpjEW9/5sgHoUk= -github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555 h1:eLZzl6yaeuQbTRYeZDu9cqrkJt7qlrt+fE7hGB+R3bc= -github.com/polarismesh/specification v1.5.2-0.20240722103923-1d9990d6f555/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= +github.com/polarismesh/specification v1.5.3-alpha.2 h1:QSgpGmx5VfPcDPAq7qnTOkMVFNpmBMgLSDhtyMlS6/g= +github.com/polarismesh/specification v1.5.3-alpha.2/go.mod h1:rDvMMtl5qebPmqiBLNa5Ps0XtwkP31ZLirbH4kXA0YU= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= +github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= @@ -340,8 +344,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -365,8 +369,8 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= -go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= -go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= @@ -380,8 +384,8 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -454,8 +458,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -479,8 +483,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -524,8 +528,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -536,8 +540,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -595,7 +599,6 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -670,12 +673,10 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g= -google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0= -google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw= -google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= +google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw= +google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 h1:Zy9XzmMEflZ/MAaA7vNcoebnRAld7FsPW1EeBB7V0m8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -696,8 +697,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= -google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= +google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= +google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -710,8 +711,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/import-format.sh b/import-format.sh index a6335c62c..9807c399e 100644 --- a/import-format.sh +++ b/import-format.sh @@ -17,7 +17,7 @@ # 格式化 go.mod go mod tidy -compat=1.17 -docker run -t --rm -v $(pwd):/app -w /app golangci/golangci-lint:v1.55.2 golangci-lint run -v +docker run -t --rm -v $(pwd):/app -w /app golangci/golangci-lint:v1.61.0 golangci-lint run -v --timeout 30m # 处理 go imports 的格式化 rm -rf style_tool diff --git a/namespace/api.go b/namespace/api.go index 84edaacb2..b6732c24a 100644 --- a/namespace/api.go +++ b/namespace/api.go @@ -30,8 +30,6 @@ type NamespaceOperateServer interface { CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response // CreateNamespaces Batch creation namespace CreateNamespaces(ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse - // DeleteNamespace Delete a single namespace - DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response // DeleteNamespaces Batch delete namespace DeleteNamespaces(ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse // UpdateNamespaces Batch update naming space diff --git a/namespace/default.go b/namespace/default.go index bebab4e8b..a172c730b 100644 --- a/namespace/default.go +++ b/namespace/default.go @@ -20,14 +20,13 @@ package namespace import ( "context" "errors" + "fmt" "sync" "golang.org/x/sync/singleflight" - "github.com/polarismesh/polaris/auth" "github.com/polarismesh/polaris/cache" cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/store" ) @@ -36,17 +35,36 @@ var ( namespaceServer = &Server{} once sync.Once finishInit bool + // serverProxyFactories Service Server API 代理工厂 + serverProxyFactories = map[string]ServerProxyFactory{} ) +type ServerProxyFactory func(context.Context, NamespaceOperateServer, cachetypes.CacheManager) (NamespaceOperateServer, error) + +func RegisterServerProxy(name string, factor ServerProxyFactory) error { + if _, ok := serverProxyFactories[name]; ok { + return fmt.Errorf("duplicate ServerProxyFactory, name(%s)", name) + } + serverProxyFactories[name] = factor + return nil +} + type Config struct { - AutoCreate bool `yaml:"autoCreate"` + AutoCreate bool `yaml:"autoCreate"` + Interceptors []string `yaml:"-"` } // Initialize 初始化 -func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager) error { +func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMgr *cache.CacheManager) error { var err error once.Do(func() { - err = initialize(ctx, nsOpt, storage, cacheMgn) + actualSvr, proxySvr, err := InitServer(ctx, nsOpt, storage, cacheMgr) + if err != nil { + return + } + namespaceServer = actualSvr + server = proxySvr + return }) if err != nil { @@ -57,35 +75,36 @@ func Initialize(ctx context.Context, nsOpt *Config, storage store.Store, cacheMg return nil } -func initialize(_ context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager) error { - if err := cacheMgn.OpenResourceCache(cachetypes.ConfigEntry{ +func InitServer(ctx context.Context, nsOpt *Config, storage store.Store, + cacheMgr *cache.CacheManager) (*Server, NamespaceOperateServer, error) { + if err := cacheMgr.OpenResourceCache(cachetypes.ConfigEntry{ Name: cachetypes.NamespaceName, }); err != nil { - return err - } - namespaceServer.caches = cacheMgn - namespaceServer.storage = storage - namespaceServer.cfg = *nsOpt - namespaceServer.createNamespaceSingle = &singleflight.Group{} - - // 获取History插件,注意:插件的配置在bootstrap已经设置好 - namespaceServer.history = plugin.GetHistory() - if namespaceServer.history == nil { - log.Warn("Not Found History Log Plugin") - } - - userMgn, err := auth.GetUserServer() - if err != nil { - return err + return nil, nil, err } - strategyMgn, err := auth.GetStrategyServer() - if err != nil { - return err + actualSvr := new(Server) + actualSvr.caches = cacheMgr + actualSvr.storage = storage + actualSvr.cfg = *nsOpt + actualSvr.createNamespaceSingle = &singleflight.Group{} + + var proxySvr NamespaceOperateServer + proxySvr = actualSvr + order := GetChainOrder() + for i := range order { + factory, exist := serverProxyFactories[order[i]] + if !exist { + return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) + } + + afterSvr, err := factory(ctx, proxySvr, cacheMgr) + if err != nil { + return nil, nil, err + } + proxySvr = afterSvr } - - server = newServerAuthAbility(namespaceServer, userMgn, strategyMgn) - return nil + return actualSvr, proxySvr, nil } // GetServer 获取已经初始化好的Server @@ -105,3 +124,9 @@ func GetOriginServer() (*Server, error) { return namespaceServer, nil } + +func GetChainOrder() []string { + return []string{ + "auth", + } +} diff --git a/namespace/interceptor/auth/log.go b/namespace/interceptor/auth/log.go new file mode 100644 index 000000000..78a77e607 --- /dev/null +++ b/namespace/interceptor/auth/log.go @@ -0,0 +1,26 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package auth + +import ( + commonlog "github.com/polarismesh/polaris/common/log" +) + +var ( + authLog = commonlog.GetScopeOrDefaultByName(commonlog.AuthLoggerName) +) diff --git a/namespace/resource_listener.go b/namespace/interceptor/auth/resource_listener.go similarity index 72% rename from namespace/resource_listener.go rename to namespace/interceptor/auth/resource_listener.go index 0fb7268d3..9ec120d94 100644 --- a/namespace/resource_listener.go +++ b/namespace/interceptor/auth/resource_listener.go @@ -15,48 +15,27 @@ * specific language governing permissions and limitations under the License. */ -package namespace +package auth import ( "context" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/namespace" ) -// ResourceHook The listener is placed before and after the resource operation, only normal flow -type ResourceHook interface { - - // Before - // @param ctx - // @param resourceType - Before(ctx context.Context, resourceType model.Resource) - - // After - // @param ctx - // @param resourceType - // @param res - After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error -} - -// ResourceEvent 资源事件 -type ResourceEvent struct { - ReqNamespace *apimodel.Namespace - Namespace *model.Namespace - IsRemove bool -} - // Before this function is called before the resource operation -func (svr *serverAuthAbility) Before(ctx context.Context, resourceType model.Resource) { +func (svr *Server) Before(ctx context.Context, resourceType model.Resource) { // do nothing } // After this function is called after the resource operation -func (svr *serverAuthAbility) After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error { +func (svr *Server) After(ctx context.Context, resourceType model.Resource, res *namespace.ResourceEvent) error { switch resourceType { case model.RNamespace: return svr.onNamespaceResource(ctx, res) @@ -66,7 +45,7 @@ func (svr *serverAuthAbility) After(ctx context.Context, resourceType model.Reso } // onNamespaceResource -func (svr *serverAuthAbility) onNamespaceResource(ctx context.Context, res *ResourceEvent) error { +func (svr *Server) onNamespaceResource(ctx context.Context, res *namespace.ResourceEvent) error { authCtx, _ := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) if authCtx == nil { log.Warn("[Namespace][ResourceHook] get auth context is nil, ignore", utils.RequestID(ctx)) diff --git a/namespace/namespace_authability.go b/namespace/interceptor/auth/server.go similarity index 50% rename from namespace/namespace_authability.go rename to namespace/interceptor/auth/server.go index 98209e163..2421a14f0 100644 --- a/namespace/namespace_authability.go +++ b/namespace/interceptor/auth/server.go @@ -15,7 +15,7 @@ * specific language governing permissions and limitations under the License. */ -package namespace +package auth import ( "context" @@ -23,29 +23,56 @@ import ( apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" + "github.com/polarismesh/polaris/auth" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/namespace" ) -var _ NamespaceOperateServer = (*serverAuthAbility)(nil) +var _ namespace.NamespaceOperateServer = (*Server)(nil) + +// Server 带有鉴权能力的 NamespaceOperateServer +// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 +type Server struct { + nextSvr namespace.NamespaceOperateServer + userSvr auth.UserServer + policySvr auth.StrategyServer + cacheSvr cachetypes.CacheManager +} + +func NewServer(nextSvr namespace.NamespaceOperateServer, userSvr auth.UserServer, + policySvr auth.StrategyServer, cacheSvr cachetypes.CacheManager) namespace.NamespaceOperateServer { + proxy := &Server{ + nextSvr: nextSvr, + userSvr: userSvr, + policySvr: policySvr, + cacheSvr: cacheSvr, + } + + if actualSvr, ok := nextSvr.(*namespace.Server); ok { + actualSvr.SetResourceHooks(proxy) + } + return proxy +} // CreateNamespaceIfAbsent Create a single name space -func (svr *serverAuthAbility) CreateNamespaceIfAbsent(ctx context.Context, +func (svr *Server) CreateNamespaceIfAbsent(ctx context.Context, req *apimodel.Namespace) (string, *apiservice.Response) { - return svr.targetServer.CreateNamespaceIfAbsent(ctx, req) + return svr.nextSvr.CreateNamespaceIfAbsent(ctx, req) } // CreateNamespace 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 -func (svr *serverAuthAbility) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Create, authcommon.CreateNamespace) // 验证 token 信息 if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -56,17 +83,17 @@ func (svr *serverAuthAbility) CreateNamespace(ctx context.Context, req *apimodel req.Owners = utils.NewStringValue(ownerId) } - return svr.targetServer.CreateNamespace(ctx, req) + return svr.nextSvr.CreateNamespace(ctx, req) } // CreateNamespaces 创建命名空间,只需要要后置鉴权,将数据添加到资源策略中 -func (svr *serverAuthAbility) CreateNamespaces( +func (svr *Server) CreateNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateNamespaces) // 验证 token 信息 if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -80,98 +107,155 @@ func (svr *serverAuthAbility) CreateNamespaces( req.Owners = utils.NewStringValue(ownerId) } } - - return svr.targetServer.CreateNamespaces(ctx, reqs) -} - -// DeleteNamespace 删除命名空间,需要先走权限检查 -func (svr *serverAuthAbility) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { - authCtx := svr.collectNamespaceAuthContext( - ctx, []*apimodel.Namespace{req}, authcommon.Delete, authcommon.DeleteNamespace) - if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - - return svr.targetServer.DeleteNamespace(ctx, req) + return svr.nextSvr.CreateNamespaces(ctx, reqs) } // DeleteNamespaces 删除命名空间,需要先走权限检查 -func (svr *serverAuthAbility) DeleteNamespaces( +func (svr *Server) DeleteNamespaces( ctx context.Context, reqs []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteNamespaces) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.targetServer.DeleteNamespaces(ctx, reqs) + return svr.nextSvr.DeleteNamespaces(ctx, reqs) } // UpdateNamespaces 更新命名空间,需要先走权限检查 -func (svr *serverAuthAbility) UpdateNamespaces( +func (svr *Server) UpdateNamespaces( ctx context.Context, req []*apimodel.Namespace) *apiservice.BatchWriteResponse { authCtx := svr.collectNamespaceAuthContext(ctx, req, authcommon.Modify, authcommon.UpdateNamespaces) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.targetServer.UpdateNamespaces(ctx, req) + return svr.nextSvr.UpdateNamespaces(ctx, req) } // UpdateNamespaceToken 更新命名空间的token信息,需要先走权限检查 -func (svr *serverAuthAbility) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *Server) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Modify, authcommon.UpdateNamespaceToken) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.targetServer.UpdateNamespaceToken(ctx, req) + return svr.nextSvr.UpdateNamespaceToken(ctx, req) } // GetNamespaces 获取命名空间列表信息,暂时不走权限检查 -func (svr *serverAuthAbility) GetNamespaces( +func (svr *Server) GetNamespaces( ctx context.Context, query map[string][]string) *apiservice.BatchQueryResponse { authCtx := svr.collectNamespaceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeNamespaces) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchQueryResponseWithMsg(convertToErrCode(err), err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendNamespacePredicate(ctx, func(ctx context.Context, n *model.Namespace) bool { + ctx = cachetypes.AppendNamespacePredicate(ctx, func(ctx context.Context, n *model.Namespace) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Users, - ID: n.Name, + Type: apisecurity.ResourceType_Namespaces, + ID: n.Name, + Metadata: n.Metadata, }) }) - return svr.targetServer.GetNamespaces(ctx, query) + authCtx.SetRequestContext(ctx) + resp := svr.nextSvr.GetNamespaces(ctx, query) + for i := range resp.Namespaces { + item := resp.Namespaces[i] + authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Namespaces: { + { + Type: apisecurity.ResourceType_Namespaces, + ID: item.GetId().GetValue(), + Metadata: item.GetMetadata(), + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateNamespaces}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = utils.NewBoolValue(false) + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteNamespaces}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = utils.NewBoolValue(false) + } + } + return resp } // GetNamespaceToken 获取命名空间的token信息,暂时不走权限检查 -func (svr *serverAuthAbility) GetNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { +func (svr *Server) GetNamespaceToken(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { authCtx := svr.collectNamespaceAuthContext( ctx, []*apimodel.Namespace{req}, authcommon.Read, authcommon.DescribeNamespaceToken) _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { - return api.NewResponseWithMsg(convertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.targetServer.GetNamespaceToken(ctx, req) + return svr.nextSvr.GetNamespaceToken(ctx, req) +} + +// collectNamespaceAuthContext 对于命名空间的处理,收集所有的与鉴权的相关信息 +func (svr *Server) collectNamespaceAuthContext(ctx context.Context, req []*apimodel.Namespace, + resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(resourceOp), + authcommon.WithModule(authcommon.CoreModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(svr.queryNamespaceResource(req)), + ) +} + +// queryNamespaceResource 根据所给的 namespace 信息,收集对应的 ResourceEntry 列表 +func (svr *Server) queryNamespaceResource( + req []*apimodel.Namespace) map[apisecurity.ResourceType][]authcommon.ResourceEntry { + if len(req) == 0 { + return map[apisecurity.ResourceType][]authcommon.ResourceEntry{} + } + + names := utils.NewSet[string]() + for index := range req { + names.Add(req[index].Name.GetValue()) + } + param := names.ToSlice() + nsArr := svr.cacheSvr.Namespace().GetNamespacesByName(param) + + temp := make([]authcommon.ResourceEntry, 0, len(nsArr)) + + for index := range nsArr { + ns := nsArr[index] + temp = append(temp, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_Namespaces, + ID: ns.Name, + Owner: ns.Owner, + }) + } + + ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Namespaces: temp, + } + authLog.Debug("[Auth][Server] collect namespace access res", zap.Any("res", ret)) + return ret } diff --git a/namespace/interceptor/register.go b/namespace/interceptor/register.go new file mode 100644 index 000000000..f3925a25b --- /dev/null +++ b/namespace/interceptor/register.go @@ -0,0 +1,68 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package interceptor + +import ( + "context" + + "github.com/polarismesh/polaris/auth" + cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/namespace" + ns_auth "github.com/polarismesh/polaris/namespace/interceptor/auth" +) + +type ( + ContextKeyUserSvr struct{} + ContextKeyPolicySvr struct{} +) + +func init() { + err := namespace.RegisterServerProxy("auth", func(ctx context.Context, + pre namespace.NamespaceOperateServer, cacheSvr cachetypes.CacheManager) (namespace.NamespaceOperateServer, error) { + + var userSvr auth.UserServer + var policySvr auth.StrategyServer + + userSvrVal := ctx.Value(ContextKeyUserSvr{}) + if userSvrVal == nil { + svr, err := auth.GetUserServer() + if err != nil { + return nil, err + } + userSvr = svr + } else { + userSvr = userSvrVal.(auth.UserServer) + } + + policySvrVal := ctx.Value(ContextKeyPolicySvr{}) + if policySvrVal == nil { + svr, err := auth.GetStrategyServer() + if err != nil { + return nil, err + } + policySvr = svr + } else { + policySvr = policySvrVal.(auth.StrategyServer) + } + + return ns_auth.NewServer(pre, userSvr, policySvr, cacheSvr), nil + }) + if err != nil { + panic(err) + } +} diff --git a/namespace/namespace.go b/namespace/namespace.go index 07e9f10cc..83643a04e 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -19,7 +19,6 @@ package namespace import ( "context" - "fmt" "time" "github.com/golang/protobuf/jsonpb" @@ -27,6 +26,7 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" commonstore "github.com/polarismesh/polaris/common/store" @@ -94,8 +94,6 @@ func (s *Server) CreateNamespaceIfAbsent(ctx context.Context, req *apimodel.Name // CreateNamespace 创建单个命名空间 func (s *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { - requestID, _ := ctx.Value(utils.StringContext("request-id")).(string) - // 参数检查 if checkError := checkCreateNamespace(req); checkError != nil { return checkError @@ -106,19 +104,18 @@ func (s *Server) CreateNamespace(ctx context.Context, req *apimodel.Namespace) * // 检查是否存在 namespace, err := s.storage.GetNamespace(namespaceName) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace != nil { return api.NewNamespaceResponse(apimodel.Code_ExistedResource, req) } - // data := s.createNamespaceModel(req) // 存储层操作 if err := s.storage.AddNamespace(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } @@ -143,6 +140,7 @@ func (s *Server) createNamespaceModel(req *apimodel.Namespace) *model.Namespace Owner: req.GetOwners().GetValue(), Token: utils.NewUUID(), ServiceExportTo: model.ExportToMap(req.GetServiceExportTo()), + Metadata: req.GetMetadata(), } return namespace } @@ -164,8 +162,6 @@ func (s *Server) DeleteNamespaces(ctx context.Context, req []*apimodel.Namespace // DeleteNamespace 删除单个命名空间 func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) *apiservice.Response { - requestID, _ := ctx.Value(utils.StringContext("request-id")).(string) - // 参数检查 if checkError := checkReviseNamespace(ctx, req); checkError != nil { return checkError @@ -173,7 +169,7 @@ func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) * tx, err := s.storage.CreateTransaction() if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } defer func() { _ = tx.Commit() }() @@ -181,47 +177,38 @@ func (s *Server) DeleteNamespace(ctx context.Context, req *apimodel.Namespace) * // 检查是否存在 namespace, err := tx.LockNamespace(req.GetName().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace == nil { return api.NewNamespaceResponse(apimodel.Code_ExecuteSuccess, req) } - // // 鉴权 - // if ok := s.authority.VerifyNamespace(namespace.Token, parseNamespaceToken(ctx, req)); !ok { - // return api.NewNamespaceResponse(api.Unauthorized, req) - // } - // 判断属于该命名空间的服务是否都已经被删除 total, err := s.getServicesCountWithNamespace(namespace.Name) if err != nil { - log.Error("get services count with namespace err", - utils.ZapRequestID(requestID), - zap.String("err", err.Error())) + log.Error("get services count with namespace err", utils.RequestID(ctx), zap.Error(err)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if total != 0 { - log.Error("the removed namespace has remain services", utils.ZapRequestID(requestID)) + log.Error("the removed namespace has remain services", utils.RequestID(ctx)) return api.NewNamespaceResponse(apimodel.Code_NamespaceExistedServices, req) } // 判断属于该命名空间的服务是否都已经被删除 total, err = s.getConfigGroupCountWithNamespace(namespace.Name) if err != nil { - log.Error("get config group count with namespace err", - utils.ZapRequestID(requestID), - zap.String("err", err.Error())) + log.Error("get config group count with namespace err", utils.RequestID(ctx), zap.Error(err)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if total != 0 { - log.Error("the removed namespace has remain config-group", utils.ZapRequestID(requestID)) + log.Error("the removed namespace has remain config-group", utils.RequestID(ctx)) return api.NewNamespaceResponse(apimodel.Code_NamespaceExistedConfigGroups, req) } // 存储层操作 if err := tx.DeleteNamespace(namespace.Name); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } @@ -262,19 +249,16 @@ func (s *Server) UpdateNamespace(ctx context.Context, req *apimodel.Namespace) * if resp != nil { return resp } - - rid := utils.ParseRequestID(ctx) // 修改 s.updateNamespaceAttribute(req, namespace) // 存储层操作 if err := s.storage.UpdateNamespace(namespace); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } - msg := fmt.Sprintf("update namespace: name=%s", namespace.Name) - log.Info(msg, utils.ZapRequestID(rid)) + log.Info("update namespace", zap.String("name", namespace.Name), utils.RequestID(ctx)) s.RecordHistory(namespaceRecordEntry(ctx, req, model.OUpdate)) if err := s.afterNamespaceResource(ctx, req, namespace, false); err != nil { @@ -301,6 +285,7 @@ func (s *Server) updateNamespaceAttribute(req *apimodel.Namespace, namespace *mo exportTo[req.GetServiceExportTo()[i].GetValue()] = struct{}{} } + namespace.Metadata = req.GetMetadata() namespace.ServiceExportTo = exportTo } @@ -314,18 +299,16 @@ func (s *Server) UpdateNamespaceToken(ctx context.Context, req *apimodel.Namespa return resp } - rid := utils.ParseRequestID(ctx) // 生成token token := utils.NewUUID() // 存储层操作 if err := s.storage.UpdateNamespaceToken(namespace.Name, token); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } - msg := fmt.Sprintf("update namespace token: name=%s", namespace.Name) - log.Info(msg, utils.ZapRequestID(rid)) + log.Info("update namespace token", zap.String("name", namespace.Name), utils.RequestID(ctx)) s.RecordHistory(namespaceRecordEntry(ctx, req, model.OUpdateToken)) out := &apimodel.Namespace{ @@ -343,7 +326,11 @@ func (s *Server) GetNamespaces(ctx context.Context, query map[string][]string) * return checkError } - namespaces, amount, err := s.storage.GetNamespaces(filter, offset, limit) + amount, namespaces, err := s.caches.Namespace().Query(ctx, &cachetypes.NamespaceArgs{ + Filter: filter, + Offset: offset, + Limit: limit, + }) if err != nil { return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -365,6 +352,9 @@ func (s *Server) GetNamespaces(ctx context.Context, query map[string][]string) * TotalInstanceCount: utils.NewUInt32Value(nsCntInfo.InstanceCnt.TotalInstanceCount), TotalHealthInstanceCount: utils.NewUInt32Value(nsCntInfo.InstanceCnt.HealthyInstanceCount), ServiceExportTo: namespace.ListServiceExportTo(), + Editable: utils.NewBoolValue(true), + Deleteable: utils.NewBoolValue(true), + Metadata: namespace.Metadata, }) totalServiceCount += nsCntInfo.ServiceCount totalInstanceCount += nsCntInfo.InstanceCnt.TotalInstanceCount @@ -435,25 +425,18 @@ func (s *Server) loadNamespace(name string) (string, error) { // 检查namespace的权限,并且返回namespace func (s *Server) checkNamespaceAuthority( ctx context.Context, req *apimodel.Namespace) (*model.Namespace, *apiservice.Response) { - rid := utils.ParseRequestID(ctx) namespaceName := req.GetName().GetValue() // namespaceToken := parseNamespaceToken(ctx, req) // 检查是否存在 namespace, err := s.storage.GetNamespace(namespaceName) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, api.NewNamespaceResponse(commonstore.StoreCode2APICode(err), req) } if namespace == nil { return nil, api.NewNamespaceResponse(apimodel.Code_NotFoundResource, req) } - - // 鉴权 - // if ok := s.authority.VerifyNamespace(namespace.Token, namespaceToken); !ok { - // return nil, api.NewNamespaceResponse(api.Unauthorized, req) - // } - return namespace, nil } diff --git a/namespace/server.go b/namespace/server.go index cefedd11f..f3f313642 100644 --- a/namespace/server.go +++ b/namespace/server.go @@ -36,7 +36,6 @@ type Server struct { caches *cache.CacheManager createNamespaceSingle *singleflight.Group cfg Config - history plugin.History hooks []ResourceHook } @@ -61,7 +60,7 @@ func (s *Server) afterNamespaceResource(ctx context.Context, req *apimodel.Names // RecordHistory server对外提供history插件的简单封装 func (s *Server) RecordHistory(entry *model.RecordEntry) { // 如果插件没有初始化,那么不记录history - if s.history == nil { + if plugin.GetHistory() == nil { return } // 如果数据为空,则不需要打印了 @@ -70,10 +69,32 @@ func (s *Server) RecordHistory(entry *model.RecordEntry) { } // 调用插件记录history - s.history.Record(entry) + plugin.GetHistory().Record(entry) } // SetResourceHooks 返回Cache func (s *Server) SetResourceHooks(hooks ...ResourceHook) { s.hooks = hooks } + +// ResourceHook The listener is placed before and after the resource operation, only normal flow +type ResourceHook interface { + + // Before + // @param ctx + // @param resourceType + Before(ctx context.Context, resourceType model.Resource) + + // After + // @param ctx + // @param resourceType + // @param res + After(ctx context.Context, resourceType model.Resource, res *ResourceEvent) error +} + +// ResourceEvent 资源事件 +type ResourceEvent struct { + ReqNamespace *apimodel.Namespace + Namespace *model.Namespace + IsRemove bool +} diff --git a/namespace/server_authability.go b/namespace/server_authability.go deleted file mode 100644 index aa7cb0a18..000000000 --- a/namespace/server_authability.go +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package namespace - -import ( - "context" - "errors" - - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - "go.uber.org/zap" - - "github.com/polarismesh/polaris/auth" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" -) - -// serverAuthAbility 带有鉴权能力的 discoverServer -// -// 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 -type serverAuthAbility struct { - targetServer *Server - userMgn auth.UserServer - policySvr auth.StrategyServer -} - -func newServerAuthAbility(targetServer *Server, - userMgn auth.UserServer, policySvr auth.StrategyServer) NamespaceOperateServer { - proxy := &serverAuthAbility{ - targetServer: targetServer, - userMgn: userMgn, - policySvr: policySvr, - } - - targetServer.SetResourceHooks(proxy) - return proxy -} - -// collectNamespaceAuthContext 对于命名空间的处理,收集所有的与鉴权的相关信息 -func (svr *serverAuthAbility) collectNamespaceAuthContext(ctx context.Context, req []*apimodel.Namespace, - resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { - return authcommon.NewAcquireContext( - authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), - authcommon.WithModule(authcommon.CoreModule), - authcommon.WithMethod(methodName), - authcommon.WithAccessResources(svr.queryNamespaceResource(req)), - ) -} - -// queryNamespaceResource 根据所给的 namespace 信息,收集对应的 ResourceEntry 列表 -func (svr *serverAuthAbility) queryNamespaceResource( - req []*apimodel.Namespace) map[apisecurity.ResourceType][]authcommon.ResourceEntry { - names := utils.NewSet[string]() - for index := range req { - names.Add(req[index].Name.GetValue()) - } - param := names.ToSlice() - nsArr := svr.targetServer.caches.Namespace().GetNamespacesByName(param) - - temp := make([]authcommon.ResourceEntry, 0, len(nsArr)) - - for index := range nsArr { - ns := nsArr[index] - temp = append(temp, authcommon.ResourceEntry{ - Type: apisecurity.ResourceType_Namespaces, - ID: ns.Name, - Owner: ns.Owner, - }) - } - - ret := map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Namespaces: temp, - } - authLog.Debug("[Auth][Server] collect namespace access res", zap.Any("res", ret)) - return ret -} - -func convertToErrCode(err error) apimodel.Code { - if errors.Is(err, authcommon.ErrorTokenNotExist) { - return apimodel.Code_TokenNotExisted - } - if errors.Is(err, authcommon.ErrorTokenDisabled) { - return apimodel.Code_TokenDisabled - } - return apimodel.Code_NotAllowedAccess -} diff --git a/namespace/test_export.go b/namespace/test_export.go deleted file mode 100644 index fbb98a16b..000000000 --- a/namespace/test_export.go +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package namespace - -import ( - "context" - - "golang.org/x/sync/singleflight" - - "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" - cachetypes "github.com/polarismesh/polaris/cache/api" - "github.com/polarismesh/polaris/plugin" - "github.com/polarismesh/polaris/store" -) - -func TestInitialize(_ context.Context, nsOpt *Config, storage store.Store, cacheMgn *cache.CacheManager, - userMgn auth.UserServer, strategyMgn auth.StrategyServer) (NamespaceOperateServer, error) { - _ = cacheMgn.OpenResourceCache(cachetypes.ConfigEntry{ - Name: cachetypes.NamespaceName, - }) - nsOpt.AutoCreate = true - namespaceServer := &Server{} - namespaceServer.caches = cacheMgn - namespaceServer.storage = storage - namespaceServer.cfg = *nsOpt - namespaceServer.createNamespaceSingle = &singleflight.Group{} - - // 获取History插件,注意:插件的配置在bootstrap已经设置好 - namespaceServer.history = plugin.GetHistory() - if namespaceServer.history == nil { - log.Warn("Not Found History Log Plugin") - } - - return newServerAuthAbility(namespaceServer, userMgn, strategyMgn), nil -} diff --git a/plugin.go b/plugin.go index 9494a723d..8b30e70c8 100644 --- a/plugin.go +++ b/plugin.go @@ -18,6 +18,7 @@ package main import ( + _ "github.com/polarismesh/polaris/admin/interceptor" _ "github.com/polarismesh/polaris/apiserver/eurekaserver" _ "github.com/polarismesh/polaris/apiserver/grpcserver/config" _ "github.com/polarismesh/polaris/apiserver/grpcserver/discover" @@ -34,6 +35,7 @@ import ( _ "github.com/polarismesh/polaris/cache/namespace" _ "github.com/polarismesh/polaris/cache/service" _ "github.com/polarismesh/polaris/config/interceptor" + _ "github.com/polarismesh/polaris/namespace/interceptor" _ "github.com/polarismesh/polaris/plugin/cmdb/memory" _ "github.com/polarismesh/polaris/plugin/crypto/aes" _ "github.com/polarismesh/polaris/plugin/discoverevent/local" diff --git a/plugin/statis/logger/statis.go b/plugin/statis/logger/statis.go index 6a8121c09..8ce86b59a 100644 --- a/plugin/statis/logger/statis.go +++ b/plugin/statis/logger/statis.go @@ -97,7 +97,7 @@ func (s *StatisWorker) ReportConfigMetrics(metric ...metrics.ConfigMetrics) { // ReportDiscoverCall report discover service times func (s *StatisWorker) ReportDiscoverCall(metric metrics.ClientDiscoverMetric) { - discoverlog.Infof(metric.String()) + discoverlog.Info(metric.String()) } func (a *StatisWorker) metricsHandle(mt metrics.CallMetricType, start time.Time, diff --git a/plugin/sync.go b/plugin/sync.go new file mode 100644 index 000000000..01a0ef11a --- /dev/null +++ b/plugin/sync.go @@ -0,0 +1,29 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package plugin + +import "github.com/polarismesh/polaris/common/model" + +// Syncer . +type Syncer interface { + Plugin + // Run + Run() error + // DebugHandlers return debug handlers + DebugHandlers() []model.DebugHandler +} diff --git a/release/build.sh b/release/build.sh index 19b8e7f0b..76e262f84 100755 --- a/release/build.sh +++ b/release/build.sh @@ -66,7 +66,7 @@ rm -f ${bin_name} export CGO_ENABLED=0 build_date=$(date "+%Y%m%d.%H%M%S") -package="github.com/polarismesh/polaris-server/common/version" +package="github.com/polarismesh/polaris/common/version" sqldb_res="store/mysql" GOARCH=${GOARCH} GOOS=${GOOS} go build -o ${bin_name} -ldflags="-X ${package}.Version=${version} -X ${package}.BuildDate=${build_date}" diff --git a/release/conf/bolt-data.yaml b/release/conf/bolt-data.yaml new file mode 100644 index 000000000..2c7827204 --- /dev/null +++ b/release/conf/bolt-data.yaml @@ -0,0 +1,174 @@ +# Tencent is pleased to support the open source community by making Polaris available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +users: + - name: polaris + token: nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I= + password: $2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum + id: 65e4789a6d5b49669adf1e9e8387549c + tokenenable: true + type: 20 + valid: true +policies: + - id: fbca9bfa04ae4ead86e1ecf5811e32a9 + name: (用户) polaris的默认策略 + action: READ_WRITE + comment: default admin + default: true + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["*"] + resources: + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 6 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 7 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 20 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 0 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 3 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 4 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 5 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 21 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 22 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 23 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 1 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 2 + resid: "*" + conditions: [] + principals: + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + principalid: 65e4789a6d5b49669adf1e9e8387549c + principaltype: 1 + valid: true + revision: fbca9bfa04ae4ead86e1ecf5811e32a9 + metadata: {} + - id: bfa04ae1e32a94fbca9ead86e1ecf581 + name: 全局只读策略 + action: ALLOW + comment: global resources read onyly + default: false + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["Describe*", "List*", "Get*"] + resources: + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 6 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 7 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 20 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 0 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 3 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 4 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 5 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 21 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 22 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 23 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 1 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 2 + resid: "*" + conditions: [] + principals: [] + valid: true + revision: 2a04ae4ead86e1e9bfacf59fbca811e3 + metadata: {} + - id: e3d86e1ecf5812bfa04ae1a94fbca9ea + name: 全局读写策略 + action: ALLOW + comment: global resources read and write + default: false + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["*"] + resources: + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 6 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 7 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 20 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 0 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 3 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 4 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 5 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 21 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 22 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 23 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 1 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 2 + resid: "*" + conditions: [] + principals: [] + valid: true + revision: 4ead86e1e9bfac2a04aef59fbca811e3 + metadata: {} diff --git a/release/conf/polaris-server.yaml b/release/conf/polaris-server.yaml index 91535a854..2d1558136 100644 --- a/release/conf/polaris-server.yaml +++ b/release/conf/polaris-server.yaml @@ -451,6 +451,7 @@ store: name: boltdbStore option: path: ./polaris.bolt + loadFile: ./conf/bolt-data.yaml ## Database storage plugin # name: defaultStore # option: diff --git a/service/api.go b/service/api.go index b37f19f86..676f806cb 100644 --- a/service/api.go +++ b/service/api.go @@ -20,22 +20,38 @@ package service import ( "context" + apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/common/api/l5" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) // DiscoverServer Server discovered by the service type DiscoverServer interface { - // DiscoverServerV1 DiscoverServerV1 - DiscoverServerV1 + // CircuitBreakerOperateServer Fuse rule operation interface definition + CircuitBreakerOperateServer + // RateLimitOperateServer Lamflow rule operation interface definition + RateLimitOperateServer + // RouteRuleOperateServer Routing rules operation interface definition + RouteRuleOperateServer + // RouterRuleOperateServer Routing rules operation interface definition + RouterRuleOperateServer + // FaultDetectRuleOperateServer fault detect rules operation interface definition + FaultDetectRuleOperateServer + // ServiceContractOperateServer service contract rules operation inerface definition + ServiceContractOperateServer // ServiceAliasOperateServer Service alias operation interface definition ServiceAliasOperateServer // ServiceOperateServer Service operation interface definition ServiceOperateServer // InstanceOperateServer Instance Operation Interface Definition InstanceOperateServer + // LaneOperateServer lane rule operation interface definition + LaneOperateServer // ClientServer Client operation interface definition ClientServer // Cache Get cache management @@ -46,6 +62,240 @@ type DiscoverServer interface { GetServiceInstanceRevision(serviceID string, instances []*model.Instance) (string, error) } +// CircuitBreakerOperateServer Melting rule related treatment +type CircuitBreakerOperateServer interface { + // CreateCircuitBreakers Create a CircuitBreaker rule + // Deprecated: not support from 1.14.x + CreateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // CreateCircuitBreakerVersions Create a melt rule version + // Deprecated: not support from 1.14.x + CreateCircuitBreakerVersions(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // DeleteCircuitBreakers Delete CircuitBreaker rules + // Deprecated: not support from 1.14.x + DeleteCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // UpdateCircuitBreakers Modify the CircuitBreaker rule + // Deprecated: not support from 1.14.x + UpdateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse + // ReleaseCircuitBreakers Release CircuitBreaker rule + // Deprecated: not support from 1.14.x + ReleaseCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse + // UnBindCircuitBreakers Solution CircuitBreaker rule + // Deprecated: not support from 1.14.x + UnBindCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse + // GetCircuitBreaker Get CircuitBreaker regular according to ID and VERSION + // Deprecated: not support from 1.14.x + GetCircuitBreaker(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerVersions Query all versions of the CircuitBreaker rule + // Deprecated: not support from 1.14.x + GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetMasterCircuitBreakers Query Master CircuitBreaker rules + // Deprecated: not support from 1.14.x + GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetReleaseCircuitBreakers Query the released CircuitBreaker rule according to the rule ID + // Deprecated: not support from 1.14.x + GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerByService Binding CircuitBreaker rule based on service query + // Deprecated: not support from 1.14.x + GetCircuitBreakerByService(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetCircuitBreakerToken Get CircuitBreaker rules token + // Deprecated: not support from 1.14.x + GetCircuitBreakerToken(ctx context.Context, req *apifault.CircuitBreaker) *apiservice.Response + // CreateCircuitBreakerRules Create a CircuitBreaker rule + CreateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // DeleteCircuitBreakerRules Delete current CircuitBreaker rules + DeleteCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // EnableCircuitBreakerRules Enable the CircuitBreaker rule + EnableCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // UpdateCircuitBreakerRules Modify the CircuitBreaker rule + UpdateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse + // GetCircuitBreakerRules Query CircuitBreaker rules + GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RateLimitOperateServer Lamflow rule related operation +type RateLimitOperateServer interface { + // CreateRateLimits Create a RateLimit rule + CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // DeleteRateLimits Delete current RateLimit rules + DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // EnableRateLimits Enable the RateLimit rule + EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // UpdateRateLimits Modify the RateLimit rule + UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse + // GetRateLimits Query RateLimit rules + GetRateLimits(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RouteRuleOperateServer Routing rules related operations +type RouteRuleOperateServer interface { + // CreateRoutingConfigs Batch creation routing configuration + CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // DeleteRoutingConfigs Batch delete routing configuration + DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // UpdateRoutingConfigs Batch update routing configuration + UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse + // GetRoutingConfigs Inquiry route configuration to OSS + GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// ServiceOperateServer Service related operations +type ServiceOperateServer interface { + // CreateServices Batch creation service + CreateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // DeleteServices Batch delete service + DeleteServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // UpdateServices Batch update service + UpdateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse + // UpdateServiceToken Update service token + UpdateServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response + // GetServices Get a list of service + GetServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetAllServices Get all service list + GetAllServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetServicesCount Total number of services + GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse + // GetServiceToken Get service token + GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response + // GetServiceOwner Owner for obtaining service + GetServiceOwner(ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse +} + +// ServiceAliasOperateServer Service alias related operations +type ServiceAliasOperateServer interface { + // CreateServiceAlias Create a service alias + CreateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response + // DeleteServiceAliases Batch delete service alias + DeleteServiceAliases(ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse + // UpdateServiceAlias Update service alias + UpdateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response + // GetServiceAliases Get a list of service alias + GetServiceAliases(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// InstanceOperateServer Example related operations +type InstanceOperateServer interface { + // CreateInstances Batch creation instance + CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse + // DeleteInstances Batch delete instance + DeleteInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // DeleteInstancesByHost Delete instance according to HOST information batch + DeleteInstancesByHost(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // UpdateInstances Batch update instance + UpdateInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // UpdateInstancesIsolate Batch update instance isolation state + UpdateInstancesIsolate(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse + // GetInstances Get an instance list + GetInstances(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // GetInstancesCount Get an instance quantity + GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse + // GetInstanceLabels Get an instance tag under a service + GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response +} + +// ClientServer Client related operation Client operation interface definition +type ClientServer interface { + // RegisterInstance create one instance by client + RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // DeregisterInstance delete onr instance by client + DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // ReportClient Client gets geographic location information + ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response + // GetPrometheusTargets Used to obtain the ReportClient information and serve as the SD result of Prometheus + GetPrometheusTargets(ctx context.Context, query map[string]string) *model.PrometheusDiscoveryResponse + // GetServiceWithCache Used for client acquisition service information + GetServiceWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // ServiceInstancesCache Used for client acquisition service instance information + ServiceInstancesCache(ctx context.Context, filter *apiservice.DiscoverFilter, req *apiservice.Service) *apiservice.DiscoverResponse + // GetRoutingConfigWithCache User Client Get Service Routing Configuration Information + GetRoutingConfigWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetRateLimitWithCache User Client Get Service Limit Configuration Information + GetRateLimitWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetCircuitBreakerWithCache Fuse configuration information for obtaining services for clients + GetCircuitBreakerWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetFaultDetectWithCache User Client Get FaultDetect Rule Information + GetFaultDetectWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse + // GetServiceContractWithCache User Client Get ServiceContract Rule Information + GetServiceContractWithCache(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response + // UpdateInstance update one instance by client + UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response + // ReportServiceContract client report service_contract + ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response + // GetLaneRuleWithCache fetch lane rules by client + GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse +} + +// L5OperateServer L5 related operations +type L5OperateServer interface { + // SyncByAgentCmd Get routing information according to SID list + SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) + // RegisterByNameCmd Look for the corresponding SID list according to the list of service names + RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) +} + +// ReportClientOperateServer Report information operation interface on the client +type ReportClientOperateServer interface { + // GetReportClients Query the client information reported + GetReportClients(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// RouterRuleOperateServer Routing rules related operations +type RouterRuleOperateServer interface { + // CreateRoutingConfigsV2 Batch creation routing configuration + CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // DeleteRoutingConfigsV2 Batch delete routing configuration + DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // UpdateRoutingConfigsV2 Batch update routing configuration + UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse + // QueryRoutingConfigsV2 Inquiry route configuration to OSS + QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // EnableRoutings batch enable routing rules + EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse +} + +// FaultDetectRuleOperateServer Fault detect rules related operations +type FaultDetectRuleOperateServer interface { + // CreateFaultDetectRules create the fault detect rule by request + CreateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // DeleteFaultDetectRules delete the fault detect rule by request + DeleteFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // UpdateFaultDetectRules update the fault detect rule by request + UpdateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse + // GetFaultDetectRules get the fault detect rule by request + GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse +} + +// ServiceContractOperateServer service contract operations +type ServiceContractOperateServer interface { + // CreateServiceContracts . + CreateServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse + // DeleteServiceContracts . + DeleteServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse + // GetServiceContracts . + GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse + // CreateServiceContractInterfaces . + CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, + source apiservice.InterfaceDescriptor_Source) *apiservice.Response + // AppendServiceContractInterfaces . + AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, + source apiservice.InterfaceDescriptor_Source) *apiservice.Response + // DeleteServiceContractInterfaces . + DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response + // GetServiceContractVersions . + GetServiceContractVersions(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse +} + +// LaneOperateServer lane operations +type LaneOperateServer interface { + // CreateLaneGroups 批量创建泳道组 + CreateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse + // UpdateLaneGroups 批量更新泳道组 + UpdateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse + // DeleteLaneGroups 批量删除泳道组 + DeleteLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse + // GetLaneGroups 查询泳道组列表 + GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse +} + // ResourceHook The listener is placed before and after the resource operation, only normal flow type ResourceHook interface { @@ -63,7 +313,9 @@ type ResourceHook interface { // ResourceEvent 资源事件 type ResourceEvent struct { - ReqService *apiservice.Service - Service *model.Service - IsRemove bool + Resource authcommon.ResourceEntry + + AddPrincipals []authcommon.Principal + DelPrincipals []authcommon.Principal + IsRemove bool } diff --git a/service/api_v1.go b/service/api_v1.go deleted file mode 100644 index 20ac16ccd..000000000 --- a/service/api_v1.go +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package service - -import ( - "context" - - apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - - "github.com/polarismesh/polaris/common/api/l5" - "github.com/polarismesh/polaris/common/model" -) - -// CircuitBreakerOperateServer Melting rule related treatment -type CircuitBreakerOperateServer interface { - // CreateCircuitBreakers Create a CircuitBreaker rule - // Deprecated: not support from 1.14.x - CreateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // CreateCircuitBreakerVersions Create a melt rule version - // Deprecated: not support from 1.14.x - CreateCircuitBreakerVersions(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // DeleteCircuitBreakers Delete CircuitBreaker rules - // Deprecated: not support from 1.14.x - DeleteCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // UpdateCircuitBreakers Modify the CircuitBreaker rule - // Deprecated: not support from 1.14.x - UpdateCircuitBreakers(ctx context.Context, req []*apifault.CircuitBreaker) *apiservice.BatchWriteResponse - // ReleaseCircuitBreakers Release CircuitBreaker rule - // Deprecated: not support from 1.14.x - ReleaseCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse - // UnBindCircuitBreakers Solution CircuitBreaker rule - // Deprecated: not support from 1.14.x - UnBindCircuitBreakers(ctx context.Context, req []*apiservice.ConfigRelease) *apiservice.BatchWriteResponse - // GetCircuitBreaker Get CircuitBreaker regular according to ID and VERSION - // Deprecated: not support from 1.14.x - GetCircuitBreaker(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerVersions Query all versions of the CircuitBreaker rule - // Deprecated: not support from 1.14.x - GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetMasterCircuitBreakers Query Master CircuitBreaker rules - // Deprecated: not support from 1.14.x - GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetReleaseCircuitBreakers Query the released CircuitBreaker rule according to the rule ID - // Deprecated: not support from 1.14.x - GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerByService Binding CircuitBreaker rule based on service query - // Deprecated: not support from 1.14.x - GetCircuitBreakerByService(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetCircuitBreakerToken Get CircuitBreaker rules token - // Deprecated: not support from 1.14.x - GetCircuitBreakerToken(ctx context.Context, req *apifault.CircuitBreaker) *apiservice.Response - // CreateCircuitBreakerRules Create a CircuitBreaker rule - CreateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // DeleteCircuitBreakerRules Delete current CircuitBreaker rules - DeleteCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // EnableCircuitBreakerRules Enable the CircuitBreaker rule - EnableCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // UpdateCircuitBreakerRules Modify the CircuitBreaker rule - UpdateCircuitBreakerRules(ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse - // GetCircuitBreakerRules Query CircuitBreaker rules - GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RateLimitOperateServer Lamflow rule related operation -type RateLimitOperateServer interface { - // CreateRateLimits Create a RateLimit rule - CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // DeleteRateLimits Delete current RateLimit rules - DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // EnableRateLimits Enable the RateLimit rule - EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // UpdateRateLimits Modify the RateLimit rule - UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse - // GetRateLimits Query RateLimit rules - GetRateLimits(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RouteRuleOperateServer Routing rules related operations -type RouteRuleOperateServer interface { - // CreateRoutingConfigs Batch creation routing configuration - CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // DeleteRoutingConfigs Batch delete routing configuration - DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // UpdateRoutingConfigs Batch update routing configuration - UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse - // GetRoutingConfigs Inquiry route configuration to OSS - GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// ServiceOperateServer Service related operations -type ServiceOperateServer interface { - // CreateServices Batch creation service - CreateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // DeleteServices Batch delete service - DeleteServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // UpdateServices Batch update service - UpdateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse - // UpdateServiceToken Update service token - UpdateServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response - // GetServices Get a list of service - GetServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetAllServices Get all service list - GetAllServices(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetServicesCount Total number of services - GetServicesCount(ctx context.Context) *apiservice.BatchQueryResponse - // GetServiceToken Get service token - GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response - // GetServiceOwner Owner for obtaining service - GetServiceOwner(ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse -} - -// ServiceAliasOperateServer Service alias related operations -type ServiceAliasOperateServer interface { - // CreateServiceAlias Create a service alias - CreateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response - // DeleteServiceAliases Batch delete service alias - DeleteServiceAliases(ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse - // UpdateServiceAlias Update service alias - UpdateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response - // GetServiceAliases Get a list of service alias - GetServiceAliases(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// InstanceOperateServer Example related operations -type InstanceOperateServer interface { - // CreateInstances Batch creation instance - CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse - // DeleteInstances Batch delete instance - DeleteInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // DeleteInstancesByHost Delete instance according to HOST information batch - DeleteInstancesByHost(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // UpdateInstances Batch update instance - UpdateInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // UpdateInstancesIsolate Batch update instance isolation state - UpdateInstancesIsolate(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse - // GetInstances Get an instance list - GetInstances(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // GetInstancesCount Get an instance quantity - GetInstancesCount(ctx context.Context) *apiservice.BatchQueryResponse - // GetInstanceLabels Get an instance tag under a service - GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response -} - -// ClientServer Client related operation Client operation interface definition -type ClientServer interface { - // RegisterInstance create one instance by client - RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // DeregisterInstance delete onr instance by client - DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // ReportClient Client gets geographic location information - ReportClient(ctx context.Context, req *apiservice.Client) *apiservice.Response - // GetPrometheusTargets Used to obtain the ReportClient information and serve as the SD result of Prometheus - GetPrometheusTargets(ctx context.Context, query map[string]string) *model.PrometheusDiscoveryResponse - // GetServiceWithCache Used for client acquisition service information - GetServiceWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // ServiceInstancesCache Used for client acquisition service instance information - ServiceInstancesCache(ctx context.Context, filter *apiservice.DiscoverFilter, req *apiservice.Service) *apiservice.DiscoverResponse - // GetRoutingConfigWithCache User Client Get Service Routing Configuration Information - GetRoutingConfigWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetRateLimitWithCache User Client Get Service Limit Configuration Information - GetRateLimitWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetCircuitBreakerWithCache Fuse configuration information for obtaining services for clients - GetCircuitBreakerWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetFaultDetectWithCache User Client Get FaultDetect Rule Information - GetFaultDetectWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse - // GetServiceContractWithCache User Client Get ServiceContract Rule Information - GetServiceContractWithCache(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response - // UpdateInstance update one instance by client - UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response - // ReportServiceContract client report service_contract - ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response - // GetLaneRuleWithCache fetch lane rules by client - GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse -} - -// L5OperateServer L5 related operations -type L5OperateServer interface { - // SyncByAgentCmd Get routing information according to SID list - SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) - // RegisterByNameCmd Look for the corresponding SID list according to the list of service names - RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) -} - -// ReportClientOperateServer Report information operation interface on the client -type ReportClientOperateServer interface { - // GetReportClients Query the client information reported - GetReportClients(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// RouterRuleOperateServer Routing rules related operations -type RouterRuleOperateServer interface { - // CreateRoutingConfigsV2 Batch creation routing configuration - CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // DeleteRoutingConfigsV2 Batch delete routing configuration - DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // UpdateRoutingConfigsV2 Batch update routing configuration - UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse - // QueryRoutingConfigsV2 Inquiry route configuration to OSS - QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // EnableRoutings batch enable routing rules - EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse -} - -// FaultDetectRuleOperateServer Fault detect rules related operations -type FaultDetectRuleOperateServer interface { - // CreateFaultDetectRules create the fault detect rule by request - CreateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // DeleteFaultDetectRules delete the fault detect rule by request - DeleteFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // UpdateFaultDetectRules update the fault detect rule by request - UpdateFaultDetectRules(ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse - // GetFaultDetectRules get the fault detect rule by request - GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse -} - -// ServiceContractOperateServer service contract operations -type ServiceContractOperateServer interface { - // CreateServiceContracts . - CreateServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse - // DeleteServiceContracts . - DeleteServiceContracts(ctx context.Context, req []*apiservice.ServiceContract) *apiservice.BatchWriteResponse - // GetServiceContracts . - GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse - // CreateServiceContractInterfaces . - CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, - source apiservice.InterfaceDescriptor_Source) *apiservice.Response - // AppendServiceContractInterfaces . - AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, - source apiservice.InterfaceDescriptor_Source) *apiservice.Response - // DeleteServiceContractInterfaces . - DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response - // GetServiceContractVersions . - GetServiceContractVersions(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse -} - -type DiscoverServerV1 interface { - // CircuitBreakerOperateServer Fuse rule operation interface definition - CircuitBreakerOperateServer - // RateLimitOperateServer Lamflow rule operation interface definition - RateLimitOperateServer - // RouteRuleOperateServer Routing rules operation interface definition - RouteRuleOperateServer - // RouterRuleOperateServer Routing rules operation interface definition - RouterRuleOperateServer - // FaultDetectRuleOperateServer fault detect rules operation interface definition - FaultDetectRuleOperateServer - // ServiceContractOperateServer service contract rules operation inerface definition - ServiceContractOperateServer -} diff --git a/service/circuitbreaker_rule.go b/service/circuitbreaker_rule.go index fbc54b9f4..8b3045362 100644 --- a/service/circuitbreaker_rule.go +++ b/service/circuitbreaker_rule.go @@ -27,10 +27,13 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "go.uber.org/zap" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -50,20 +53,15 @@ func (s *Server) CreateCircuitBreakerRules( // CreateCircuitBreakerRule Create a CircuitBreaker rule func (s *Server) createCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkCircuitBreakerRuleParams(request, false, true); resp != nil { - return resp - } - // 构造底层数据结构 data, err := api2CircuitBreakerRule(request) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponse(apimodel.Code_ParseCircuitBreakerException) } exists, err := s.storage.HasCircuitBreakerRuleByName(data.Name, data.Namespace) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { @@ -73,96 +71,23 @@ func (s *Server) createCircuitBreakerRule( // 存储层操作 if err := s.storage.CreateCircuitBreakerRule(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } msg := fmt.Sprintf("create circuitBreaker rule: id=%v, name=%v, namespace=%v", data.ID, request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, data, model.OCreate)) - + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: request.GetId(), + Type: security.ResourceType_CircuitBreakerRules, + }, false) request.Id = data.ID return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, request) } -func checkCircuitBreakerRuleParams( - req *apifault.CircuitBreakerRule, idRequired bool, nameRequired bool) *apiservice.Response { - if req == nil { - return api.NewResponse(apimodel.Code_EmptyRequest) - } - if resp := checkCircuitBreakerRuleParamsDbLen(req); nil != resp { - return resp - } - if nameRequired && len(req.GetName()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) - } - if idRequired && len(req.GetId()) == 0 { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) - } - return nil -} - -func checkCircuitBreakerRuleParamsDbLen(req *apifault.CircuitBreakerRule) *apiservice.Response { - if err := utils.CheckDbRawStrFieldLen( - req.RuleMatcher.GetSource().GetService(), MaxDbServiceNameLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceName) - } - if err := utils.CheckDbRawStrFieldLen( - req.RuleMatcher.GetSource().GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetName(), MaxRuleName); err != nil { - return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), MaxCommentLength); err != nil { - return api.NewResponse(apimodel.Code_InvalidServiceComment) - } - return nil -} - -func circuitBreakerRuleRecordEntry(ctx context.Context, req *apifault.CircuitBreakerRule, md *model.CircuitBreakerRule, - opt model.OperationType) *model.RecordEntry { - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - entry := &model.RecordEntry{ - ResourceType: model.RCircuitBreakerRule, - ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), - Namespace: req.GetNamespace(), - OperationType: opt, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - return entry -} - -var ( - // CircuitBreakerRuleFilters filter circuitbreaker rule query parameters - CircuitBreakerRuleFilters = map[string]bool{ - "brief": true, - "offset": true, - "limit": true, - "id": true, - "name": true, - "namespace": true, - "enable": true, - "level": true, - "service": true, - "serviceNamespace": true, - "srcService": true, - "srcNamespace": true, - "dstService": true, - "dstNamespace": true, - "dstMethod": true, - "description": true, - } -) - // DeleteCircuitBreakerRules Delete current CircuitBreaker rules func (s *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { @@ -177,11 +102,7 @@ func (s *Server) DeleteCircuitBreakerRules( // deleteCircuitBreakerRule delete current CircuitBreaker rule func (s *Server) deleteCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkCircuitBreakerRuleParams(request, true, false); resp != nil { - return resp - } - resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) + resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) if resp != nil { if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundCircuitBreaker) { resp.Code = &wrappers.UInt32Value{Value: uint32(apimodel.Code_ExecuteSuccess)} @@ -191,16 +112,20 @@ func (s *Server) deleteCircuitBreakerRule( cbRuleId := &apifault.CircuitBreakerRule{Id: request.GetId()} err := s.storage.DeleteCircuitBreakerRule(request.GetId()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAnyDataResponse(apimodel.Code_ParseCircuitBreakerException, cbRuleId) } msg := fmt.Sprintf("delete circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) cbRule := &model.CircuitBreakerRule{ ID: request.GetId(), Name: request.GetName(), Namespace: request.GetNamespace()} s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.ODelete)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: request.GetId(), + Type: security.ResourceType_CircuitBreakerRules, + }, true) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } @@ -217,11 +142,7 @@ func (s *Server) EnableCircuitBreakerRules( func (s *Server) enableCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkCircuitBreakerRuleParams(request, true, false); resp != nil { - return resp - } - resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) + resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) if resp != nil { return resp } @@ -234,13 +155,13 @@ func (s *Server) enableCircuitBreakerRule( Revision: utils.NewUUID(), } if err := s.storage.EnableCircuitBreakerRule(cbRule); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return storeError2AnyResponse(err, cbRuleId) } msg := fmt.Sprintf("enable circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.OUpdate)) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) @@ -259,46 +180,42 @@ func (s *Server) UpdateCircuitBreakerRules( func (s *Server) updateCircuitBreakerRule( ctx context.Context, request *apifault.CircuitBreakerRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - if resp := checkCircuitBreakerRuleParams(request, true, true); resp != nil { - return resp - } - resp := s.checkCircuitBreakerRuleExists(request.GetId(), requestID) + resp := s.checkCircuitBreakerRuleExists(ctx, request.GetId()) if resp != nil { return resp } cbRuleId := &apifault.CircuitBreakerRule{Id: request.GetId()} cbRule, err := api2CircuitBreakerRule(request) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAnyDataResponse(apimodel.Code_ParseCircuitBreakerException, cbRuleId) } cbRule.ID = request.GetId() exists, err := s.storage.HasCircuitBreakerRuleByNameExcludeId(cbRule.Name, cbRule.Namespace, cbRule.ID) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } if exists { return api.NewResponse(apimodel.Code_ServiceExistedCircuitBreakers) } if err := s.storage.UpdateCircuitBreakerRule(cbRule); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return storeError2AnyResponse(err, cbRuleId) } msg := fmt.Sprintf("update circuitbreaker rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, circuitBreakerRuleRecordEntry(ctx, request, cbRule, model.OUpdate)) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } -func (s *Server) checkCircuitBreakerRuleExists(id, requestID string) *apiservice.Response { +func (s *Server) checkCircuitBreakerRuleExists(ctx context.Context, id string) *apiservice.Response { exists, err := s.storage.HasCircuitBreakerRule(id) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewResponse(commonstore.StoreCode2APICode(err)) } if !exists { @@ -309,24 +226,10 @@ func (s *Server) checkCircuitBreakerRuleExists(id, requestID string) *apiservice // GetCircuitBreakerRules Query CircuitBreaker rules func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - offset, limit, err := utils.ParseOffsetAndLimit(query) + offset, limit, _ := utils.ParseOffsetAndLimit(query) + total, cbRules, err := s.storage.GetCircuitBreakerRules(query, offset, limit) if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - searchFilter := make(map[string]string, len(query)) - for key, value := range query { - if _, ok := CircuitBreakerRuleFilters[key]; !ok { - log.Errorf("params %s is not allowed in querying circuitbreaker rule", key) - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - if value == "" { - continue - } - searchFilter[key] = value - } - total, cbRules, err := s.storage.GetCircuitBreakerRules(searchFilter, offset, limit) - if err != nil { - log.Errorf("get circuitbreaker rules store err: %s", err.Error()) + log.Error("get circuitbreaker rules store", utils.RequestID(ctx), zap.Error(err)) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } out := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) @@ -335,7 +238,7 @@ func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]st for _, cbRule := range cbRules { cbRuleProto, err := circuitBreakerRule2api(cbRule) if nil != err { - log.Errorf("marshal circuitbreaker rule fail: %v", err) + log.Error("marshal circuitbreaker rule fail", utils.RequestID(ctx), zap.Error(err)) continue } if nil == cbRuleProto { @@ -343,13 +246,34 @@ func (s *Server) GetCircuitBreakerRules(ctx context.Context, query map[string]st } err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto) if nil != err { - log.Errorf("add circuitbreaker rule as any data fail: %v", err) + log.Error("add circuitbreaker rule as any data fail", utils.RequestID(ctx), zap.Error(err)) continue } } return out } +// GetAllCircuitBreakerRules Query all router_rule rules +func (s *Server) GetAllCircuitBreakerRules(ctx context.Context) *apiservice.BatchQueryResponse { + return nil +} + +func circuitBreakerRuleRecordEntry(ctx context.Context, req *apifault.CircuitBreakerRule, md *model.CircuitBreakerRule, + opt model.OperationType) *model.RecordEntry { + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + entry := &model.RecordEntry{ + ResourceType: model.RCircuitBreakerRule, + ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), + Namespace: req.GetNamespace(), + OperationType: opt, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + return entry +} + func marshalCircuitBreakerRuleV2(req *apifault.CircuitBreakerRule) (string, error) { r := &apifault.CircuitBreakerRule{ RuleMatcher: req.RuleMatcher, diff --git a/service/circuitbreaker_rule_test.go b/service/circuitbreaker_rule_test.go index 07d9df31a..9e8ec5bab 100644 --- a/service/circuitbreaker_rule_test.go +++ b/service/circuitbreaker_rule_test.go @@ -316,7 +316,7 @@ func TestUpdateCircuitBreakerRule(t *testing.T) { testRule := cbRules[0] testRule.Description = mockDescr resp = discoverSuit.DiscoverServer().UpdateCircuitBreakerRules(discoverSuit.DefaultCtx, []*apifault.CircuitBreakerRule{testRule}) - assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), resp.GetCode().GetValue()) + assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), resp.GetCode().GetValue(), resp.GetInfo().GetValue()) qResp = queryCircuitBreakerRules(discoverSuit, map[string]string{"id": testRule.Id}) assert.Equal(t, uint32(apimodel.Code_ExecuteSuccess), qResp.GetCode().GetValue()) diff --git a/service/client_info.go b/service/client_info.go index bf2ea3854..105b79e57 100644 --- a/service/client_info.go +++ b/service/client_info.go @@ -80,12 +80,9 @@ func (s *Server) createClient(ctx context.Context, req *apiservice.Client) (*mod // req 原始请求 // ins 包含了req数据与instanceID,serviceToken func (s *Server) asyncCreateClient(ctx context.Context, req *apiservice.Client) (*model.Client, *apiservice.Response) { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) future := s.bc.AsyncRegisterClient(req) if err := future.Wait(); err != nil { - log.Error("[Server][ReportClient] async create client", zap.Error(err), utils.ZapRequestID(rid), - utils.ZapPlatformID(pid)) + log.Error("[Server][ReportClient] async create client", zap.Error(err), utils.RequestID(ctx)) if future.Code() == apimodel.Code_ExistedResource { req.Id = utils.NewStringValue(req.GetId().GetValue()) } diff --git a/service/client_v1.go b/service/client_v1.go index ec5488852..03924acc2 100644 --- a/service/client_v1.go +++ b/service/client_v1.go @@ -175,9 +175,9 @@ func (s *Server) GetServiceWithCache(ctx context.Context, req *apiservice.Servic ) if req.GetNamespace().GetValue() != "" { - revision, svcs = s.Cache().Service().ListServices(req.GetNamespace().GetValue()) + revision, svcs = s.Cache().Service().ListServices(ctx, req.GetNamespace().GetValue()) } else { - revision, svcs = s.Cache().Service().ListAllServices() + revision, svcs = s.Cache().Service().ListAllServices(ctx) } if revision == "" { return resp @@ -226,12 +226,29 @@ func (s *Server) ServiceInstancesCache(ctx context.Context, filter *apiservice.D revisions := make([]string, 0, len(visibleServices)+1) finalInstances := make(map[string]*apiservice.Instance, 128) for _, svc := range visibleServices { + specSvc := &apiservice.Service{ + Id: utils.NewStringValue(svc.ID), + Name: utils.NewStringValue(svc.Name), + Namespace: utils.NewStringValue(svc.Namespace), + } + ret := s.caches.Instance().DiscoverServiceInstances(specSvc.GetId().GetValue(), filter.GetOnlyHealthyInstance()) + // 如果是空实例,则直接跳过,不处理实例列表以及 revision 信息 + if len(ret) == 0 { + continue + } revision := s.caches.Service().GetRevisionWorker().GetServiceInstanceRevision(svc.ID) if revision == "" { revision = utils.NewUUID() } revisions = append(revisions, revision) + + for i := range ret { + copyIns := s.getInstance(specSvc, ret[i].Proto) + // 注意:这里的value是cache的,不修改cache的数据,通过getInstance,浅拷贝一份数据 + finalInstances[copyIns.GetId().GetValue()] = copyIns + } } + aggregateRevision, err := cachetypes.CompositeComputeRevision(revisions) if err != nil { log.Errorf("[Server][Service][Instance] compute multi revision service(%s) err: %s", @@ -242,20 +259,6 @@ func (s *Server) ServiceInstancesCache(ctx context.Context, filter *apiservice.D return api.NewDiscoverInstanceResponse(apimodel.Code_DataNoChange, req) } - for _, svc := range visibleServices { - specSvc := &apiservice.Service{ - Id: utils.NewStringValue(svc.ID), - Name: utils.NewStringValue(svc.Name), - Namespace: utils.NewStringValue(svc.Namespace), - } - ret := s.caches.Instance().DiscoverServiceInstances(specSvc.GetId().GetValue(), filter.GetOnlyHealthyInstance()) - for i := range ret { - copyIns := s.getInstance(specSvc, ret[i].Proto) - // 注意:这里的value是cache的,不修改cache的数据,通过getInstance,浅拷贝一份数据 - finalInstances[copyIns.GetId().GetValue()] = copyIns - } - } - // 填充service数据 resp.Service = service2Api(aliasFor) // 这里需要把服务信息改为用户请求的服务名以及命名空间 @@ -277,20 +280,13 @@ func (s *Server) findVisibleServices(serviceName, namespaceName string, req *api visibleServices := make([]*model.Service, 0, 4) // 数据源都来自Cache,这里拿到的service,已经是源服务 aliasFor := s.getServiceCache(serviceName, namespaceName) - if aliasFor == nil { - aliasFor = &model.Service{ - Name: serviceName, - Namespace: namespaceName, - } - ret := s.caches.Service().GetVisibleServicesInOtherNamespace(serviceName, namespaceName) - if len(ret) == 0 { - return nil, nil - } - visibleServices = append(visibleServices, ret...) - } else { + if aliasFor != nil { visibleServices = append(visibleServices, aliasFor) } - + ret := s.caches.Service().GetVisibleServicesInOtherNamespace(serviceName, namespaceName) + if len(ret) > 0 { + visibleServices = append(visibleServices, ret...) + } return aliasFor, visibleServices } diff --git a/service/common_test.go b/service/common_test.go index 56262c522..c17ddc342 100644 --- a/service/common_test.go +++ b/service/common_test.go @@ -350,7 +350,7 @@ func (d *DiscoverTestSuit) createCommonRoutingConfig( // TODO 是否应该先删除routing resp := d.DiscoverServer().CreateRoutingConfigs(d.DefaultCtx, []*apitraffic.Routing{conf}) - if !respSuccess(resp) { + if respSuccess(resp) { t.Fatalf("error: %+v", resp) } @@ -400,7 +400,7 @@ func (d *DiscoverTestSuit) createCommonRoutingConfigV1IntoOldStore(t *testing.T, } resp := d.OriginDiscoverServer().(*service.Server).CreateRoutingConfig(d.DefaultCtx, conf) - if !respSuccess(resp) { + if respSuccess(resp) { t.Fatalf("error: %+v", resp) } diff --git a/service/default.go b/service/default.go index 5609dc9e9..1efa4bc14 100644 --- a/service/default.go +++ b/service/default.go @@ -28,6 +28,7 @@ import ( "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/plugin" + "github.com/polarismesh/polaris/store" ) const ( @@ -48,7 +49,7 @@ const ( DefaultTLL = 5 ) -type ServerProxyFactory func(pre DiscoverServer) (DiscoverServer, error) +type ServerProxyFactory func(pre DiscoverServer, s store.Store) (DiscoverServer, error) var ( server DiscoverServer @@ -135,7 +136,7 @@ func InitServer(ctx context.Context, namingOpt *Config, opts ...InitOption) (*Se return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) } - afterSvr, err := factory(proxySvr) + afterSvr, err := factory(proxySvr, actualSvr.storage) if err != nil { return nil, nil, err } @@ -144,6 +145,10 @@ func InitServer(ctx context.Context, namingOpt *Config, opts ...InitOption) (*Se return actualSvr, proxySvr, nil } +func (svr *Server) Initialize(context.Context, store.Store) error { + return nil +} + type PluginInstanceEventHandler struct { *BaseInstanceEventHandler subscriber plugin.DiscoverChannel diff --git a/service/faultdetect_config.go b/service/faultdetect_config.go index 32ab8d5de..23305de47 100644 --- a/service/faultdetect_config.go +++ b/service/faultdetect_config.go @@ -27,12 +27,14 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "go.uber.org/zap" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -40,9 +42,9 @@ import ( // CreateFaultDetectRules Create a FaultDetect rule func (s *Server) CreateFaultDetectRules( - ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range request { + for _, cbRule := range reqs { response := s.createFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -51,10 +53,10 @@ func (s *Server) CreateFaultDetectRules( // DeleteFaultDetectRules Delete current Fault Detect rules func (s *Server) DeleteFaultDetectRules( - ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range request { + for _, cbRule := range reqs { response := s.deleteFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -63,10 +65,10 @@ func (s *Server) DeleteFaultDetectRules( // UpdateFaultDetectRules Modify the FaultDetect rule func (s *Server) UpdateFaultDetectRules( - ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { + ctx context.Context, reqs []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) - for _, cbRule := range request { + for _, cbRule := range reqs { response := s.updateFaultDetectRule(ctx, cbRule) api.Collect(responses, response) } @@ -117,7 +119,10 @@ func (s *Server) createFaultDetectRule(ctx context.Context, request *apifault.Fa log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, data, model.OCreate)) - + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: request.GetId(), + Type: security.ResourceType_FaultDetectRules, + }, false) request.Id = data.ID return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, request) } @@ -154,54 +159,31 @@ func (s *Server) updateFaultDetectRule(ctx context.Context, request *apifault.Fa // deleteFaultDetectRule Delete a FaultDetect rule func (s *Server) deleteFaultDetectRule(ctx context.Context, request *apifault.FaultDetectRule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - resp := s.checkFaultDetectRuleExists(request.GetId(), requestID) - if resp != nil { - if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundResource) { - resp.Code = &wrappers.UInt32Value{Value: uint32(apimodel.Code_ExecuteSuccess)} - } - return resp - } cbRuleId := &apifault.FaultDetectRule{Id: request.GetId()} err := s.storage.DeleteFaultDetectRule(request.GetId()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewAnyDataResponse(apimodel.Code_ParseException, cbRuleId) } msg := fmt.Sprintf("delete fault detect rule: id=%v, name=%v, namespace=%v", request.GetId(), request.GetName(), request.GetNamespace()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) cbRule := &model.FaultDetectRule{ID: request.GetId(), Name: request.GetName(), Namespace: request.GetNamespace()} s.RecordHistory(ctx, faultDetectRuleRecordEntry(ctx, request, cbRule, model.ODelete)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: request.GetId(), + Type: security.ResourceType_FaultDetectRules, + }, true) return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, cbRuleId) } -func (s *Server) checkFaultDetectRuleExists(id, requestID string) *apiservice.Response { - exists, err := s.storage.HasFaultDetectRule(id) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) - return api.NewResponse(commonstore.StoreCode2APICode(err)) - } - if !exists { - return api.NewResponse(apimodel.Code_NotFoundResource) - } - return nil -} - func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { offset, limit, _ := utils.ParseOffsetAndLimit(query) total, cbRules, err := s.caches.FaultDetector().Query(ctx, &cachetypes.FaultDetectArgs{ - ID: query["id"], - Name: query["name"], - Namespace: query["namespace"], - Service: query["service"], - ServiceNamespace: query["serviceNamespace"], - DstNamespace: query["dstNamespace"], - DstService: query["dstService"], - DstMethod: query["dstMethod"], - Offset: offset, - Limit: limit, + Filter: query, + Offset: offset, + Limit: limit, }) if err != nil { log.Errorf("get fault detect rules store err: %s", err.Error()) @@ -219,8 +201,7 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin if nil == cbRuleProto { continue } - err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto) - if nil != err { + if err = api.AddAnyDataIntoBatchQuery(out, cbRuleProto); nil != err { log.Error("add circuitbreaker rule as any data fail", utils.RequestID(ctx), zap.Error(err)) continue } @@ -228,6 +209,11 @@ func (s *Server) GetFaultDetectRules(ctx context.Context, query map[string]strin return out } +// GetAllFaultDetectRules Query all router_rule rules +func (s *Server) GetAllFaultDetectRules(ctx context.Context) *apiservice.BatchQueryResponse { + return nil +} + func marshalFaultDetectRule(req *apifault.FaultDetectRule) (string, error) { r := &apifault.FaultDetectRule{ TargetService: req.TargetService, @@ -262,6 +248,7 @@ func api2FaultDetectRule(req *apifault.FaultDetectRule) (*model.FaultDetectRule, DstMethod: req.GetTargetService().GetMethod().GetValue().GetValue(), Rule: rule, Revision: utils.NewUUID(), + Metadata: req.Metadata, } if out.Namespace == "" { out.Namespace = DefaultNamespace @@ -273,27 +260,29 @@ func faultDetectRule2api(fdRule *model.FaultDetectRule) (*apifault.FaultDetectRu if fdRule == nil { return nil, nil } - fdRule.Proto = &apifault.FaultDetectRule{} + specData := &apifault.FaultDetectRule{} if len(fdRule.Rule) > 0 { - if err := json.Unmarshal([]byte(fdRule.Rule), fdRule.Proto); err != nil { + if err := json.Unmarshal([]byte(fdRule.Rule), specData); err != nil { return nil, err } } else { // brief search, to display the services in list result - fdRule.Proto.TargetService = &apifault.FaultDetectRule_DestinationService{ + specData.TargetService = &apifault.FaultDetectRule_DestinationService{ Service: fdRule.DstService, Namespace: fdRule.DstNamespace, Method: &apimodel.MatchString{Value: &wrappers.StringValue{Value: fdRule.DstMethod}}, } } - fdRule.Proto.Id = fdRule.ID - fdRule.Proto.Name = fdRule.Name - fdRule.Proto.Namespace = fdRule.Namespace - fdRule.Proto.Description = fdRule.Description - fdRule.Proto.Revision = fdRule.Revision - fdRule.Proto.Ctime = commontime.Time2String(fdRule.CreateTime) - fdRule.Proto.Mtime = commontime.Time2String(fdRule.ModifyTime) - return fdRule.Proto, nil + specData.Id = fdRule.ID + specData.Name = fdRule.Name + specData.Namespace = fdRule.Namespace + specData.Description = fdRule.Description + specData.Revision = fdRule.Revision + specData.Ctime = commontime.Time2String(fdRule.CreateTime) + specData.Mtime = commontime.Time2String(fdRule.ModifyTime) + specData.Editable = true + specData.Deleteable = true + return specData, nil } // faultDetectRule2ClientAPI 把内部数据结构转化为客户端API参数 diff --git a/service/instance.go b/service/instance.go index 13681b7f7..aec69fa05 100644 --- a/service/instance.go +++ b/service/instance.go @@ -592,6 +592,10 @@ func (s *Server) updateInstanceAttribute( } func instanceLocationNeedUpdate(req *apimodel.Location, old *apimodel.Location) bool { + // 如果没有带上,则不进行更新 + if req == nil { + return false + } if req.GetRegion().GetValue() != old.GetRegion().GetValue() { return true } @@ -863,7 +867,7 @@ func (s *Server) instanceAuth(ctx context.Context, req *apiservice.Instance, ser *model.Service, *apiservice.Response) { service, err := s.storage.GetServiceByID(serviceID) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(utils.ParseRequestID(ctx))) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } if service == nil { diff --git a/service/instance_test.go b/service/instance_test.go index a5d25008b..69d5b2bd2 100644 --- a/service/instance_test.go +++ b/service/instance_test.go @@ -324,15 +324,7 @@ func TestUpdateInstanceManyTimes(t *testing.T) { go func(index int) { defer wg.Done() for c := 0; c < 16; c++ { - marshalVal, err := proto.Marshal(instanceReq) - if err != nil { - errs <- err - return - } - - ret := &apiservice.Instance{} - proto.Unmarshal(marshalVal, ret) - + ret := proto.Clone(instanceReq).(*apiservice.Instance) ret.Weight = wrapperspb.UInt32(uint32(rand.Int() % 32767)) if updateResp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{instanceReq}); !respSuccess(updateResp) { errs <- fmt.Errorf("error: %+v", updateResp) diff --git a/service/interceptor/auth/circuitbreaker_rule.go b/service/interceptor/auth/circuitbreaker_rule.go index 99201c302..99018c85b 100644 --- a/service/interceptor/auth/circuitbreaker_rule.go +++ b/service/interceptor/auth/circuitbreaker_rule.go @@ -22,7 +22,10 @@ import ( apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -33,7 +36,8 @@ import ( func (svr *Server) CreateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Create, authcommon.CreateCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Create, + authcommon.CreateCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) @@ -46,7 +50,8 @@ func (svr *Server) CreateCircuitBreakerRules( func (svr *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Delete, authcommon.DeleteCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Delete, + authcommon.DeleteCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -58,7 +63,8 @@ func (svr *Server) DeleteCircuitBreakerRules( func (svr *Server) EnableCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.EnableCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Modify, + authcommon.EnableCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -70,7 +76,8 @@ func (svr *Server) EnableCircuitBreakerRules( func (svr *Server) UpdateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, request, authcommon.Modify, authcommon.UpdateCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2(ctx, request, authcommon.Modify, + authcommon.UpdateCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -82,7 +89,8 @@ func (svr *Server) UpdateCircuitBreakerRules( func (svr *Server) GetCircuitBreakerRules( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectCircuitBreakerRuleV2AuthContext(ctx, nil, authcommon.Read, authcommon.DescribeCircuitBreakerRules) + authCtx := svr.collectCircuitBreakerRuleV2(ctx, nil, authcommon.Read, + authcommon.DescribeCircuitBreakerRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } @@ -90,13 +98,48 @@ func (svr *Server) GetCircuitBreakerRules( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendCircuitBreakerRulePredicate(ctx, func(ctx context.Context, cbr *model.CircuitBreakerRule) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_CircuitBreakerRules, - ID: cbr.ID, - Metadata: cbr.Proto.Metadata, + ctx = cachetypes.AppendCircuitBreakerRulePredicate(ctx, + func(ctx context.Context, cbr *model.CircuitBreakerRule) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_CircuitBreakerRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, + }) + }) + authCtx.SetRequestContext(ctx) + + resp := svr.nextSvr.GetCircuitBreakerRules(ctx, query) + + for index := range resp.Data { + item := &apifault.CircuitBreakerRule{} + _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_CircuitBreakerRules: { + { + Type: apisecurity.ResourceType_CircuitBreakerRules, + ID: item.GetId(), + Metadata: item.Metadata, + }, + }, }) - }) - return svr.nextSvr.GetCircuitBreakerRules(ctx, query) + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{ + authcommon.UpdateCircuitBreakerRules, + authcommon.EnableCircuitBreakerRules, + }) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = false + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteCircuitBreakerRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = false + } + _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) + } + return resp } diff --git a/service/interceptor/auth/client_v1.go b/service/interceptor/auth/client_v1.go index fb1fc8dec..4dd6a0fb3 100644 --- a/service/interceptor/auth/client_v1.go +++ b/service/interceptor/auth/client_v1.go @@ -20,9 +20,11 @@ package service_auth import ( "context" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "google.golang.org/protobuf/types/known/wrapperspb" + cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -51,10 +53,8 @@ func (svr *Server) DeregisterInstance(ctx context.Context, req *apiservice.Insta authCtx := svr.collectClientInstanceAuthContext( ctx, []*apiservice.Instance{req}, authcommon.Create, authcommon.DeregisterInstance) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -76,10 +76,8 @@ func (svr *Server) ReportServiceContract(ctx context.Context, req *apiservice.Se Namespace: wrapperspb.String(req.GetNamespace()), }}, authcommon.Create, authcommon.ReportServiceContract) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -100,15 +98,21 @@ func (svr *Server) GetServiceWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverServices) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + }) + authCtx.SetRequestContext(ctx) + return svr.nextSvr.GetServiceWithCache(ctx, req) } @@ -118,11 +122,8 @@ func (svr *Server) ServiceInstancesCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverInstances) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -136,11 +137,8 @@ func (svr *Server) GetRoutingConfigWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRouterRule) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -154,11 +152,8 @@ func (svr *Server) GetRateLimitWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverRateLimitRule) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -172,11 +167,8 @@ func (svr *Server) GetCircuitBreakerWithCache( authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverCircuitBreakerRule) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -184,16 +176,14 @@ func (svr *Server) GetCircuitBreakerWithCache( return svr.nextSvr.GetCircuitBreakerWithCache(ctx, req) } +// GetFaultDetectWithCache 获取主动探测规则列表 func (svr *Server) GetFaultDetectWithCache( ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverFaultDetectRule) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) @@ -206,10 +196,8 @@ func (svr *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) authCtx := svr.collectClientInstanceAuthContext( ctx, []*apiservice.Instance{req}, authcommon.Modify, authcommon.UpdateInstance) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -226,11 +214,8 @@ func (svr *Server) GetServiceContractWithCache(ctx context.Context, Name: wrapperspb.String(req.Service), }}, authcommon.Read, authcommon.DiscoverServiceContract) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -243,11 +228,8 @@ func (svr *Server) GetServiceContractWithCache(ctx context.Context, func (svr *Server) GetLaneRuleWithCache(ctx context.Context, req *apiservice.Service) *apiservice.DiscoverResponse { authCtx := svr.collectServiceAuthContext( ctx, []*apiservice.Service{req}, authcommon.Read, authcommon.DiscoverLaneRule) - _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx) - if err != nil { - resp := api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) - resp.Info = utils.NewStringValue(err.Error()) - return resp + if _, err := svr.policySvr.GetAuthChecker().CheckClientPermission(authCtx); err != nil { + return api.NewDiscoverResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) diff --git a/service/interceptor/auth/faultdetect_config.go b/service/interceptor/auth/faultdetect_config.go index ef4b09d76..e3401be52 100644 --- a/service/interceptor/auth/faultdetect_config.go +++ b/service/interceptor/auth/faultdetect_config.go @@ -22,7 +22,10 @@ import ( apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -34,7 +37,7 @@ import ( func (svr *Server) CreateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.CreateFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Create, authcommon.CreateFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -46,7 +49,7 @@ func (svr *Server) CreateFaultDetectRules( func (svr *Server) DeleteFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.DeleteFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Delete, authcommon.DeleteFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -58,7 +61,7 @@ func (svr *Server) DeleteFaultDetectRules( func (svr *Server) UpdateFaultDetectRules( ctx context.Context, request []*apifault.FaultDetectRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Read, authcommon.UpdateFaultDetectRules) + authCtx := svr.collectFaultDetectAuthContext(ctx, request, authcommon.Modify, authcommon.UpdateFaultDetectRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -76,13 +79,44 @@ func (svr *Server) GetFaultDetectRules( ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendFaultDetectRulePredicate(ctx, func(ctx context.Context, cbr *model.FaultDetectRule) bool { + ctx = cachetypes.AppendFaultDetectRulePredicate(ctx, func(ctx context.Context, cbr *model.FaultDetectRule) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_FaultDetectRules, ID: cbr.ID, - Metadata: cbr.Proto.Metadata, + Metadata: cbr.Proto.GetMetadata(), }) }) + authCtx.SetRequestContext(ctx) - return svr.nextSvr.GetFaultDetectRules(ctx, query) + resp := svr.nextSvr.GetFaultDetectRules(ctx, query) + + for index := range resp.Data { + item := &apifault.FaultDetectRule{} + _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_FaultDetectRules: { + { + Type: apisecurity.ResourceType_FaultDetectRules, + ID: item.GetId(), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateFaultDetectRules, authcommon.EnableFaultDetectRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = false + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteFaultDetectRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = false + } + _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) + } + return resp } diff --git a/service/interceptor/auth/instance.go b/service/interceptor/auth/instance.go index 071d9f856..58fff9a8a 100644 --- a/service/interceptor/auth/instance.go +++ b/service/interceptor/auth/instance.go @@ -26,6 +26,7 @@ import ( api "github.com/polarismesh/polaris/common/api/v1" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/service" ) // CreateInstances create instances @@ -33,8 +34,7 @@ func (svr *Server) CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateInstances) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { resp := api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) batchResp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) api.Collect(batchResp, resp) @@ -66,7 +66,7 @@ func (svr *Server) DeleteInstances(ctx context.Context, return svr.nextSvr.DeleteInstances(ctx, reqs) } -// DeleteInstancesByHost 目前只允许 super account 进行数据删除 +// DeleteInstancesByHost 根据 host 信息进行数据删除 func (svr *Server) DeleteInstancesByHost(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { authCtx := svr.collectInstanceAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteInstancesByHost) @@ -145,10 +145,38 @@ func (svr *Server) GetInstancesCount(ctx context.Context) *apiservice.BatchQuery return svr.nextSvr.GetInstancesCount(ctx) } +// GetInstanceLabels 获取某个服务下的实例标签集合 func (svr *Server) GetInstanceLabels(ctx context.Context, query map[string]string) *apiservice.Response { - authCtx := svr.collectInstanceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeInstanceLabels) + var ( + serviceId string + namespace = service.DefaultNamespace + ) + + if val, ok := query["namespace"]; ok { + namespace = val + } + + if svcName, ok := query["service"]; ok { + if svc := svr.Cache().Service().GetServiceByName(svcName, namespace); svc != nil { + serviceId = svc.ID + } + } + + if id, ok := query["service_id"]; ok { + serviceId = id + } + + // TODO 如果在鉴权的时候发现资源不存在,怎么处理? + svc := svr.Cache().Service().GetServiceByID(serviceId) + if svc == nil { + return api.NewResponse(apimodel.Code_NotFoundResource) + } + + authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{ + svc.ToSpec(), + }, authcommon.Read, authcommon.DescribeInstanceLabels) _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) if err != nil { return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) diff --git a/service/interceptor/auth/lane.go b/service/interceptor/auth/lane.go new file mode 100644 index 000000000..e5d7f109e --- /dev/null +++ b/service/interceptor/auth/lane.go @@ -0,0 +1,147 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service_auth + +import ( + "context" + + "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + "github.com/polarismesh/polaris/common/utils" +) + +// CreateLaneGroups 批量创建泳道组 +func (svr *Server) CreateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + + authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateLaneGroups) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + } + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.CreateLaneGroups(ctx, reqs) +} + +// UpdateLaneGroups 批量更新泳道组 +func (svr *Server) UpdateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateLaneGroups) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + } + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.UpdateLaneGroups(ctx, reqs) +} + +// DeleteLaneGroups 批量删除泳道组 +func (svr *Server) DeleteLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + authCtx := svr.collectLaneRuleAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteLaneGroups) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) + } + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + return svr.nextSvr.DeleteLaneGroups(ctx, reqs) +} + +// GetLaneGroups 查询泳道组列表 +func (svr *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { + authCtx := svr.collectFaultDetectAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeFaultDetectRules) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) + } + ctx = authCtx.GetRequestContext() + ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + + ctx = cachetypes.AppendLaneRulePredicate(ctx, func(ctx context.Context, cbr *model.LaneGroupProto) bool { + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_LaneRules, + ID: cbr.ID, + Metadata: cbr.Proto.Metadata, + }) + }) + authCtx.SetRequestContext(ctx) + + resp := svr.nextSvr.GetLaneGroups(ctx, filter) + + for index := range resp.Data { + item := &apitraffic.LaneGroup{} + _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_LaneRules: { + { + Type: apisecurity.ResourceType_LaneRules, + ID: item.GetId(), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateLaneGroups, authcommon.EnableLaneGroups}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = false + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteLaneGroups}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = false + } + _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) + } + return resp +} + +// collectLaneRuleAuthContext 收集全链路灰度规则 +func (svr *Server) collectLaneRuleAuthContext(ctx context.Context, req []*apitraffic.LaneGroup, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { + + resources := make([]authcommon.ResourceEntry, 0, len(req)) + for i := range req { + saveRule := svr.Cache().LaneRule().GetRule(req[i].GetId()) + if saveRule != nil { + resources = append(resources, authcommon.ResourceEntry{ + Type: apisecurity.ResourceType_LaneRules, + ID: saveRule.ID, + Metadata: saveRule.Labels, + }) + } + } + + return authcommon.NewAcquireContext( + authcommon.WithRequestContext(ctx), + authcommon.WithOperation(op), + authcommon.WithModule(authcommon.DiscoverModule), + authcommon.WithMethod(methodName), + authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_LaneRules: resources, + }), + ) +} diff --git a/service/interceptor/auth/ratelimit_config.go b/service/interceptor/auth/ratelimit_config.go index f41a09cac..2eea1fddb 100644 --- a/service/interceptor/auth/ratelimit_config.go +++ b/service/interceptor/auth/ratelimit_config.go @@ -20,8 +20,8 @@ package service_auth import ( "context" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" @@ -37,9 +37,8 @@ func (svr *Server) CreateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Create, authcommon.CreateRateLimitRules) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -53,9 +52,8 @@ func (svr *Server) DeleteRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Delete, authcommon.DeleteRateLimitRules) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -69,9 +67,8 @@ func (svr *Server) UpdateRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Modify, authcommon.UpdateRateLimitRules) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -83,11 +80,10 @@ func (svr *Server) UpdateRateLimits( // EnableRateLimits 启用限流规则 func (svr *Server) EnableRateLimits( ctx context.Context, reqs []*apitraffic.Rule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.EnableRateLimitRules) + authCtx := svr.collectRateLimitAuthContext(ctx, reqs, authcommon.Read, authcommon.EnableRateLimitRules) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -101,21 +97,50 @@ func (svr *Server) GetRateLimits( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { authCtx := svr.collectRateLimitAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeRateLimitRules) - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + return api.NewAuthBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendRatelimitRulePredicate(ctx, func(ctx context.Context, cbr *model.RateLimit) bool { + ctx = cachetypes.AppendRatelimitRulePredicate(ctx, func(ctx context.Context, cbr *model.RateLimit) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_RateLimitRules, ID: cbr.ID, Metadata: cbr.Proto.Metadata, }) }) + authCtx.SetRequestContext(ctx) + + resp := svr.nextSvr.GetRateLimits(ctx, query) + + for index := range resp.RateLimits { + item := resp.RateLimits[index] + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_RateLimitRules: { + { + Type: apisecurity.ResourceType_RateLimitRules, + ID: item.GetId().GetValue(), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRateLimitRules, authcommon.EnableRateLimitRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = false + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRateLimitRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = false + } + } - return svr.nextSvr.GetRateLimits(ctx, query) + return resp } diff --git a/service/interceptor/auth/resource_listen.go b/service/interceptor/auth/resource_listen.go index 080b364ff..c29e464b4 100644 --- a/service/interceptor/auth/resource_listen.go +++ b/service/interceptor/auth/resource_listen.go @@ -35,40 +35,45 @@ func (svr *Server) Before(ctx context.Context, resourceType model.Resource) { // After this function is called after the resource operation func (svr *Server) After(ctx context.Context, resourceType model.Resource, res *service.ResourceEvent) error { - switch resourceType { - case model.RService: - return svr.onServiceResource(ctx, res) - default: - return nil - } + // 资源删除,触发所有关联的策略进行一个 update 操作更新 + return svr.onChangeResource(ctx, res) } -// onServiceResource 服务资源的处理,只处理服务,namespace 只由 namespace 相关的进行处理, -func (svr *Server) onServiceResource(ctx context.Context, res *service.ResourceEvent) error { +// onChangeResource 服务资源的处理,只处理服务,namespace 只由 namespace 相关的进行处理, +func (svr *Server) onChangeResource(ctx context.Context, res *service.ResourceEvent) error { authCtx := ctx.Value(utils.ContextAuthContextKey).(*authcommon.AcquireContext) - ownerId := utils.ParseOwnerID(ctx) authCtx.SetAttachment(authcommon.ResourceAttachmentKey, map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_Services: { - { - ID: res.Service.ID, - Owner: ownerId, - Metadata: res.Service.Meta, - }, + res.Resource.Type: { + res.Resource, }, }) - users := utils.ConvertStringValuesToSlice(res.ReqService.UserIds) - removeUses := utils.ConvertStringValuesToSlice(res.ReqService.RemoveUserIds) + var users, removeUsers []string + var groups, removeGroups []string - groups := utils.ConvertStringValuesToSlice(res.ReqService.GroupIds) - removeGroups := utils.ConvertStringValuesToSlice(res.ReqService.RemoveGroupIds) + for i := range res.AddPrincipals { + switch res.AddPrincipals[i].PrincipalType { + case authcommon.PrincipalUser: + users = append(users, res.AddPrincipals[i].PrincipalID) + case authcommon.PrincipalGroup: + groups = append(groups, res.AddPrincipals[i].PrincipalID) + } + } + for i := range res.DelPrincipals { + switch res.DelPrincipals[i].PrincipalType { + case authcommon.PrincipalUser: + removeUsers = append(removeUsers, res.DelPrincipals[i].PrincipalID) + case authcommon.PrincipalGroup: + removeGroups = append(removeGroups, res.DelPrincipals[i].PrincipalID) + } + } - authCtx.SetAttachment(authcommon.LinkUsersKey, utils.StringSliceDeDuplication(users)) - authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, utils.StringSliceDeDuplication(removeUses)) + authCtx.SetAttachment(authcommon.LinkUsersKey, users) + authCtx.SetAttachment(authcommon.RemoveLinkUsersKey, removeUsers) - authCtx.SetAttachment(authcommon.LinkGroupsKey, utils.StringSliceDeDuplication(groups)) - authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, utils.StringSliceDeDuplication(removeGroups)) + authCtx.SetAttachment(authcommon.LinkGroupsKey, groups) + authCtx.SetAttachment(authcommon.RemoveLinkGroupsKey, removeGroups) return svr.policySvr.AfterResourceOperation(authCtx) } diff --git a/service/interceptor/auth/routing_config_v1.go b/service/interceptor/auth/routing_config_v1.go index 32066dbbe..b29287aa0 100644 --- a/service/interceptor/auth/routing_config_v1.go +++ b/service/interceptor/auth/routing_config_v1.go @@ -20,75 +20,30 @@ package service_auth import ( "context" - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - - api "github.com/polarismesh/polaris/common/api/v1" - authcommon "github.com/polarismesh/polaris/common/model/auth" - "github.com/polarismesh/polaris/common/utils" ) // CreateRoutingConfigs creates routing configs func (svr *Server) CreateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Create, "CreateRoutingConfigs") - - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.CreateRoutingConfigs(ctx, reqs) } // DeleteRoutingConfigs deletes routing configs func (svr *Server) DeleteRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Delete, "DeleteRoutingConfigs") - - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.DeleteRoutingConfigs(ctx, reqs) } // UpdateRoutingConfigs updates routing configs func (svr *Server) UpdateRoutingConfigs( ctx context.Context, reqs []*apitraffic.Routing) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, reqs, authcommon.Modify, "UpdateRoutingConfigs") - - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchWriteResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.UpdateRoutingConfigs(ctx, reqs) } // GetRoutingConfigs gets routing configs func (svr *Server) GetRoutingConfigs( ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - authCtx := svr.collectRouteRuleAuthContext(ctx, nil, authcommon.Read, "GetRoutingConfigs") - - _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx) - if err != nil { - return api.NewBatchQueryResponseWithMsg(apimodel.Code_NotAllowedAccess, err.Error()) - } - - ctx = authCtx.GetRequestContext() - ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - return svr.nextSvr.GetRoutingConfigs(ctx, query) } diff --git a/service/interceptor/auth/routing_config_v2.go b/service/interceptor/auth/routing_config_v2.go index 3ed128dad..6da142365 100644 --- a/service/interceptor/auth/routing_config_v2.go +++ b/service/interceptor/auth/routing_config_v2.go @@ -21,8 +21,11 @@ import ( "context" "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" @@ -36,7 +39,7 @@ func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { // TODO not support RouteRuleV2 resource auth, so we set op is read - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.CreateRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Create, authcommon.CreateRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -49,7 +52,7 @@ func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.DeleteRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Delete, authcommon.DeleteRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -62,7 +65,7 @@ func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.UpdateRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Modify, authcommon.UpdateRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -75,7 +78,7 @@ func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, func (svr *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Read, authcommon.EnableRouteRules) + authCtx := svr.collectRouteRuleV2AuthContext(ctx, req, authcommon.Modify, authcommon.EnableRouteRules) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } @@ -94,13 +97,43 @@ func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendRouterRulePredicate(ctx, func(ctx context.Context, cbr *model.ExtendRouterConfig) bool { + ctx = cachetypes.AppendRouterRulePredicate(ctx, func(ctx context.Context, cbr *model.ExtendRouterConfig) bool { return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_RouteRules, ID: cbr.ID, Metadata: cbr.Metadata, }) }) + authCtx.SetRequestContext(ctx) - return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) + resp := svr.nextSvr.QueryRoutingConfigsV2(ctx, query) + for index := range resp.Data { + item := &apitraffic.RouteRule{} + _ = anypb.UnmarshalTo(resp.Data[index], item, proto.UnmarshalOptions{}) + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_RouteRules: { + { + Type: apisecurity.ResourceType_RouteRules, + ID: item.GetId(), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRouteRules, authcommon.EnableRouteRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = false + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRouteRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = false + } + _ = anypb.MarshalFrom(resp.Data[index], item, proto.MarshalOptions{}) + } + return resp } diff --git a/service/interceptor/auth/server.go b/service/interceptor/auth/server.go index bb6859493..08750f50f 100644 --- a/service/interceptor/auth/server.go +++ b/service/interceptor/auth/server.go @@ -39,20 +39,19 @@ import ( // 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 type Server struct { nextSvr service.DiscoverServer - userMgn auth.UserServer + userSvr auth.UserServer policySvr auth.StrategyServer } -func NewServerAuthAbility(nextSvr service.DiscoverServer, - userMgn auth.UserServer, policySvr auth.StrategyServer) service.DiscoverServer { +func NewServer(nextSvr service.DiscoverServer, + userSvr auth.UserServer, policySvr auth.StrategyServer) service.DiscoverServer { proxy := &Server{ nextSvr: nextSvr, - userMgn: userMgn, + userSvr: userSvr, policySvr: policySvr, } - actualSvr, ok := nextSvr.(*service.Server) - if ok { + if actualSvr, ok := nextSvr.(*service.Server); ok { actualSvr.SetResourceHooks(proxy) } return proxy @@ -203,20 +202,23 @@ func (svr *Server) collectRouteRuleV2AuthContext(ctx context.Context, req []*api } } + accessResources := map[apisecurity.ResourceType][]authcommon.ResourceEntry{} + if len(resources) != 0 { + accessResources[apisecurity.ResourceType_RouteRules] = resources + } + return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), authcommon.WithOperation(resourceOp), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), - authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ - apisecurity.ResourceType_RouteRules: resources, - }), + authcommon.WithAccessResources(accessResources), ) } -// collectCircuitBreakerRuleV2AuthContext 收集熔断v2规则 -func (svr *Server) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, - req []*apifault.CircuitBreakerRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +// collectCircuitBreakerRuleV2 收集熔断v2规则 +func (svr *Server) collectCircuitBreakerRuleV2(ctx context.Context, req []*apifault.CircuitBreakerRule, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { resources := make([]authcommon.ResourceEntry, 0, len(req)) for i := range req { @@ -225,14 +227,14 @@ func (svr *Server) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, resources = append(resources, authcommon.ResourceEntry{ Type: apisecurity.ResourceType_CircuitBreakerRules, ID: saveRule.ID, - Metadata: saveRule.Proto.Metadata, + Metadata: saveRule.Proto.GetMetadata(), }) } } return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), + authcommon.WithOperation(op), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ @@ -242,8 +244,8 @@ func (svr *Server) collectCircuitBreakerRuleV2AuthContext(ctx context.Context, } // collectFaultDetectAuthContext 收集主动探测规则 -func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, - req []*apifault.FaultDetectRule, resourceOp authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { +func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, req []*apifault.FaultDetectRule, + op authcommon.ResourceOperation, methodName authcommon.ServerFunctionName) *authcommon.AcquireContext { resources := make([]authcommon.ResourceEntry, 0, len(req)) for i := range req { @@ -252,14 +254,14 @@ func (svr *Server) collectFaultDetectAuthContext(ctx context.Context, resources = append(resources, authcommon.ResourceEntry{ Type: apisecurity.ResourceType_FaultDetectRules, ID: saveRule.ID, - Metadata: saveRule.Proto.Metadata, + Metadata: saveRule.Proto.GetMetadata(), }) } } return authcommon.NewAcquireContext( authcommon.WithRequestContext(ctx), - authcommon.WithOperation(resourceOp), + authcommon.WithOperation(op), authcommon.WithModule(authcommon.DiscoverModule), authcommon.WithMethod(methodName), authcommon.WithAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ diff --git a/service/interceptor/auth/service.go b/service/interceptor/auth/service.go index 864a33368..cb94fe0a4 100644 --- a/service/interceptor/auth/service.go +++ b/service/interceptor/auth/service.go @@ -67,7 +67,7 @@ func (svr *Server) DeleteServices( authCtx.SetAccessResources(accessRes) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchWriteResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchWriteResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -106,7 +106,7 @@ func (svr *Server) UpdateServiceToken( authCtx.SetAccessResources(accessRes) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -119,12 +119,34 @@ func (svr *Server) GetAllServices(ctx context.Context, authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeAllServices) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) + ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + if ok { + return true + } + saveNs := svr.Cache().Namespace().GetNamespace(cbr.Namespace) + if saveNs == nil { + return false + } + // 检查下是否可以访问对应的 namespace + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Namespaces, + ID: saveNs.Name, + Metadata: saveNs.Metadata, + }) + }) + authCtx.SetRequestContext(ctx) + return svr.nextSvr.GetAllServices(ctx, query) } @@ -134,20 +156,60 @@ func (svr *Server) GetServices( authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServices) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) // 注入查询条件拦截器 + ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + ok := svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Services, + ID: cbr.ID, + Metadata: cbr.Meta, + }) + if ok { + return true + } + saveNs := svr.Cache().Namespace().GetNamespace(cbr.Namespace) + if saveNs == nil { + return false + } + // 检查下是否可以访问对应的 namespace + return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ + Type: security.ResourceType_Namespaces, + ID: saveNs.Name, + Metadata: saveNs.Metadata, + }) + }) + authCtx.SetRequestContext(ctx) resp := svr.nextSvr.GetServices(ctx, query) - if len(resp.Services) != 0 { - for index := range resp.Services { - svc := resp.Services[index] - // TODO 需要配合 metadata 做调整 - svc.Editable = utils.NewBoolValue(true) + for index := range resp.Services { + item := resp.Services[index] + authCtx.SetAccessResources(map[apisecurity.ResourceType][]authcommon.ResourceEntry{ + apisecurity.ResourceType_Services: { + { + Type: apisecurity.ResourceType_Services, + ID: item.GetId().GetValue(), + Metadata: item.Metadata, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateServices}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = utils.NewBoolValue(false) + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteServices}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = utils.NewBoolValue(false) } } return resp @@ -158,7 +220,7 @@ func (svr *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryR authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServicesCount) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -168,10 +230,11 @@ func (svr *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryR // GetServiceToken 获取服务的 token func (svr *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { - authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceToken) + authCtx := svr.collectServiceAuthContext(ctx, []*apiservice.Service{req}, authcommon.Read, + authcommon.DescribeServiceToken) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() @@ -182,22 +245,14 @@ func (svr *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) // GetServiceOwner 获取服务的 owner func (svr *Server) GetServiceOwner( ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse { - authCtx := svr.collectServiceAuthContext(ctx, nil, authcommon.Read, authcommon.DescribeServiceOwner) + authCtx := svr.collectServiceAuthContext(ctx, req, authcommon.Read, authcommon.DescribeServiceOwner) if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { - return api.NewBatchQueryResponseWithMsg(authcommon.ConvertToErrCode(err), err.Error()) + return api.NewBatchQueryResponse(authcommon.ConvertToErrCode(err)) } ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { - return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ - Type: security.ResourceType_Services, - ID: cbr.ID, - Metadata: cbr.Meta, - }) - }) - return svr.nextSvr.GetServiceOwner(ctx, req) } diff --git a/service/interceptor/auth/service_alias.go b/service/interceptor/auth/service_alias.go index d52f86538..c0c523269 100644 --- a/service/interceptor/auth/service_alias.go +++ b/service/interceptor/auth/service_alias.go @@ -21,6 +21,7 @@ import ( "context" "github.com/polarismesh/specification/source/go/api/v1/security" + apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" cachetypes "github.com/polarismesh/polaris/cache/api" @@ -95,13 +96,53 @@ func (svr *Server) GetServiceAliases(ctx context.Context, ctx = authCtx.GetRequestContext() ctx = context.WithValue(ctx, utils.ContextAuthContextKey, authCtx) - cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + ctx = cachetypes.AppendServicePredicate(ctx, func(ctx context.Context, cbr *model.Service) bool { + sourceSvc := svr.Cache().Service().GetServiceByID(cbr.Reference) + if sourceSvc == nil { + return false + } return svr.policySvr.GetAuthChecker().ResourcePredicate(authCtx, &authcommon.ResourceEntry{ Type: security.ResourceType_Services, - ID: cbr.ID, - Metadata: cbr.Meta, + ID: sourceSvc.ID, + Metadata: sourceSvc.Meta, }) }) - return svr.nextSvr.GetServiceAliases(ctx, query) + authCtx.SetRequestContext(ctx) + + resp := svr.nextSvr.GetServiceAliases(ctx, query) + for i := range resp.Aliases { + item := resp.Aliases[i] + sourceSvc := svr.Cache().Service().GetServiceByName(item.GetAlias().GetValue(), item.GetAliasNamespace().GetValue()) + if sourceSvc == nil { + item.Editable = utils.NewBoolValue(false) + item.Deleteable = utils.NewBoolValue(false) + continue + } + authCtx.SetAccessResources(map[security.ResourceType][]authcommon.ResourceEntry{ + security.ResourceType_Services: { + { + Type: apisecurity.ResourceType_Services, + ID: sourceSvc.ID, + Metadata: sourceSvc.Meta, + }, + }, + }) + + // 检查 write 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.UpdateRateLimitRules, authcommon.EnableRateLimitRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Editable = utils.NewBoolValue(false) + } + + // 检查 delete 操作权限 + authCtx.SetMethod([]authcommon.ServerFunctionName{authcommon.DeleteRateLimitRules}) + // 如果检查不通过,设置 editable 为 false + if _, err := svr.policySvr.GetAuthChecker().CheckConsolePermission(authCtx); err != nil { + item.Deleteable = utils.NewBoolValue(false) + } + } + + return resp } diff --git a/service/interceptor/paramcheck/circuit_breaker.go b/service/interceptor/paramcheck/circuit_breaker.go index 452121325..8c69aa98b 100644 --- a/service/interceptor/paramcheck/circuit_breaker.go +++ b/service/interceptor/paramcheck/circuit_breaker.go @@ -19,6 +19,7 @@ package paramcheck import ( "context" + "strconv" "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" @@ -27,119 +28,210 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/utils" ) +var ( + // CircuitBreakerRuleFilters filter circuitbreaker rule query parameters + CircuitBreakerRuleFilters = map[string]bool{ + "brief": true, + "offset": true, + "limit": true, + "id": true, + "name": true, + "namespace": true, + "enable": true, + "level": true, + "service": true, + "serviceNamespace": true, + "srcService": true, + "srcNamespace": true, + "dstService": true, + "dstNamespace": true, + "dstMethod": true, + "description": true, + } +) + // GetMasterCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetMasterCircuitBreakers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetMasterCircuitBreakers(ctx, query) } // GetReleaseCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetReleaseCircuitBreakers(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetReleaseCircuitBreakers(ctx, query) } // GetCircuitBreaker implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreaker(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreaker(ctx, query) } // GetCircuitBreakerByService implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerByService(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerByService(ctx, query) } // DeleteCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) DeleteCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { return svr.nextSvr.DeleteCircuitBreakers(ctx, req) } // GetCircuitBreakerToken implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerToken(ctx context.Context, req *fault_tolerance.CircuitBreaker) *service_manage.Response { return svr.nextSvr.GetCircuitBreakerToken(ctx, req) } // GetCircuitBreakerVersions implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) GetCircuitBreakerVersions(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetCircuitBreakerVersions(ctx, query) } -// GetCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) GetCircuitBreakerRules(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetCircuitBreakerRules(ctx, query) -} - -// DeleteCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) DeleteCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - return svr.nextSvr.DeleteCircuitBreakerRules(ctx, request) -} - -// EnableCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) EnableCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - return svr.nextSvr.EnableCircuitBreakerRules(ctx, request) -} - // ReleaseCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { return svr.nextSvr.ReleaseCircuitBreakers(ctx, req) } // UnBindCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) UnBindCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { return svr.nextSvr.UnBindCircuitBreakers(ctx, req) } -// UpdateCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - return svr.nextSvr.UpdateCircuitBreakerRules(ctx, request) -} - -// CreateCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) CreateCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - return svr.nextSvr.CreateCircuitBreakerRules(ctx, request) -} - // CreateCircuitBreakerVersions implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { return svr.nextSvr.CreateCircuitBreakerVersions(ctx, req) } // CreateCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) CreateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { return svr.nextSvr.CreateCircuitBreakers(ctx, req) } // UpdateCircuitBreakers implements service.DiscoverServer. +// Deprecated: not support from 1.14.x func (svr *Server) UpdateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { return svr.nextSvr.UpdateCircuitBreakers(ctx, req) } +// ------------- 这里开始接口实现才是正式有效的 ------------- + +// GetCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) GetCircuitBreakerRules(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + searchFilter := make(map[string]string, len(query)) + for key, value := range query { + if _, ok := CircuitBreakerRuleFilters[key]; !ok { + log.Errorf("params %s is not allowed in querying circuitbreaker rule", key) + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + if value == "" { + continue + } + searchFilter[key] = value + } + + searchFilter["offset"] = strconv.FormatUint(uint64(offset), 10) + searchFilter["limit"] = strconv.FormatUint(uint64(limit), 10) + + return svr.nextSvr.GetCircuitBreakerRules(ctx, searchFilter) +} + +// DeleteCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) DeleteCircuitBreakerRules(ctx context.Context, + reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkCircuitBreakerRuleParams(reqs[i], true, false) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteCircuitBreakerRules(ctx, reqs) +} + +// EnableCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) EnableCircuitBreakerRules(ctx context.Context, + reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkCircuitBreakerRuleParams(reqs[i], true, false) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.EnableCircuitBreakerRules(ctx, reqs) +} + +// CreateCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakerRules(ctx context.Context, + reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(reqs); err != nil { + return err + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkCircuitBreakerRuleParams(reqs[i], false, true) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.CreateCircuitBreakerRules(ctx, reqs) +} + +// UpdateCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, + reqs []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkCircuitBreakerRuleParams(reqs[i], true, true) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateCircuitBreakerRules(ctx, reqs) +} + func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { if len(req) == 0 { return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) @@ -150,3 +242,41 @@ func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservi } return nil } + +func checkCircuitBreakerRuleParams( + req *apifault.CircuitBreakerRule, idRequired bool, nameRequired bool) *apiservice.Response { + if req == nil { + return api.NewResponse(apimodel.Code_EmptyRequest) + } + if resp := checkCircuitBreakerRuleParamsDbLen(req); nil != resp { + return resp + } + if nameRequired && len(req.GetName()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) + } + if idRequired && len(req.GetId()) == 0 { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerID) + } + return nil +} + +func checkCircuitBreakerRuleParamsDbLen(req *apifault.CircuitBreakerRule) *apiservice.Response { + if err := utils.CheckDbRawStrFieldLen( + req.RuleMatcher.GetSource().GetService(), utils.MaxDbServiceNameLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceName) + } + if err := utils.CheckDbRawStrFieldLen( + req.RuleMatcher.GetSource().GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetName(), utils.MaxRuleName); err != nil { + return api.NewResponse(apimodel.Code_InvalidCircuitBreakerName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if err := utils.CheckDbRawStrFieldLen(req.GetDescription(), utils.MaxCommentLength); err != nil { + return api.NewResponse(apimodel.Code_InvalidServiceComment) + } + return nil +} diff --git a/service/interceptor/paramcheck/fault_detect.go b/service/interceptor/paramcheck/fault_detect.go index 28a8d4ca4..2aac60ecf 100644 --- a/service/interceptor/paramcheck/fault_detect.go +++ b/service/interceptor/paramcheck/fault_detect.go @@ -119,7 +119,8 @@ func (svr *Server) CreateFaultDetectRules(ctx context.Context, } // UpdateFaultDetectRules implements service.DiscoverServer. -func (svr *Server) UpdateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { +func (svr *Server) UpdateFaultDetectRules(ctx context.Context, + request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { if checkErr := checkBatchFaultDetectRules(request); checkErr != nil { return checkErr } diff --git a/service/interceptor/paramcheck/lane.go b/service/interceptor/paramcheck/lane.go new file mode 100644 index 000000000..db4293804 --- /dev/null +++ b/service/interceptor/paramcheck/lane.go @@ -0,0 +1,156 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + "strconv" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/wrapperspb" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/common/utils" +) + +var ( + laneGroupSearchAttributes = map[string]struct{}{ + "id": {}, + "name": {}, + "offset": {}, + "brief": {}, + "limit": {}, + "order_type": {}, + "order_field": {}, + } +) + +// CreateLaneGroups 批量创建泳道组 +func (svr *Server) CreateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + if err := checkBatchLaneGroupRules(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkLaneGroupParam(reqs[i], false) + api.Collect(batchRsp, rsp) + } + + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.CreateLaneGroups(ctx, reqs) +} + +// UpdateLaneGroups 批量更新泳道组 +func (svr *Server) UpdateLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + if err := checkBatchLaneGroupRules(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + rsp := checkLaneGroupParam(reqs[i], true) + api.Collect(batchRsp, rsp) + } + + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateLaneGroups(ctx, reqs) +} + +// DeleteLaneGroups 批量删除泳道组 +func (svr *Server) DeleteLaneGroups(ctx context.Context, reqs []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + if err := checkBatchLaneGroupRules(reqs); err != nil { + return err + } + return svr.nextSvr.DeleteLaneGroups(ctx, reqs) +} + +// GetLaneGroups 查询泳道组列表 +func (svr *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { + offset, limit, err := utils.ParseOffsetAndLimit(filter) + if err != nil { + return api.NewBatchQueryResponseWithMsg(apimodel.Code_BadRequest, err.Error()) + } + + for k := range filter { + if _, ok := laneGroupSearchAttributes[k]; !ok { + log.Error("[Server][LaneGroup][Query] not allowed", zap.String("attribute", k), utils.RequestID(ctx)) + return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, k+" is not allowed") + } + if filter[k] == "" { + delete(filter, k) + } + } + + if _, ok := filter["order_field"]; !ok { + filter["order_field"] = "mtime" + } + if _, ok := filter["order_type"]; !ok { + filter["order_type"] = "desc" + } + + filter["offset"] = strconv.FormatUint(uint64(offset), 10) + filter["limit"] = strconv.FormatUint(uint64(limit), 10) + + return svr.nextSvr.GetLaneGroups(ctx, filter) +} + +func checkBatchLaneGroupRules(req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + return nil +} + +func checkLaneGroupParam(req *apitraffic.LaneGroup, update bool) *apiservice.Response { + if len(req.GetName()) >= utils.MaxRuleName { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_group name size must be <= 64") + } + if err := utils.CheckResourceName(wrapperspb.String(req.GetName())); err != nil { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, err.Error()) + } + if len(req.Rules) > utils.MaxBatchSize { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_rule size must be <= 100") + } + for i := range req.Rules { + rule := req.Rules[i] + if err := utils.CheckResourceName(wrapperspb.String(rule.GetName())); err != nil { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, err.Error()) + } + if len(rule.GetName()) >= utils.MaxRuleName { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_rule name size must be <= 64") + } + } + + if update { + if req.GetId() == "" { + return api.NewResponseWithMsg(apimodel.Code_InvalidParameter, "lane_group id is empty") + } + } + return nil +} diff --git a/service/interceptor/paramcheck/ratelimit.go b/service/interceptor/paramcheck/ratelimit.go index 4720b462f..0c210a1a1 100644 --- a/service/interceptor/paramcheck/ratelimit.go +++ b/service/interceptor/paramcheck/ratelimit.go @@ -19,27 +19,80 @@ package paramcheck import ( "context" + "time" + "github.com/golang/protobuf/ptypes" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/common/utils" ) // CreateRateLimits implements service.DiscoverServer. func (svr *Server) CreateRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateRateLimits(ctx, request) + reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + if err := checkBatchRateLimits(reqs); err != nil { + return err + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + // 参数校验 + if resp := checkRateLimitParams(reqs[i]); resp != nil { + api.Collect(batchRsp, resp) + continue + } + if resp := checkRateLimitRuleParams(ctx, reqs[i]); resp != nil { + api.Collect(batchRsp, resp) + continue + } + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + + return svr.nextSvr.CreateRateLimits(ctx, reqs) } // DeleteRateLimits implements service.DiscoverServer. -func (svr *Server) DeleteRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteRateLimits(ctx, request) +func (svr *Server) DeleteRateLimits(ctx context.Context, reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + if err := checkBatchRateLimits(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + resp := checkRevisedRateLimitParams(reqs[i]) + api.Collect(batchRsp, resp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteRateLimits(ctx, reqs) } // EnableRateLimits implements service.DiscoverServer. func (svr *Server) EnableRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.EnableRateLimits(ctx, request) + reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + if err := checkBatchRateLimits(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + resp := checkRevisedRateLimitParams(reqs[i]) + api.Collect(batchRsp, resp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.EnableRateLimits(ctx, reqs) } // GetRateLimits implements service.DiscoverServer. @@ -49,6 +102,104 @@ func (svr *Server) GetRateLimits(ctx context.Context, } // UpdateRateLimits implements service.DiscoverServer. -func (svr *Server) UpdateRateLimits(ctx context.Context, request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateRateLimits(ctx, request) +func (svr *Server) UpdateRateLimits(ctx context.Context, reqs []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + if err := checkBatchRateLimits(reqs); err != nil { + return err + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + if resp := checkRevisedRateLimitParams(reqs[i]); resp != nil { + api.Collect(batchRsp, resp) + continue + } + if resp := checkRateLimitRuleParams(ctx, reqs[i]); resp != nil { + api.Collect(batchRsp, resp) + continue + } + if resp := checkRateLimitParamsDbLen(reqs[i]); resp != nil { + api.Collect(batchRsp, resp) + continue + } + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + + return svr.nextSvr.UpdateRateLimits(ctx, reqs) +} + +// checkBatchRateLimits 检查批量请求的限流规则 +func checkBatchRateLimits(req []*apitraffic.Rule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// checkRateLimitParams 检查限流规则基础参数 +func checkRateLimitParams(req *apitraffic.Rule) *apiservice.Response { + if req == nil { + return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) + } + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) + } + if resp := checkRateLimitParamsDbLen(req); nil != resp { + return resp + } + return nil +} + +// checkRateLimitParams 检查限流规则基础参数 +func checkRateLimitParamsDbLen(req *apitraffic.Rule) *apiservice.Response { + if err := utils.CheckDbStrFieldLen(req.GetService(), utils.MaxDbServiceNameLength); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := utils.CheckDbStrFieldLen(req.GetName(), utils.MaxDbRateLimitName); err != nil { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitName, req) + } + return nil +} + +// checkRateLimitRuleParams 检查限流规则其他参数 +func checkRateLimitRuleParams(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { + // 检查amounts是否有重复周期 + amounts := req.GetAmounts() + durations := make(map[time.Duration]bool) + for _, amount := range amounts { + d := amount.GetValidDuration() + duration, err := ptypes.Duration(d) + if err != nil { + log.Error(err.Error(), utils.RequestID(ctx)) + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) + } + durations[duration] = true + } + if len(amounts) != len(durations) { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) + } + return nil +} + +// checkRevisedRateLimitParams 检查修改/删除限流规则基础参数 +func checkRevisedRateLimitParams(req *apitraffic.Rule) *apiservice.Response { + if req == nil { + return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) + } + if req.GetId().GetValue() == "" { + return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitID, req) + } + return nil } diff --git a/service/interceptor/paramcheck/route_rule.go b/service/interceptor/paramcheck/route_rule.go index 4da1fedad..46eab06eb 100644 --- a/service/interceptor/paramcheck/route_rule.go +++ b/service/interceptor/paramcheck/route_rule.go @@ -19,45 +19,61 @@ package paramcheck import ( "context" + "strconv" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + + apiv1 "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" +) + +var ( + // RoutingConfigV2FilterAttrs router config filter attrs + RoutingConfigV2FilterAttrs = map[string]bool{ + "id": true, + "name": true, + "service": true, + "namespace": true, + "source_service": true, + "destination_service": true, + "source_namespace": true, + "destination_namespace": true, + "enable": true, + "offset": true, + "limit": true, + "order_field": true, + "order_type": true, + } ) // UpdateRoutingConfigs implements service.DiscoverServer. +// Deprecated: not support from 1.19.x func (svr *Server) UpdateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.UpdateRoutingConfigs(ctx, req) } -// UpdateRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) -} - -// QueryRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) -} - // GetRoutingConfigs implements service.DiscoverServer. +// Deprecated: not support from 1.19.x func (svr *Server) GetRoutingConfigs(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { return svr.nextSvr.GetRoutingConfigs(ctx, query) } -// EnableRoutings implements service.DiscoverServer. -func (svr *Server) EnableRoutings(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.EnableRoutings(ctx, req) -} - // CreateRoutingConfigs implements service.DiscoverServer. +// Deprecated: not support from 1.19.x func (svr *Server) CreateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.CreateRoutingConfigs(ctx, req) } // DeleteRoutingConfigs implements service.DiscoverServer. +// Deprecated: not support from 1.19.x func (svr *Server) DeleteRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { return svr.nextSvr.DeleteRoutingConfigs(ctx, req) @@ -66,11 +82,211 @@ func (svr *Server) DeleteRoutingConfigs(ctx context.Context, // CreateRoutingConfigsV2 implements service.DiscoverServer. func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, item := range req { + if resp := checkRoutingConfigV2(item); resp != nil { + apiv1.Collect(batchRsp, resp) + } + } + if !apiv1.IsSuccess(batchRsp) { + return batchRsp + } return svr.nextSvr.CreateRoutingConfigsV2(ctx, req) } +// EnableRoutings implements service.DiscoverServer. +func (svr *Server) EnableRoutings(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, item := range req { + if resp := checkRoutingConfigIDV2(item); resp != nil { + apiv1.Collect(batchRsp, resp) + } + } + if !apiv1.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.EnableRoutings(ctx, req) +} + +// UpdateRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, item := range req { + if resp := checkUpdateRoutingConfigV2(item); resp != nil { + apiv1.Collect(batchRsp, resp) + } + } + if !apiv1.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) +} + +// QueryRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return apiv1.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + + filter := make(map[string]string) + for key, value := range query { + if _, ok := RoutingConfigV2FilterAttrs[key]; !ok { + log.Errorf("[Routing][V2][Query] attribute(%s) is not allowed", key) + return apiv1.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + filter[key] = value + } + filter["offset"] = strconv.FormatUint(uint64(offset), 10) + filter["limit"] = strconv.FormatUint(uint64(limit), 10) + + return svr.nextSvr.QueryRoutingConfigsV2(ctx, filter) +} + // DeleteRoutingConfigsV2 implements service.DiscoverServer. func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + if err := checkBatchRoutingConfigV2(req); err != nil { + return err + } + batchRsp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for _, item := range req { + if resp := checkRoutingConfigIDV2(item); resp != nil { + apiv1.Collect(batchRsp, resp) + } + } + if !apiv1.IsSuccess(batchRsp) { + return batchRsp + } return svr.nextSvr.DeleteRoutingConfigsV2(ctx, req) } + +// checkBatchRoutingConfig Check batch request +func checkBatchRoutingConfigV2(req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return apiv1.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return apiv1.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// checkRoutingConfig Check the validity of the basic parameter of the routing configuration +func checkRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if err := checkRoutingNameAndNamespace(req); err != nil { + return err + } + + if err := checkRoutingConfigPriorityV2(req); err != nil { + return err + } + + if err := checkRoutingPolicyV2(req); err != nil { + return err + } + + return nil +} + +// checkUpdateRoutingConfigV2 Check the validity of the basic parameter of the routing configuration +func checkUpdateRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { + if resp := checkRoutingConfigIDV2(req); resp != nil { + return resp + } + + if err := checkRoutingNameAndNamespace(req); err != nil { + return err + } + + if err := checkRoutingConfigPriorityV2(req); err != nil { + return err + } + + if err := checkRoutingPolicyV2(req); err != nil { + return err + } + + return nil +} + +func checkRoutingNameAndNamespace(req *apitraffic.RouteRule) *apiservice.Response { + if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetName()), utils.MaxRuleName); err != nil { + return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingName, req) + } + + if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetNamespace()), + utils.MaxDbServiceNamespaceLength); err != nil { + return apiv1.NewRouterResponse(apimodel.Code_InvalidNamespaceName, req) + } + + return nil +} + +func checkRoutingConfigIDV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.Id == "" { + return apiv1.NewResponse(apimodel.Code_InvalidRoutingID) + } + + return nil +} + +func checkRoutingConfigPriorityV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.Priority > 10 { + return apiv1.NewResponse(apimodel.Code_InvalidRoutingPriority) + } + + return nil +} + +func checkRoutingPolicyV2(req *apitraffic.RouteRule) *apiservice.Response { + if req == nil { + return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) + } + + if req.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { + return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingPolicy, req) + } + + // Automatically supplement @Type attribute according to Policy + if req.RoutingConfig.TypeUrl == "" { + if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { + req.RoutingConfig.TypeUrl = model.RuleRoutingTypeUrl + } + if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_MetadataPolicy { + req.RoutingConfig.TypeUrl = model.MetaRoutingTypeUrl + } + if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_NearbyPolicy { + req.RoutingConfig.TypeUrl = model.NearbyRoutingTypeUrl + } + } + + return nil +} diff --git a/service/interceptor/paramcheck/server.go b/service/interceptor/paramcheck/server.go index f46a54065..680e14bf1 100644 --- a/service/interceptor/paramcheck/server.go +++ b/service/interceptor/paramcheck/server.go @@ -35,9 +35,10 @@ type Server struct { ratelimit plugin.Ratelimit } -func NewServer(nextSvr service.DiscoverServer) service.DiscoverServer { +func NewServer(nextSvr service.DiscoverServer, s store.Store) service.DiscoverServer { proxy := &Server{ nextSvr: nextSvr, + storage: s, } // 获取限流插件 proxy.ratelimit = plugin.GetRatelimit() diff --git a/service/interceptor/paramcheck/service.go b/service/interceptor/paramcheck/service.go index d46394f22..2dc917322 100644 --- a/service/interceptor/paramcheck/service.go +++ b/service/interceptor/paramcheck/service.go @@ -20,6 +20,7 @@ package paramcheck import ( "context" "errors" + "strconv" "strings" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" @@ -175,6 +176,14 @@ func (svr *Server) GetServices(ctx context.Context, query map[string]string) *se } } + // 判断offset和limit是否为int,并从filters清除offset/limit参数 + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + query["offset"] = strconv.FormatUint(uint64(offset), 10) + query["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetServices(ctx, query) } diff --git a/service/interceptor/paramcheck/service_alias.go b/service/interceptor/paramcheck/service_alias.go index c01f58de8..7435cfbec 100644 --- a/service/interceptor/paramcheck/service_alias.go +++ b/service/interceptor/paramcheck/service_alias.go @@ -40,12 +40,8 @@ func (svr *Server) CreateServiceAlias(ctx context.Context, // DeleteServiceAliases implements service.DiscoverServer. func (svr *Server) DeleteServiceAliases(ctx context.Context, req []*service_manage.ServiceAlias) *service_manage.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > utils.MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + if checkError := checkBatchAlias(req); checkError != nil { + return checkError } batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) @@ -74,6 +70,18 @@ func (svr *Server) GetServiceAliases(ctx context.Context, return svr.nextSvr.GetServiceAliases(ctx, query) } +func checkBatchAlias(req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + // checkCreateServiceAliasReq 检查别名请求 func checkCreateServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { response, done := preCheckAlias(req) diff --git a/service/interceptor/register.go b/service/interceptor/register.go index 399ba8294..87a8ac433 100644 --- a/service/interceptor/register.go +++ b/service/interceptor/register.go @@ -22,27 +22,30 @@ import ( "github.com/polarismesh/polaris/service" service_auth "github.com/polarismesh/polaris/service/interceptor/auth" "github.com/polarismesh/polaris/service/interceptor/paramcheck" + "github.com/polarismesh/polaris/store" ) func init() { - err := service.RegisterServerProxy("paramcheck", func(pre service.DiscoverServer) (service.DiscoverServer, error) { - return paramcheck.NewServer(pre), nil + err := service.RegisterServerProxy("paramcheck", func(pre service.DiscoverServer, + s store.Store) (service.DiscoverServer, error) { + return paramcheck.NewServer(pre, s), nil }) if err != nil { panic(err) } - err = service.RegisterServerProxy("auth", func(pre service.DiscoverServer) (service.DiscoverServer, error) { - userMgn, err := auth.GetUserServer() + err = service.RegisterServerProxy("auth", func(pre service.DiscoverServer, + s store.Store) (service.DiscoverServer, error) { + userSvr, err := auth.GetUserServer() if err != nil { return nil, err } - strategyMgn, err := auth.GetStrategyServer() + policySvr, err := auth.GetStrategyServer() if err != nil { return nil, err } - return service_auth.NewServerAuthAbility(pre, userMgn, strategyMgn), nil + return service_auth.NewServer(pre, userSvr, policySvr), nil }) if err != nil { panic(err) diff --git a/service/lane.go b/service/lane.go new file mode 100644 index 000000000..78b08f39c --- /dev/null +++ b/service/lane.go @@ -0,0 +1,295 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service + +import ( + "context" + "fmt" + "time" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" + + cachetypes "github.com/polarismesh/polaris/cache/api" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" + commonstore "github.com/polarismesh/polaris/common/store" + "github.com/polarismesh/polaris/common/utils" +) + +// CreateLaneGroups 批量创建泳道组 +func (s *Server) CreateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + resp := s.CreateLaneGroup(ctx, req[i]) + api.Collect(responses, resp) + } + return api.FormatBatchWriteResponse(responses) +} + +// CreateLaneGroup 创建泳道组 +func (s *Server) CreateLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { + tx, err := s.storage.StartTx() + if err != nil { + log.Error("[Service][Lane] open store transaction fail", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + defer func() { + _ = tx.Rollback() + }() + + saveVal, err := s.storage.LockLaneGroup(tx, req.GetName()) + if err != nil { + log.Error("[Service][Lane] lock one lane_group", utils.RequestID(ctx), + zap.String("name", req.GetName()), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + if saveVal != nil { + return api.NewResponse(apimodel.Code_ExistedResource) + } + saveData := &model.LaneGroup{} + if err := saveData.FromSpec(req); err != nil { + log.Error("[Service][Lane] create lane_group transfer spec to model", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(apimodel.Code_ExecuteException) + } + saveData.ID = utils.DefaultString(req.GetId(), utils.NewUUID()) + saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) + + // 由于这里是新建,所以需要手动再把两个 flag 字段设置为 true 状态 + for i := range saveData.LaneRules { + saveData.LaneRules[i].SetAddFlag(true) + saveData.LaneRules[i].SetChangeEnable(true) + } + + if err := s.storage.AddLaneGroup(tx, saveData); err != nil { + log.Error("[Service][Lane] save lane_group", utils.RequestID(ctx), zap.String("name", saveData.Name), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + req.Id = saveData.ID + + if err := tx.Commit(); err != nil { + log.Error("[Service][Lane] commit store transaction fail", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + + s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.OCreate)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId(), + Type: security.ResourceType_LaneRules, + }, false) + return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) +} + +// UpdateLaneGroups 批量更新泳道组 +func (s *Server) UpdateLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + resp := s.UpdateLaneGroup(ctx, req[i]) + api.Collect(responses, resp) + } + return api.FormatBatchWriteResponse(responses) +} + +// UpdateLaneGroup 更新泳道组 +func (s *Server) UpdateLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { + tx, err := s.storage.StartTx() + if err != nil { + log.Error("[Service][Lane] open store transaction fail", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + defer func() { + _ = tx.Rollback() + }() + + saveData, err := s.storage.LockLaneGroup(tx, req.GetName()) + if err != nil { + log.Error("[Service][Lane] lock one lane_group", utils.RequestID(ctx), + zap.String("name", req.GetName()), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + if saveData == nil { + log.Error("[Service][Lane] lock one lane_group not found", utils.RequestID(ctx), + zap.String("name", req.GetName())) + return api.NewResponse(apimodel.Code_NotFoundResource) + } + + needUpdate, err := updateLaneGroupAttribute(req, saveData) + if err != nil { + log.Error("[Service][Lane] update lane_group transfer spec to model", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(apimodel.Code_ExecuteException) + } + if !needUpdate { + return api.NewResponse(apimodel.Code_NoNeedUpdate) + } + + saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) + if err := s.storage.UpdateLaneGroup(tx, saveData); err != nil { + log.Error("[Service][Lane] update lane_group", utils.RequestID(ctx), zap.String("name", saveData.Name), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + req.Id = saveData.ID + + if err := tx.Commit(); err != nil { + log.Error("[Service][Lane] commit store transaction fail", utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + + s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.OUpdate)) + return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) +} + +// DeleteLaneGroups 批量删除泳道组 +func (s *Server) DeleteLaneGroups(ctx context.Context, req []*apitraffic.LaneGroup) *apiservice.BatchWriteResponse { + responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + resp := s.DeleteLaneGroup(ctx, req[i]) + api.Collect(responses, resp) + } + return api.FormatBatchWriteResponse(responses) +} + +// DeleteLaneGroup 删除泳道组 +func (s *Server) DeleteLaneGroup(ctx context.Context, req *apitraffic.LaneGroup) *apiservice.Response { + var saveData *model.LaneGroup + var err error + if req.GetId() != "" { + saveData, err = s.storage.GetLaneGroupByID(req.GetId()) + } else { + saveData, err = s.storage.GetLaneGroup(req.GetName()) + } + if err != nil { + log.Error("[Server][LaneGroup] get target lane_group when delete", zap.String("id", req.GetId()), + zap.String("name", req.GetName()), utils.RequestID(ctx), zap.Error(err)) + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + if saveData == nil { + log.Info("[Server][LaneGroup] delete target lane_group but not found", zap.String("id", req.GetId()), + zap.String("name", req.GetName()), utils.RequestID(ctx)) + return api.NewResponse(apimodel.Code_ExecuteSuccess) + } + + saveData.Revision = utils.DefaultString(req.GetRevision(), utils.NewUUID()) + if err := s.storage.DeleteLaneGroup(saveData.ID); err != nil { + return api.NewResponse(commonstore.StoreCode2APICode(err)) + } + req.Id = saveData.ID + s.RecordHistory(ctx, laneGroupRecordEntry(ctx, req, saveData, model.ODelete)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId(), + Type: security.ResourceType_LaneRules, + }, true) + return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, req) +} + +// GetLaneGroups 查询泳道组列表 +func (s *Server) GetLaneGroups(ctx context.Context, filter map[string]string) *apiservice.BatchQueryResponse { + offset, limit, _ := utils.ParseOffsetAndLimit(filter) + total, ret, err := s.caches.LaneRule().Query(ctx, &cachetypes.LaneGroupArgs{ + Filter: filter, + Offset: offset, + Limit: limit, + }) + if err != nil { + log.Error("[Server][LaneGroup][Query] get lane_groups from store", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) + } + + rsp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) + rsp.Amount = wrapperspb.UInt32(total) + rsp.Size = wrapperspb.UInt32(uint32(len(ret))) + rsp.Data = make([]*anypb.Any, 0, len(ret)) + + for i := range ret { + data, err := ret[i].ToProto() + if err != nil { + log.Error("[Server][LaneGroup][Query] lane_group convert to proto", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) + } + anyData, err := anypb.New(proto.MessageV2(data.Proto)) + if err != nil { + log.Error("[Server][LaneGroup][Query] lane_group convert to anypb", utils.RequestID(ctx), zap.Error(err)) + return api.NewBatchQueryResponse(apimodel.Code_ExecuteException) + } + rsp.Data = append(rsp.Data, anyData) + } + return rsp +} + +// GetAllLaneGroups Query all router_rule rules +func (s *Server) GetAllLaneGroups(ctx context.Context) *apiservice.BatchQueryResponse { + return nil +} + +func updateLaneGroupAttribute(req *apitraffic.LaneGroup, saveData *model.LaneGroup) (bool, error) { + updateData := &model.LaneGroup{} + if err := updateData.FromSpec(req); err != nil { + return false, err + } + + saveData.Description = updateData.Description + saveData.Rule = updateData.Rule + + for ruleId := range updateData.LaneRules { + // 默认所有规则 enable 状态都出现了变更 + updateData.LaneRules[ruleId].SetChangeEnable(true) + updateData.LaneRules[ruleId].SetAddFlag(false) + } + + for ruleId := range updateData.LaneRules { + newRule := updateData.LaneRules[ruleId] + oldRule, ok := saveData.LaneRules[ruleId] + if !ok { + // 在原来的规则当中不存在,认为是新增的 + newRule.SetAddFlag(true) + continue + } + newRule.Revision = utils.DefaultString(newRule.Revision, utils.NewUUID()) + // 如果 Enable 字段比较发现没有变化,则设置为 nil + if oldRule.Enable == newRule.Enable { + newRule.SetChangeEnable(false) + } + } + saveData.LaneRules = updateData.LaneRules + return true, nil +} + +// laneGroupRecordEntry 转换为鉴权策略的记录结构体 +func laneGroupRecordEntry(ctx context.Context, req *apitraffic.LaneGroup, md *model.LaneGroup, + operationType model.OperationType) *model.RecordEntry { + + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + + entry := &model.RecordEntry{ + ResourceType: model.RLaneGroup, + ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), + OperationType: operationType, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + return entry +} diff --git a/service/namespace_test.go b/service/namespace_test.go index 6451492ee..abd0b5999 100644 --- a/service/namespace_test.go +++ b/service/namespace_test.go @@ -158,7 +158,7 @@ func TestRemoveNamespace(t *testing.T) { } defer discoverSuit.cleanServiceName(serviceReq.GetName().GetValue(), serviceReq.GetNamespace().GetValue()) - resp := discoverSuit.NamespaceServer().DeleteNamespace(discoverSuit.DefaultCtx, namespaceResp) + resp := discoverSuit.NamespaceServer().DeleteNamespaces(discoverSuit.DefaultCtx, []*apimodel.Namespace{namespaceResp}) if resp.GetCode().GetValue() != uint32(apimodel.Code_NamespaceExistedServices) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) } diff --git a/service/ratelimit_config.go b/service/ratelimit_config.go index be4bf0734..37004b7e8 100644 --- a/service/ratelimit_config.go +++ b/service/ratelimit_config.go @@ -25,14 +25,16 @@ import ( "time" "github.com/gogo/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" + "go.uber.org/zap" cachetypes "github.com/polarismesh/polaris/cache/api" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" commontime "github.com/polarismesh/polaris/common/time" "github.com/polarismesh/polaris/common/utils" @@ -56,10 +58,6 @@ var ( // CreateRateLimits 批量创建限流规则 func (s *Server) CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if err := checkBatchRateLimits(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, rateLimit := range request { response := s.CreateRateLimit(ctx, rateLimit) @@ -70,45 +68,34 @@ func (s *Server) CreateRateLimits(ctx context.Context, request []*apitraffic.Rul // CreateRateLimit 创建限流规则 func (s *Server) CreateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - - // 参数校验 - if resp := checkRateLimitParams(req); resp != nil { - return resp - } - if resp := checkRateLimitRuleParams(requestID, req); resp != nil { - return resp - } - // 构造底层数据结构 data, err := api2RateLimit(req, nil) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) } // 存储层操作 if err := s.storage.CreateRateLimit(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("create rate limit rule: id=%v, namespace=%v, service=%v, name=%v", data.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), req.GetName().GetValue()) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, data, model.OCreate)) - + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId().GetValue(), + Type: security.ResourceType_RateLimitRules, + }, false) req.Id = utils.NewStringValue(data.ID) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) } // DeleteRateLimits 批量删除限流规则 func (s *Server) DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if err := checkBatchRateLimits(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { resp := s.DeleteRateLimit(ctx, entry) @@ -119,16 +106,8 @@ func (s *Server) DeleteRateLimits(ctx context.Context, request []*apitraffic.Rul // DeleteRateLimit 删除单个限流规则 func (s *Server) DeleteRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - - // 参数校验 - if resp := checkRevisedRateLimitParams(req); resp != nil { - return resp - } - // 检查限流规则是否存在 - rateLimit, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) + rateLimit, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) if resp != nil { if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundRateLimit) { return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -141,23 +120,24 @@ func (s *Server) DeleteRateLimit(ctx context.Context, req *apitraffic.Rule) *api // 存储层操作 if err := s.storage.DeleteRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("delete rate limit rule: id=%v, namespace=%v, service=%v, name=%v", rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Labels) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.ODelete)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId().GetValue(), + Type: security.ResourceType_RateLimitRules, + }, true) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) } func (s *Server) EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if err := checkBatchRateLimits(request); err != nil { - return err - } responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { response := s.EnableRateLimit(ctx, entry) @@ -168,16 +148,8 @@ func (s *Server) EnableRateLimits(ctx context.Context, request []*apitraffic.Rul // EnableRateLimit 启用限流规则 func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - - // 参数校验 - if resp := checkRevisedRateLimitParams(req); resp != nil { - return resp - } - // 检查限流规则是否存在 - data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) + data, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) if resp != nil { return resp } @@ -190,13 +162,13 @@ func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *api rateLimit.Revision = utils.NewUUID() if err := s.storage.EnableRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("enable rate limit: id=%v, disable=%v", rateLimit.ID, rateLimit.Disable) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdateEnable)) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -204,10 +176,6 @@ func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *api // UpdateRateLimits 批量更新限流规则 func (s *Server) UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if err := checkBatchRateLimits(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { response := s.UpdateRateLimit(ctx, entry) @@ -218,20 +186,8 @@ func (s *Server) UpdateRateLimits(ctx context.Context, request []*apitraffic.Rul // UpdateRateLimit 更新限流规则 func (s *Server) UpdateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - // 参数校验 - if resp := checkRevisedRateLimitParams(req); resp != nil { - return resp - } - if resp := checkRateLimitRuleParams(requestID, req); resp != nil { - return resp - } - if resp := checkRateLimitParamsDbLen(req); resp != nil { - return resp - } - // 检查限流规则是否存在 - data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) + data, resp := s.checkRateLimitExisted(ctx, req.GetId().GetValue(), req) if resp != nil { return resp } @@ -239,18 +195,18 @@ func (s *Server) UpdateRateLimit(ctx context.Context, req *apitraffic.Rule) *api // 构造底层数据结构 rateLimit, err := api2RateLimit(req, data) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) } rateLimit.ID = data.ID if err := s.storage.UpdateRateLimit(rateLimit); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperRateLimitStoreResponse(req, err) } msg := fmt.Sprintf("update rate limit: id=%v, namespace=%v, service=%v, name=%v", rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Name) - log.Info(msg, utils.ZapRequestID(requestID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdate)) return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) @@ -266,7 +222,7 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap total, extendRateLimits, err := s.Cache().RateLimit().QueryRateLimitRules(ctx, *args) if err != nil { - log.Errorf("get rate limits store err: %s", err.Error()) + log.Error("get rate limits store", zap.Error(err), utils.RequestID(ctx)) return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) } @@ -277,7 +233,7 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap for _, item := range extendRateLimits { limit, err := rateLimit2Console(item) if err != nil { - log.Errorf("get rate limits convert err: %s", err.Error()) + log.Error("get rate limits convert", zap.Error(err), utils.RequestID(ctx)) return api.NewBatchQueryResponse(apimodel.Code_ParseRateLimitException) } out.RateLimits = append(out.RateLimits, limit) @@ -286,6 +242,11 @@ func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *ap return out } +// GetAllRateLimits Query all router_rule rules +func (s *Server) GetAllRateLimits(ctx context.Context) *apiservice.BatchQueryResponse { + return nil +} + func parseRateLimitArgs(query map[string]string) (*cachetypes.RateLimitRuleArgs, *apiservice.BatchQueryResponse) { for key := range query { if _, ok := RateLimitFilters[key]; !ok { @@ -318,19 +279,6 @@ func parseRateLimitArgs(query map[string]string) (*cachetypes.RateLimitRuleArgs, return args, nil } -// checkBatchRateLimits 检查批量请求的限流规则 -func checkBatchRateLimits(req []*apitraffic.Rule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - // checkRateLimitValid 检查限流规则是否允许修改/删除 func (s *Server) checkRateLimitValid(ctx context.Context, serviceID string, req *apitraffic.Rule) ( *model.Service, *apiservice.Response) { @@ -345,74 +293,13 @@ func (s *Server) checkRateLimitValid(ctx context.Context, serviceID string, req return service, nil } -// checkRateLimitParams 检查限流规则基础参数 -func checkRateLimitParams(req *apitraffic.Rule) *apiservice.Response { - if req == nil { - return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) - } - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) - } - if resp := checkRateLimitParamsDbLen(req); nil != resp { - return resp - } - return nil -} - -// checkRateLimitParams 检查限流规则基础参数 -func checkRateLimitParamsDbLen(req *apitraffic.Rule) *apiservice.Response { - if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetName(), MaxDbRateLimitName); err != nil { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitName, req) - } - return nil -} - -// checkRateLimitRuleParams 检查限流规则其他参数 -func checkRateLimitRuleParams(requestID string, req *apitraffic.Rule) *apiservice.Response { - // 检查amounts是否有重复周期 - amounts := req.GetAmounts() - durations := make(map[time.Duration]bool) - for _, amount := range amounts { - d := amount.GetValidDuration() - duration, err := ptypes.Duration(d) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) - } - durations[duration] = true - } - if len(amounts) != len(durations) { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) - } - return nil -} - -// checkRevisedRateLimitParams 检查修改/删除限流规则基础参数 -func checkRevisedRateLimitParams(req *apitraffic.Rule) *apiservice.Response { - if req == nil { - return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) - } - if req.GetId().GetValue() == "" { - return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitID, req) - } - return nil -} - // checkRateLimitExisted 检查限流规则是否存在 -func (s *Server) checkRateLimitExisted( - id, requestID string, req *apitraffic.Rule) (*model.RateLimit, *apiservice.Response) { +func (s *Server) checkRateLimitExisted(ctx context.Context, id string, + req *apitraffic.Rule) (*model.RateLimit, *apiservice.Response) { + rateLimit, err := s.storage.GetRateLimitWithID(id) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, api.NewRateLimitResponse(commonstore.StoreCode2APICode(err), req) } if rateLimit == nil { @@ -447,6 +334,7 @@ func api2RateLimit(req *apitraffic.Rule, old *model.RateLimit) (*model.RateLimit Labels: string(labelStr), Rule: rule, Revision: utils.NewUUID(), + Metadata: req.Metadata, } return out, nil } @@ -457,6 +345,7 @@ func rateLimit2Console(rateLimit *model.RateLimit) (*apitraffic.Rule, error) { return nil, nil } if len(rateLimit.Rule) > 0 { + rateLimit = rateLimit.CopyNoProto() rateLimit.Proto = &apitraffic.Rule{} // 控制台查询的请求 if err := json.Unmarshal([]byte(rateLimit.Rule), rateLimit.Proto); err != nil { @@ -474,6 +363,7 @@ func rateLimit2Console(rateLimit *model.RateLimit) (*apitraffic.Rule, error) { rule.Ctime = utils.NewStringValue(commontime.Time2String(rateLimit.CreateTime)) rule.Mtime = utils.NewStringValue(commontime.Time2String(rateLimit.ModifyTime)) rule.Disable = utils.NewBoolValue(rateLimit.Disable) + rule.Metadata = rateLimit.Metadata if rateLimit.EnableTime.Year() > 2000 { rule.Etime = utils.NewStringValue(commontime.Time2String(rateLimit.EnableTime)) } else { @@ -528,6 +418,7 @@ func rateLimit2Client( rule.Priority = utils.NewUInt32Value(rateLimit.Priority) rule.Revision = utils.NewStringValue(rateLimit.Revision) rule.Disable = utils.NewBoolValue(rateLimit.Disable) + rule.Metadata = rateLimit.Metadata copyRateLimitProto(rateLimit, rule) return rule, nil } diff --git a/service/ratelimit_config_test.go b/service/ratelimit_config_test.go index 0e93d0653..a004c963d 100644 --- a/service/ratelimit_config_test.go +++ b/service/ratelimit_config_test.go @@ -331,7 +331,7 @@ func TestDeleteRateLimit(t *testing.T) { }() resp := discoverSuit.DiscoverServer().DeleteRateLimits(discoverSuit.DefaultCtx, []*apitraffic.Rule{rateLimitReq}) - assert.True(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) }) t.Run("并发删除限流规则,可以正常删除", func(t *testing.T) { diff --git a/service/routing_config_v1.go b/service/routing_config_v1.go index 6457f8dbd..d8a29fa2e 100644 --- a/service/routing_config_v1.go +++ b/service/routing_config_v1.go @@ -19,20 +19,12 @@ package service import ( "context" - "encoding/json" - "fmt" - "time" - "github.com/gogo/protobuf/jsonpb" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - commonstore "github.com/polarismesh/polaris/common/store" - commontime "github.com/polarismesh/polaris/common/time" - "github.com/polarismesh/polaris/common/utils" ) var ( @@ -47,15 +39,10 @@ var ( // CreateRoutingConfigs Create a routing configuration func (s *Server) CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfig(req); err != nil { - return err - } - resp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { - api.Collect(resp, s.createRoutingConfigV1toV2(ctx, entry)) + api.Collect(resp, s.CreateRoutingConfig(ctx, entry)) } - return api.FormatBatchWriteResponse(resp) } @@ -63,92 +50,32 @@ func (s *Server) CreateRoutingConfigs(ctx context.Context, req []*apitraffic.Rou // services to prevent the service from being deleted // Deprecated: This method is ready to abandon func (s *Server) CreateRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - if resp := checkRoutingConfig(req); resp != nil { - return resp - } - - serviceName := req.GetService().GetValue() - namespaceName := req.GetNamespace().GetValue() - service, errResp := s.loadService(namespaceName, serviceName) - if errResp != nil { - log.Error(errResp.GetInfo().GetValue(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewRoutingResponse(apimodel.Code(errResp.GetCode().GetValue()), req) - } - if service == nil { - return api.NewRoutingResponse(apimodel.Code_NotFoundService, req) - } - if service.IsAlias() { - return api.NewRoutingResponse(apimodel.Code_NotAllowAliasCreateRouting, req) - } - - routingConfig, err := s.storage.GetRoutingConfigWithService(service.Name, service.Namespace) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - if routingConfig != nil { - return api.NewRoutingResponse(apimodel.Code_ExistedResource, req) - } - - conf, err := api2RoutingConfig(service.ID, req) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewRoutingResponse(apimodel.Code_ExecuteException, req) - } - if err := s.storage.CreateRoutingConfig(conf); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return wrapperRoutingStoreResponse(req, err) - } - - s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, conf, model.OCreate)) - return api.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) + resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") + return resps } // DeleteRoutingConfigs Batch delete routing configuration func (s *Server) DeleteRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfig(req); err != nil { - return err - } - out := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.DeleteRoutingConfig(ctx, entry) api.Collect(out, resp) } - return api.FormatBatchWriteResponse(out) } // DeleteRoutingConfig Delete a routing configuration // Deprecated: This method is ready to abandon func (s *Server) DeleteRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - service, resp := s.routingConfigCommonCheck(ctx, req) - if resp != nil { - return resp - } - - if err := s.storage.DeleteRoutingConfig(service.ID); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return wrapperRoutingStoreResponse(req, err) - } - - s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, nil, model.ODelete)) - return api.NewResponse(apimodel.Code_ExecuteSuccess) + resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") + return resps } // UpdateRoutingConfigs Batch update routing configuration func (s *Server) UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Routing) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfig(req); err != nil { - return err - } - out := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { - resp := s.updateRoutingConfigV1toV2(ctx, entry) + resp := s.UpdateRoutingConfig(ctx, entry) api.Collect(out, resp) } @@ -158,267 +85,14 @@ func (s *Server) UpdateRoutingConfigs(ctx context.Context, req []*apitraffic.Rou // UpdateRoutingConfig Update a routing configuration // Deprecated: 该方法准备舍弃 func (s *Server) UpdateRoutingConfig(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - service, resp := s.routingConfigCommonCheck(ctx, req) - if resp != nil { - return resp - } - - conf, err := s.storage.GetRoutingConfigWithService(service.Name, service.Namespace) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - if conf == nil { - return api.NewRoutingResponse(apimodel.Code_NotFoundRouting, req) - } - - reqModel, err := api2RoutingConfig(service.ID, req) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewRoutingResponse(apimodel.Code_ParseRoutingException, req) - } - - if err := s.storage.UpdateRoutingConfig(reqModel); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return wrapperRoutingStoreResponse(req, err) - } - - s.RecordHistory(ctx, routingRecordEntry(ctx, req, service, reqModel, model.OUpdate)) - return api.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) + resps := api.NewResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") + return resps } // GetRoutingConfigs Get the routing configuration in batches, and provide the interface of // the query routing configuration to the OSS // Deprecated: This method is ready to abandon func (s *Server) GetRoutingConfigs(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - - filter := make(map[string]string) - for key, value := range query { - if _, ok := RoutingConfigFilterAttrs[key]; !ok { - log.Errorf("[Server][RoutingConfig][Query] attribute(%s) is not allowed", key) - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } - filter[key] = value - } - // service -- > name This special treatment - if service, ok := filter["service"]; ok { - filter["name"] = service - delete(filter, "service") - } - - // Can be filtered according to name and namespace - total, routings, err := s.storage.GetRoutingConfigs(filter, offset, limit) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) - } - - resp := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) - resp.Amount = utils.NewUInt32Value(total) - resp.Size = utils.NewUInt32Value(uint32(len(routings))) - resp.Routings = make([]*apitraffic.Routing, 0, len(routings)) - for _, entry := range routings { - routing, err := routingConfig2API(entry.Config, entry.ServiceName, entry.NamespaceName) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewBatchQueryResponse(apimodel.Code_ParseRoutingException) - } - resp.Routings = append(resp.Routings, routing) - } - - return resp -} - -// routingConfigCommonCheck Public examination of routing configuration operation -func (s *Server) routingConfigCommonCheck( - ctx context.Context, req *apitraffic.Routing) (*model.Service, *apiservice.Response) { - if resp := checkRoutingConfig(req); resp != nil { - return nil, resp - } - - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - serviceName := req.GetService().GetValue() - namespaceName := req.GetNamespace().GetValue() - - service, err := s.storage.GetService(serviceName, namespaceName) - if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return nil, api.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - if service == nil { - return nil, api.NewRoutingResponse(apimodel.Code_NotFoundService, req) - } - - return service, nil -} - -// checkRoutingConfig Check the validity of the basic parameter of the routing configuration -func checkRoutingConfig(req *apitraffic.Routing) *apiservice.Response { - if req == nil { - return api.NewRoutingResponse(apimodel.Code_EmptyRequest, req) - } - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewRoutingResponse(apimodel.Code_InvalidServiceName, req) - } - - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewRoutingResponse(apimodel.Code_InvalidNamespaceName, req) - } - - if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { - return api.NewRoutingResponse(apimodel.Code_InvalidServiceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewRoutingResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := utils.CheckDbStrFieldLen(req.GetServiceToken(), MaxDbServiceToken); err != nil { - return api.NewRoutingResponse(apimodel.Code_InvalidServiceToken, req) - } - - return nil -} - -// parseServiceRoutingToken Get token from RoutingConfig request parameters -func parseServiceRoutingToken(ctx context.Context, req *apitraffic.Routing) string { - if reqToken := req.GetServiceToken().GetValue(); reqToken != "" { - return reqToken - } - - return utils.ParseToken(ctx) -} - -// api2RoutingConfig Convert the API parameter to internal data structure -func api2RoutingConfig(serviceID string, req *apitraffic.Routing) (*model.RoutingConfig, error) { - inBounds, outBounds, err := marshalRoutingConfig(req.GetInbounds(), req.GetOutbounds()) - if err != nil { - return nil, err - } - - out := &model.RoutingConfig{ - ID: serviceID, - InBounds: string(inBounds), - OutBounds: string(outBounds), - Revision: utils.NewUUID(), - } - - return out, nil -} - -// routingConfig2API Convert the internal data structure to API parameter to pass out -func routingConfig2API(req *model.RoutingConfig, service string, namespace string) (*apitraffic.Routing, error) { - if req == nil { - return nil, nil - } - - out := &apitraffic.Routing{ - Service: utils.NewStringValue(service), - Namespace: utils.NewStringValue(namespace), - Revision: utils.NewStringValue(req.Revision), - Ctime: utils.NewStringValue(commontime.Time2String(req.CreateTime)), - Mtime: utils.NewStringValue(commontime.Time2String(req.ModifyTime)), - } - - if req.InBounds != "" { - var inBounds []*apitraffic.Route - if err := json.Unmarshal([]byte(req.InBounds), &inBounds); err != nil { - return nil, err - } - out.Inbounds = inBounds - } - if req.OutBounds != "" { - var outBounds []*apitraffic.Route - if err := json.Unmarshal([]byte(req.OutBounds), &outBounds); err != nil { - return nil, err - } - out.Outbounds = outBounds - } - - return out, nil -} - -// marshalRoutingConfig Formulate Inbounds and OUTBOUNDS -func marshalRoutingConfig(in []*apitraffic.Route, out []*apitraffic.Route) ([]byte, []byte, error) { - inBounds, err := json.Marshal(in) - if err != nil { - return nil, nil, err - } - - outBounds, err := json.Marshal(out) - if err != nil { - return nil, nil, err - } - - return inBounds, outBounds, nil -} - -// checkBatchRoutingConfig Check batch request -func checkBatchRoutingConfig(req []*apitraffic.Routing) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// routingRecordEntry Construction of RoutingConfig's record Entry -func routingRecordEntry(ctx context.Context, req *apitraffic.Routing, svc *model.Service, md *model.RoutingConfig, - opt model.OperationType) *model.RecordEntry { - - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - - entry := &model.RecordEntry{ - ResourceType: model.RRouting, - ResourceName: fmt.Sprintf("%s(%s)", svc.Name, svc.ID), - Namespace: svc.Namespace, - OperationType: opt, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - - return entry -} - -// routingV2RecordEntry Construction of RoutingConfig's record Entry -func routingV2RecordEntry(ctx context.Context, req *apitraffic.RouteRule, md *model.RouterConfig, - opt model.OperationType) *model.RecordEntry { - - marshaler := jsonpb.Marshaler{} - detail, _ := marshaler.MarshalToString(req) - - entry := &model.RecordEntry{ - ResourceType: model.RRouting, - ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), - Namespace: req.GetNamespace(), - OperationType: opt, - Operator: utils.ParseOperator(ctx), - Detail: detail, - HappenTime: time.Now(), - } - return entry -} - -// wrapperRoutingStoreResponse Packing routing storage layer error -func wrapperRoutingStoreResponse(routing *apitraffic.Routing, err error) *apiservice.Response { - if err == nil { - return nil - } - resp := api.NewResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) - resp.Routing = routing - return resp + resps := api.NewBatchQueryResponseWithMsg(apimodel.Code_BadRequest, "API is Deprecated") + return resps } diff --git a/service/routing_config_v1_test.go b/service/routing_config_v1_test.go index a38857670..2595a0022 100644 --- a/service/routing_config_v1_test.go +++ b/service/routing_config_v1_test.go @@ -24,7 +24,6 @@ import ( "testing" "github.com/golang/protobuf/ptypes" - "github.com/polarismesh/specification/source/go/api/v1/model" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" @@ -92,161 +91,8 @@ func checkSameRoutingConfig(t *testing.T, lhs *apitraffic.Routing, rhs *apitraff checkFunc("Outbounds", lhs.Outbounds, rhs.Outbounds) } -// 测试创建路由配置 -func TestCreateRoutingConfig(t *testing.T) { - t.Run("正常创建路由配置配置请求", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - - defer discoverSuit.Destroy() - _, serviceResp := discoverSuit.createCommonService(t, 200) - defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _, _ = discoverSuit.createCommonRoutingConfig(t, serviceResp, 3, 0) - - // 对写进去的数据进行查询 - _ = discoverSuit.CacheMgr().TestUpdate() - out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) - defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - }) - - t.Run("参数缺失,报错", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - defer discoverSuit.Destroy() - - _, serviceResp := discoverSuit.createCommonService(t, 20) - defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - - req := &apitraffic.Routing{} - resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - assert.False(t, respSuccess(resp)) - t.Logf("%s", resp.GetInfo().GetValue()) - - req.Service = serviceResp.Name - resp = discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - assert.False(t, respSuccess(resp)) - t.Logf("%s", resp.GetInfo().GetValue()) - - req.Namespace = serviceResp.Namespace - resp = discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - defer discoverSuit.cleanCommonRoutingConfig(req.GetService().GetValue(), req.GetNamespace().GetValue()) - assert.True(t, respSuccess(resp)) - t.Logf("%s", resp.GetInfo().GetValue()) - }) - - t.Run("服务不存在,创建路由配置不报错", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - defer discoverSuit.Destroy() - - _, serviceResp := discoverSuit.createCommonService(t, 120) - discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - - _ = discoverSuit.CacheMgr().TestUpdate() - req := &apitraffic.Routing{} - req.Service = serviceResp.Name - req.Namespace = serviceResp.Namespace - req.ServiceToken = serviceResp.Token - resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - assert.False(t, respSuccess(resp)) - t.Logf("%s", resp.GetInfo().GetValue()) - }) -} - -// 测试创建路由配置 -func TestUpdateRoutingConfig(t *testing.T) { - t.Run("更新V1路由规则, 成功转为V2规则", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - - _, svc := discoverSuit.createCommonService(t, 200) - v1Rule, _ := discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) - t.Cleanup(func() { - discoverSuit.cleanServiceName(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.truncateCommonRoutingConfigV2() - discoverSuit.Destroy() - }) - - v1Rule.Outbounds = v1Rule.Inbounds - uResp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{v1Rule}) - assert.True(t, respSuccess(uResp)) - - // 等缓存层更新 - _ = discoverSuit.CacheMgr().TestUpdate() - - // 直接查询存储无法查询到 v1 的路由规则 - total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) - assert.NoError(t, err, err) - assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") - assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") - - // 从缓存中查询应该查到 6 条 v2 的路由规则 - out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(6), int(out.GetAmount().GetValue()), "query routing size") - rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - for i := range rulesV2 { - item := rulesV2[i] - assert.True(t, item.Enable, "v1 to v2 need default open enable") - msg := &apitraffic.RuleRoutingConfig{} - err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) - assert.NoError(t, err) - assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") - assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") - assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") - } - }) -} - // 测试缓存获取路由配置 func TestGetRoutingConfigWithCache(t *testing.T) { - - t.Run("多个服务的,多个路由配置,都可以查询到", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - defer discoverSuit.Destroy() - - total := 20 - serviceResps := make([]*apiservice.Service, 0, total) - routingResps := make([]*apitraffic.Routing, 0, total) - for i := 0; i < total; i++ { - _, resp := discoverSuit.createCommonService(t, i) - defer discoverSuit.cleanServiceName(resp.GetName().GetValue(), resp.GetNamespace().GetValue()) - serviceResps = append(serviceResps, resp) - - _, routingResp := discoverSuit.createCommonRoutingConfig(t, resp, 2, 0) - defer discoverSuit.cleanCommonRoutingConfig(resp.GetName().GetValue(), resp.GetNamespace().GetValue()) - routingResps = append(routingResps, routingResp) - } - - _ = discoverSuit.CacheMgr().TestUpdate() - for i := 0; i < total; i++ { - t.Logf("service : name=%s namespace=%s", serviceResps[i].GetName().GetValue(), serviceResps[i].GetNamespace().GetValue()) - out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResps[i]) - checkSameRoutingConfig(t, routingResps[i], out.GetRouting()) - } - }) - t.Run("走v2接口创建路由规则,不启用查不到,启用可以查到", func(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -402,31 +248,6 @@ func TestGetRoutingConfigWithCache(t *testing.T) { assert.True(t, len(out.GetRouting().GetOutbounds()) == 0, "inBounds must be zero") }) - - t.Run("服务路由数据不改变,传递了路由revision,不返回数据", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - defer discoverSuit.Destroy() - - _, serviceResp := discoverSuit.createCommonService(t, 10) - defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - - _, routingResp := discoverSuit.createCommonRoutingConfig(t, serviceResp, 2, 0) - defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - - _ = discoverSuit.CacheMgr().TestUpdate() - firstResp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) - checkSameRoutingConfig(t, routingResp, firstResp.GetRouting()) - - serviceResp.Revision = firstResp.Service.Revision - secondResp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) - if secondResp.GetService().GetRevision().GetValue() != serviceResp.GetRevision().GetValue() { - t.Fatalf("error") - } - assert.Equal(t, model.Code(secondResp.GetCode().GetValue()), apimodel.Code_DataNoChange) - }) t.Run("路由不存在,不会出异常", func(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -443,8 +264,8 @@ func TestGetRoutingConfigWithCache(t *testing.T) { }) } -// test对routing字段进行校验 -func TestCheckRoutingFieldLen(t *testing.T) { +// Test_RouteRule_V1_Server +func Test_RouteRule_V1_Server(t *testing.T) { discoverSuit := &DiscoverTestSuit{} if err := discoverSuit.Initialize(); err != nil { @@ -452,40 +273,29 @@ func TestCheckRoutingFieldLen(t *testing.T) { } defer discoverSuit.Destroy() - req := &apitraffic.Routing{ - ServiceToken: utils.NewStringValue("test"), - Service: utils.NewStringValue("test"), - Namespace: utils.NewStringValue("default"), - } + t.Run("Create", func(t *testing.T) { + rsp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ + &apitraffic.Routing{}, + }) + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) - t.Run("创建路由规则,服务名超长", func(t *testing.T) { - str := genSpecialStr(129) - oldName := req.Service - req.Service = utils.NewStringValue(str) - resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - req.Service = oldName - if resp.Code.Value != api.InvalidServiceName { - t.Fatalf("%+v", resp) - } + t.Run("Update", func(t *testing.T) { + rsp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ + &apitraffic.Routing{}, + }) + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) }) - t.Run("创建路由规则,命名空间超长", func(t *testing.T) { - str := genSpecialStr(129) - oldNamespace := req.Namespace - req.Namespace = utils.NewStringValue(str) - resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - req.Namespace = oldNamespace - if resp.Code.Value != api.InvalidNamespaceName { - t.Fatalf("%+v", resp) - } + + t.Run("Delete", func(t *testing.T) { + rsp := discoverSuit.DiscoverServer().UpdateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{ + &apitraffic.Routing{}, + }) + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) }) - t.Run("创建路由规则,toeken超长", func(t *testing.T) { - str := genSpecialStr(2049) - oldServiceToken := req.ServiceToken - req.ServiceToken = utils.NewStringValue(str) - resp := discoverSuit.DiscoverServer().CreateRoutingConfigs(discoverSuit.DefaultCtx, []*apitraffic.Routing{req}) - req.ServiceToken = oldServiceToken - if resp.Code.Value != api.InvalidServiceToken { - t.Fatalf("%+v", resp) - } + + t.Run("Get", func(t *testing.T) { + rsp := discoverSuit.DiscoverServer().GetRoutingConfigs(discoverSuit.DefaultCtx, map[string]string{}) + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) }) } diff --git a/service/routing_config_v1tov2.go b/service/routing_config_v1tov2.go deleted file mode 100644 index bebcfe146..000000000 --- a/service/routing_config_v1tov2.go +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package service - -import ( - "context" - - apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - "go.uber.org/zap" - - apiv1 "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" - commonstore "github.com/polarismesh/polaris/common/store" - "github.com/polarismesh/polaris/common/utils" -) - -// createRoutingConfigV1toV2 Compatible with V1 version of the creation routing rules, convert V1 to V2 for storage -func (s *Server) createRoutingConfigV1toV2(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - if resp := checkRoutingConfig(req); resp != nil { - return resp - } - - serviceName := req.GetService().GetValue() - namespaceName := req.GetNamespace().GetValue() - svc, errResp := s.loadService(namespaceName, serviceName) - if errResp != nil { - log.Error("[Service][Routing] get read lock for service", zap.String("service", serviceName), - zap.String("namespace", namespaceName), utils.RequestID(ctx), zap.Any("err", errResp)) - return apiv1.NewRoutingResponse(apimodel.Code(errResp.GetCode().GetValue()), req) - } - if svc == nil { - return apiv1.NewRoutingResponse(apimodel.Code_NotFoundService, req) - } - if svc.IsAlias() { - return apiv1.NewRoutingResponse(apimodel.Code_NotAllowAliasCreateRouting, req) - } - - inDatas, outDatas, resp := batchBuildV2Routings(req) - if resp != nil { - return resp - } - - resp = s.saveRoutingV1toV2(ctx, svc.ID, inDatas, outDatas) - if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { - return resp - } - - return apiv1.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) -} - -// updateRoutingConfigV1toV2 Compatible with V1 version update routing rules, convert the data of V1 to V2 for storage -// Once the V1 rule is converted to V2 rules, the original V1 rules will be removed from storage -func (s *Server) updateRoutingConfigV1toV2(ctx context.Context, req *apitraffic.Routing) *apiservice.Response { - svc, resp := s.routingConfigCommonCheck(ctx, req) - if resp != nil { - return resp - } - - serviceTx, err := s.storage.CreateTransaction() - if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) - return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - // Release the lock for the service - defer func() { - _ = serviceTx.Commit() - }() - - // Need to prohibit the concurrent modification of the V1 rules - if _, err = serviceTx.LockService(svc.Name, svc.Namespace); err != nil { - log.Error("[Service][Routing] get service x-lock", zap.String("service", svc.Name), - zap.String("namespace", svc.Namespace), utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - - conf, err := s.storage.GetRoutingConfigWithService(svc.Name, svc.Namespace) - if err != nil { - log.Error(err.Error(), utils.RequestID(ctx)) - return apiv1.NewRoutingResponse(commonstore.StoreCode2APICode(err), req) - } - if conf == nil { - return apiv1.NewRoutingResponse(apimodel.Code_NotFoundRouting, req) - } - - inDatas, outDatas, resp := batchBuildV2Routings(req) - if resp != nil { - return resp - } - - if resp := s.saveRoutingV1toV2(ctx, svc.ID, inDatas, outDatas); resp.GetCode().GetValue() != uint32( - apimodel.Code_ExecuteSuccess) { - return resp - } - - return apiv1.NewRoutingResponse(apimodel.Code_ExecuteSuccess, req) -} - -// saveRoutingV1toV2 Convert the V1 rules of the target to V2 rule -func (s *Server) saveRoutingV1toV2(ctx context.Context, svcId string, - inRules, outRules []*apitraffic.RouteRule) *apiservice.Response { - tx, err := s.storage.StartTx() - if err != nil { - log.Error("[Service][Routing] create routing v2 from v1 open tx", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) - } - defer func() { - _ = tx.Rollback() - }() - - // Need to delete the routing rules of V1 first - if err := s.storage.DeleteRoutingConfigTx(tx, svcId); err != nil { - log.Error("[Service][Routing] clean routing v1 from store", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) - } - - saveOperation := func(routings []*apitraffic.RouteRule) *apiservice.Response { - priorityMax := 0 - for i := range routings { - item := routings[i] - if item.Id == "" { - item.Id = utils.NewRoutingV2UUID() - } - item.Revision = utils.NewV2Revision() - data := &model.RouterConfig{} - if err := data.ParseRouteRuleFromAPI(item); err != nil { - return apiv1.NewResponse(apimodel.Code_ExecuteException) - } - - data.Valid = true - data.Enable = true - if priorityMax > 10 { - priorityMax = 10 - } - - data.Priority = uint32(priorityMax) - priorityMax++ - - if err := s.storage.CreateRoutingConfigV2Tx(tx, data); err != nil { - log.Error("[Routing][V2] create routing v2 from v1 into store", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) - } - s.RecordHistory(ctx, routingV2RecordEntry(ctx, item, data, model.OCreate)) - } - - return nil - } - - if resp := saveOperation(inRules); resp != nil { - return resp - } - if resp := saveOperation(outRules); resp != nil { - return resp - } - - if err := tx.Commit(); err != nil { - log.Error("[Service][Routing] create routing v2 from v1 commit", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(apimodel.Code_ExecuteException) - } - - return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) -} - -func batchBuildV2Routings( - req *apitraffic.Routing) ([]*apitraffic.RouteRule, []*apitraffic.RouteRule, *apiservice.Response) { - inBounds := req.GetInbounds() - outBounds := req.GetOutbounds() - inRoutings := make([]*apitraffic.RouteRule, 0, len(inBounds)) - outRoutings := make([]*apitraffic.RouteRule, 0, len(outBounds)) - for i := range inBounds { - routing, err := model.BuildV2RoutingFromV1Route(req, inBounds[i]) - if err != nil { - return nil, nil, apiv1.NewResponse(apimodel.Code_ExecuteException) - } - routing.Name = req.GetNamespace().GetValue() + "." + req.GetService().GetValue() - inRoutings = append(inRoutings, routing) - } - - for i := range outBounds { - routing, err := model.BuildV2RoutingFromV1Route(req, outBounds[i]) - if err != nil { - return nil, nil, apiv1.NewResponse(apimodel.Code_ExecuteException) - } - outRoutings = append(outRoutings, routing) - } - - return inRoutings, outRoutings, nil -} diff --git a/service/routing_config_v2.go b/service/routing_config_v2.go index 7e8040a1d..687354855 100644 --- a/service/routing_config_v2.go +++ b/service/routing_config_v2.go @@ -19,12 +19,16 @@ package service import ( "context" + "fmt" "strconv" + "time" + "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/wrappers" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" "go.uber.org/zap" @@ -32,36 +36,14 @@ import ( cachetypes "github.com/polarismesh/polaris/cache/api" apiv1 "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" ) -var ( - // RoutingConfigV2FilterAttrs router config filter attrs - RoutingConfigV2FilterAttrs = map[string]bool{ - "id": true, - "name": true, - "service": true, - "namespace": true, - "source_service": true, - "destination_service": true, - "source_namespace": true, - "destination_namespace": true, - "enable": true, - "offset": true, - "limit": true, - "order_field": true, - "order_type": true, - } -) - // CreateRoutingConfigsV2 Create a routing configuration func (s *Server) CreateRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - resp := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { apiv1.Collect(resp, s.createRoutingConfigV2(ctx, entry)) @@ -72,10 +54,6 @@ func (s *Server) CreateRoutingConfigsV2( // createRoutingConfigV2 Create a routing configuration func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { - if resp := checkRoutingConfigV2(req); resp != nil { - return resp - } - conf, err := Api2RoutingConfigV2(req) if err != nil { log.Error("[Routing][V2] parse routing config v2 from request for create", @@ -89,8 +67,11 @@ func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.Rout return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, conf, model.OCreate)) - + s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, conf, model.OCreate)) + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId(), + Type: security.ResourceType_RouteRules, + }, false) req.Id = conf.ID return apiv1.NewRouterResponse(apimodel.Code_ExecuteSuccess, req) } @@ -98,10 +79,6 @@ func (s *Server) createRoutingConfigV2(ctx context.Context, req *apitraffic.Rout // DeleteRoutingConfigsV2 Batch delete routing configuration func (s *Server) DeleteRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.deleteRoutingConfigV2(ctx, entry) @@ -113,38 +90,27 @@ func (s *Server) DeleteRoutingConfigsV2( // DeleteRoutingConfigV2 Delete a routing configuration func (s *Server) deleteRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { - if resp := checkRoutingConfigIDV2(req); resp != nil { - return resp - } - - // Determine whether the current routing rules are only converted from the memory transmission in the V1 version - if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { - resp := s.transferV1toV2OnModify(ctx, req) - if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { - return resp - } - } - if err := s.storage.DeleteRoutingConfigV2(req.Id); err != nil { log.Error("[Routing][V2] delete routing config v2 store layer", utils.RequestID(ctx), zap.Error(err)) return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, &model.RouterConfig{ + s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, &model.RouterConfig{ ID: req.GetId(), Name: req.GetName(), }, model.ODelete)) + + _ = s.afterRuleResource(ctx, model.RRouting, authcommon.ResourceEntry{ + ID: req.GetId(), + Type: security.ResourceType_RouteRules, + }, true) return apiv1.NewRouterResponse(apimodel.Code_ExecuteSuccess, req) } // UpdateRoutingConfigsV2 Batch update routing configuration func (s *Server) UpdateRoutingConfigsV2( ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - if err := checkBatchRoutingConfigV2(req); err != nil { - return err - } - out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range req { resp := s.updateRoutingConfigV2(ctx, entry) @@ -156,21 +122,6 @@ func (s *Server) UpdateRoutingConfigsV2( // updateRoutingConfigV2 Update a single routing configuration func (s *Server) updateRoutingConfigV2(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { - // If V2 routing rules to be modified are from the V1 rule in the cache, need to do the following steps first - // step 1: Turn the V1 rule to the real V2 rule - // step 2: Find the corresponding route to the V2 rules to be modified in the V1 rules, set their rules ID - // step 3: Store persistence - if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { - resp := s.transferV1toV2OnModify(ctx, req) - if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { - return resp - } - } - - if resp := checkUpdateRoutingConfigV2(req); resp != nil { - return resp - } - // Check whether the routing configuration exists conf, err := s.storage.GetRoutingConfigV2WithID(req.Id) if err != nil { @@ -196,7 +147,7 @@ func (s *Server) updateRoutingConfigV2(ctx context.Context, req *apitraffic.Rout return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, reqModel, model.OUpdate)) + s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, reqModel, model.OUpdate)) return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) } @@ -227,6 +178,11 @@ func (s *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]str return resp } +// GetAllRouterRules Query all router_rule rules +func (s *Server) GetAllRouterRules(ctx context.Context) *apiservice.BatchQueryResponse { + return nil +} + // EnableRoutings batch enable routing rules func (s *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { out := apiv1.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) @@ -239,17 +195,6 @@ func (s *Server) EnableRoutings(ctx context.Context, req []*apitraffic.RouteRule } func (s *Server) enableRoutings(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { - if resp := checkRoutingConfigIDV2(req); resp != nil { - return resp - } - - if _, ok := s.Cache().RoutingConfig().IsConvertFromV1(req.Id); ok { - resp := s.transferV1toV2OnModify(ctx, req) - if resp.GetCode().GetValue() != uint32(apimodel.Code_ExecuteSuccess) { - return resp - } - } - conf, err := s.storage.GetRoutingConfigV2WithID(req.Id) if err != nil { log.Error("[Routing][V2] get routing config v2 store layer", @@ -269,83 +214,13 @@ func (s *Server) enableRoutings(ctx context.Context, req *apitraffic.RouteRule) return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) } - s.RecordHistory(ctx, routingV2RecordEntry(ctx, req, conf, model.OUpdate)) - return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) -} - -// transferV1toV2OnModify When enabled or prohibited for the V2 rules, the V1 rules need to be converted to V2 rules -// and execute persistent storage -func (s *Server) transferV1toV2OnModify(ctx context.Context, req *apitraffic.RouteRule) *apiservice.Response { - svcId, _ := s.Cache().RoutingConfig().IsConvertFromV1(req.Id) - v1conf, err := s.storage.GetRoutingConfigWithID(svcId) - if err != nil { - log.Error("[Routing][V2] get routing config v1 store layer", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(commonstore.StoreCode2APICode(err)) - } - if v1conf != nil { - svc, err := s.loadServiceByID(svcId) - if svc == nil { - log.Error("[Routing][V2] convert routing config v1 to v2 find svc", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(apimodel.Code_NotFoundService) - } - - inV2, outV2, err := model.ConvertRoutingV1ToExtendV2(svc.Name, svc.Namespace, v1conf) - if err != nil { - log.Error("[Routing][V2] convert routing config v1 to v2", - utils.RequestID(ctx), zap.Error(err)) - return apiv1.NewResponse(apimodel.Code_ExecuteException) - } - - formatApi := func(rules []*model.ExtendRouterConfig) ([]*apitraffic.RouteRule, *apiservice.Response) { - ret := make([]*apitraffic.RouteRule, 0, len(rules)) - for i := range rules { - item, err := rules[i].ToApi() - if err != nil { - log.Error("[Routing][V2] convert routing config v1 to v2, format v2 to api", - utils.RequestID(ctx), zap.Error(err)) - return nil, apiv1.NewResponse(apimodel.Code_ExecuteException) - } - ret = append(ret, item) - } - - return ret, nil - } - - inDatas, resp := formatApi(inV2) - if resp != nil { - return resp - } - outDatas, resp := formatApi(outV2) - if resp != nil { - return resp - } - - if resp := s.saveRoutingV1toV2(ctx, svcId, inDatas, outDatas); resp.GetCode().GetValue() != apiv1.ExecuteSuccess { - return apiv1.NewResponse(apimodel.Code(resp.GetCode().GetValue())) - } - } - + s.RecordHistory(ctx, routeRuleRecordEntry(ctx, req, conf, model.OUpdate)) return apiv1.NewResponse(apimodel.Code_ExecuteSuccess) } // parseServiceArgs The query conditions of the analysis service -func parseRoutingArgs(query map[string]string, ctx context.Context) (*cachetypes.RoutingArgs, *apiservice.Response) { - offset, limit, err := utils.ParseOffsetAndLimit(query) - if err != nil { - return nil, apiv1.NewResponse(apimodel.Code_InvalidParameter) - } - - filter := make(map[string]string) - for key, value := range query { - if _, ok := RoutingConfigV2FilterAttrs[key]; !ok { - log.Errorf("[Routing][V2][Query] attribute(%s) is not allowed", key) - return nil, apiv1.NewResponse(apimodel.Code_InvalidParameter) - } - filter[key] = value - } - +func parseRoutingArgs(filter map[string]string, ctx context.Context) (*cachetypes.RoutingArgs, *apiservice.Response) { + offset, limit, _ := utils.ParseOffsetAndLimit(filter) res := &cachetypes.RoutingArgs{ Filter: filter, Name: filter["name"], @@ -379,120 +254,6 @@ func parseRoutingArgs(query map[string]string, ctx context.Context) (*cachetypes return res, nil } -// checkBatchRoutingConfig Check batch request -func checkBatchRoutingConfigV2(req []*apitraffic.RouteRule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return apiv1.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return apiv1.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// checkRoutingConfig Check the validity of the basic parameter of the routing configuration -func checkRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if err := checkRoutingNameAndNamespace(req); err != nil { - return err - } - - if err := checkRoutingConfigPriorityV2(req); err != nil { - return err - } - - if err := checkRoutingPolicyV2(req); err != nil { - return err - } - - return nil -} - -// checkUpdateRoutingConfigV2 Check the validity of the basic parameter of the routing configuration -func checkUpdateRoutingConfigV2(req *apitraffic.RouteRule) *apiservice.Response { - if resp := checkRoutingConfigIDV2(req); resp != nil { - return resp - } - - if err := checkRoutingNameAndNamespace(req); err != nil { - return err - } - - if err := checkRoutingConfigPriorityV2(req); err != nil { - return err - } - - if err := checkRoutingPolicyV2(req); err != nil { - return err - } - - return nil -} - -func checkRoutingNameAndNamespace(req *apitraffic.RouteRule) *apiservice.Response { - if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetName()), MaxDbRoutingName); err != nil { - return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingName, req) - } - - if err := utils.CheckDbStrFieldLen(utils.NewStringValue(req.GetNamespace()), - MaxDbServiceNamespaceLength); err != nil { - return apiv1.NewRouterResponse(apimodel.Code_InvalidNamespaceName, req) - } - - return nil -} - -func checkRoutingConfigIDV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.Id == "" { - return apiv1.NewResponse(apimodel.Code_InvalidRoutingID) - } - - return nil -} - -func checkRoutingConfigPriorityV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.Priority > 10 { - return apiv1.NewResponse(apimodel.Code_InvalidRoutingPriority) - } - - return nil -} - -func checkRoutingPolicyV2(req *apitraffic.RouteRule) *apiservice.Response { - if req == nil { - return apiv1.NewRouterResponse(apimodel.Code_EmptyRequest, req) - } - - if req.GetRoutingPolicy() != apitraffic.RoutingPolicy_RulePolicy { - return apiv1.NewRouterResponse(apimodel.Code_InvalidRoutingPolicy, req) - } - - // Automatically supplement @Type attribute according to Policy - if req.RoutingConfig.TypeUrl == "" { - if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_RulePolicy { - req.RoutingConfig.TypeUrl = model.RuleRoutingTypeUrl - } - if req.GetRoutingPolicy() == apitraffic.RoutingPolicy_MetadataPolicy { - req.RoutingConfig.TypeUrl = model.MetaRoutingTypeUrl - } - } - - return nil -} - // Api2RoutingConfigV2 Convert the API parameter to internal data structure func Api2RoutingConfigV2(req *apitraffic.RouteRule) (*model.RouterConfig, error) { out := &model.RouterConfig{ @@ -531,3 +292,22 @@ func marshalRoutingV2toAnySlice(routings []*model.ExtendRouterConfig) ([]*any.An return ret, nil } + +// routeRuleRecordEntry Construction of RoutingConfig's record Entry +func routeRuleRecordEntry(ctx context.Context, req *apitraffic.RouteRule, md *model.RouterConfig, + opt model.OperationType) *model.RecordEntry { + + marshaler := jsonpb.Marshaler{} + detail, _ := marshaler.MarshalToString(req) + + entry := &model.RecordEntry{ + ResourceType: model.RRouting, + ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), + Namespace: req.GetNamespace(), + OperationType: opt, + Operator: utils.ParseOperator(ctx), + Detail: detail, + HappenTime: time.Now(), + } + return entry +} diff --git a/service/routing_config_v2_test.go b/service/routing_config_v2_test.go index 8513f3ae0..3297489df 100644 --- a/service/routing_config_v2_test.go +++ b/service/routing_config_v2_test.go @@ -120,250 +120,6 @@ func TestCreateRoutingConfigV2(t *testing.T) { }) } -// TestCompatibleRoutingConfigV2AndV1 测试V2版本的路由规则和V1版本的路由规则 -func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { - - svc := &apiservice.Service{ - Name: utils.NewStringValue("compatible-routing"), - Namespace: utils.NewStringValue("compatible"), - } - - initSuitFunc := func(t *testing.T) *DiscoverTestSuit { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - discoverSuit.Destroy() - }) - - createSvcResp := discoverSuit.DiscoverServer().CreateServices(discoverSuit.DefaultCtx, []*apiservice.Service{svc}) - if !respSuccess(createSvcResp) { - t.Fatalf("error: %s", createSvcResp.GetInfo().GetValue()) - } - - _ = createSvcResp.Responses[0].GetService() - t.Cleanup(func() { - discoverSuit.cleanServices([]*apiservice.Service{svc}) - }) - return discoverSuit - } - - t.Run("V1的存量规则-走V2接口可以查询到,ExtendInfo符合要求", func(t *testing.T) { - discoverSuit := initSuitFunc(t) - _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) - t.Cleanup(func() { - discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.truncateCommonRoutingConfigV2() - }) - - _ = discoverSuit.CacheMgr().TestUpdate() - // 从缓存中查询应该查到 3+3 条 v2 的路由规则 - out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - - rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - for i := range rulesV2 { - item := rulesV2[i] - assert.True(t, item.Enable, "v1 to v2 need default open enable") - msg := &apitraffic.RuleRoutingConfig{} - err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) - assert.NoError(t, err) - assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") - assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") - assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") - } - }) - - t.Run("V1的存量规则-走v2规则的启用可正常迁移v1规则", func(t *testing.T) { - discoverSuit := initSuitFunc(t) - _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) - t.Cleanup(func() { - discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.truncateCommonRoutingConfigV2() - }) - - _ = discoverSuit.CacheMgr().TestUpdate() - // 从缓存中查询应该查到 3 条 v2 的路由规则 - out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - - // 选择其中一条规则进行enable操作 - v2resp := discoverSuit.DiscoverServer().EnableRoutings(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) - if !respSuccess(v2resp) { - t.Fatalf("error: %+v", v2resp) - } - // 直接查询存储无法查询到 v1 的路由规则 - total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) - assert.NoError(t, err, err) - assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") - assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") - - // 从缓存中查询应该查到 3 条 v2 的路由规则 - out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - for i := range rulesV2 { - item := rulesV2[i] - assert.True(t, item.Enable, "v1 to v2 need default open enable") - msg := &apitraffic.RuleRoutingConfig{} - err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) - assert.NoError(t, err) - assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") - assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") - assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") - } - }) - - t.Run("V1的存量规则-走v2规则的删除可正常迁移v1规则", func(t *testing.T) { - discoverSuit := initSuitFunc(t) - _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) - t.Cleanup(func() { - discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.truncateCommonRoutingConfigV2() - }) - - _ = discoverSuit.CacheMgr().TestUpdate() - // 从缓存中查询应该查到 3+3 条 v2 的路由规则 - out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - - rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - - // 选择其中一条规则进行删除操作 - v2resp := discoverSuit.DiscoverServer().DeleteRoutingConfigsV2(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) - if !respSuccess(v2resp) { - t.Fatalf("error: %+v", v2resp) - } - // 直接查询存储无法查询到 v1 的路由规则 - total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) - assert.NoError(t, err, err) - assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") - assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") - - // 查询对应的 v2 规则也查询不到 - ruleV2, err := discoverSuit.Storage.GetRoutingConfigV2WithID(rulesV2[0].Id) - assert.NoError(t, err, err) - assert.Nil(t, ruleV2, "v2 routing must delete") - - _ = discoverSuit.CacheMgr().TestUpdate() - // 从缓存中查询应该查到 2 条 v2 的路由规则 - out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(2), int(out.GetAmount().GetValue()), "query routing size") - rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - for i := range rulesV2 { - item := rulesV2[i] - assert.True(t, item.Enable, "v1 to v2 need default open enable") - msg := &apitraffic.RuleRoutingConfig{} - err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) - assert.NoError(t, err) - assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") - assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") - assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") - } - }) - - t.Run("V1的存量规则-走v2规则的编辑可正常迁移v1规则", func(t *testing.T) { - discoverSuit := initSuitFunc(t) - _, _ = discoverSuit.createCommonRoutingConfigV1IntoOldStore(t, svc, 3, 0) - t.Cleanup(func() { - discoverSuit.cleanCommonRoutingConfig(svc.GetName().GetValue(), svc.GetNamespace().GetValue()) - discoverSuit.truncateCommonRoutingConfigV2() - }) - - _ = discoverSuit.CacheMgr().TestUpdate() - // 从缓存中查询应该查到 3+3 条 v2 的路由规则 - out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - - rulesV2, err := unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - - // 需要将 v2 规则的 extendInfo 规则清理掉 - // 选择其中一条规则进行enable操作 - rulesV2[0].Description = "update v2 rule and transfer v1 to v2" - v2resp := discoverSuit.DiscoverServer().UpdateRoutingConfigsV2(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{rulesV2[0]}) - if !respSuccess(v2resp) { - t.Fatalf("error: %+v", v2resp) - } - // 直接查询存储无法查询到 v1 的路由规则 - total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) - assert.NoError(t, err, err) - assert.Equal(t, uint32(0), total, "v1 routing must delete and transfer to v1") - assert.Equal(t, 0, len(routingsV1), "v1 routing ret len need zero") - - // 查询对应的 v2 规则能够查询到 - ruleV2, err := discoverSuit.Storage.GetRoutingConfigV2WithID(rulesV2[0].Id) - assert.NoError(t, err, err) - assert.NotNil(t, ruleV2, "v2 routing must exist") - assert.Equal(t, rulesV2[0].Description, ruleV2.Description) - - _ = discoverSuit.CacheMgr().TestUpdate() - out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ - "limit": "100", - "offset": "0", - }) - if !respSuccess(out) { - t.Fatalf("error: %+v", out) - } - assert.Equal(t, int(3), int(out.GetAmount().GetValue()), "query routing size") - rulesV2, err = unmarshalRoutingV2toAnySlice(out.GetData()) - assert.NoError(t, err) - for i := range rulesV2 { - item := rulesV2[i] - assert.True(t, item.Enable, "v1 to v2 need default open enable") - msg := &apitraffic.RuleRoutingConfig{} - err := ptypes.UnmarshalAny(item.GetRoutingConfig(), msg) - assert.NoError(t, err) - assert.True(t, len(msg.GetSources()) == 0, "RuleRoutingConfig.Sources len != 0") - assert.True(t, len(msg.GetDestinations()) == 0, "RuleRoutingConfig.Destinations len != 0") - assert.True(t, len(msg.GetRules()) != 0, "RuleRoutingConfig.Rules len == 0") - } - }) -} - // TestDeleteRoutingConfigV2 测试删除路由配置 func TestDeleteRoutingConfigV2(t *testing.T) { diff --git a/service/server.go b/service/server.go index 8c1cd9b3e..a2bb50a21 100644 --- a/service/server.go +++ b/service/server.go @@ -20,6 +20,7 @@ package service import ( "context" + "github.com/polarismesh/specification/source/go/api/v1/security" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "golang.org/x/sync/singleflight" @@ -27,6 +28,7 @@ import ( cacheservice "github.com/polarismesh/polaris/cache/service" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/namespace" "github.com/polarismesh/polaris/plugin" @@ -157,12 +159,62 @@ func (s *Server) getLocation(host string) *model.Location { return location } +func (s *Server) afterRuleResource(ctx context.Context, r model.Resource, res auth.ResourceEntry, remove bool) error { + event := &ResourceEvent{ + Resource: res, + IsRemove: remove, + } + + for index := range s.hooks { + hook := s.hooks[index] + if err := hook.After(ctx, r, event); err != nil { + return err + } + } + return nil +} + func (s *Server) afterServiceResource(ctx context.Context, req *apiservice.Service, save *model.Service, remove bool) error { event := &ResourceEvent{ - ReqService: req, - Service: save, - IsRemove: remove, + Resource: auth.ResourceEntry{ + Type: security.ResourceType_Services, + ID: save.ID, + Metadata: save.Meta, + }, + AddPrincipals: func() []auth.Principal { + ret := make([]auth.Principal, 0, 4) + for i := range req.UserIds { + ret = append(ret, auth.Principal{ + PrincipalType: auth.PrincipalUser, + PrincipalID: req.UserIds[i].GetValue(), + }) + } + for i := range req.GroupIds { + ret = append(ret, auth.Principal{ + PrincipalType: auth.PrincipalGroup, + PrincipalID: req.GroupIds[i].GetValue(), + }) + } + return ret + }(), + DelPrincipals: func() []auth.Principal { + ret := make([]auth.Principal, 0, 4) + for i := range req.RemoveUserIds { + ret = append(ret, auth.Principal{ + PrincipalType: auth.PrincipalUser, + PrincipalID: req.RemoveUserIds[i].GetValue(), + }) + } + for i := range req.RemoveGroupIds { + ret = append(ret, auth.Principal{ + PrincipalType: auth.PrincipalGroup, + PrincipalID: req.RemoveGroupIds[i].GetValue(), + }) + } + return ret + }(), + IsRemove: remove, } for index := range s.hooks { @@ -171,7 +223,6 @@ func (s *Server) afterServiceResource(ctx context.Context, req *apiservice.Servi return err } } - return nil } diff --git a/service/service.go b/service/service.go index e7b43d0a7..57758d57b 100644 --- a/service/service.go +++ b/service/service.go @@ -299,9 +299,9 @@ func (s *Server) GetAllServices(ctx context.Context, query map[string]string) *a ) if ns, ok := query["namespace"]; ok && len(ns) > 0 { - _, svcs = s.Cache().Service().ListServices(ns) + _, svcs = s.Cache().Service().ListServices(ctx, ns) } else { - _, svcs = s.Cache().Service().ListAllServices() + _, svcs = s.Cache().Service().ListAllServices(ctx) } ret := make([]*apiservice.Service, 0, len(svcs)) @@ -332,7 +332,7 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis inputInstMetaKeys, inputInstMetaValues string ) for key, value := range query { - typ, _ := ServiceFilterAttributes[key] + typ := ServiceFilterAttributes[key] switch { case typ == serviceFilter: serviceFilters[key] = value @@ -375,10 +375,7 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis } // 判断offset和limit是否为int,并从filters清除offset/limit参数 - offset, limit, err := utils.ParseOffsetAndLimit(serviceFilters) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } + offset, limit, _ := utils.ParseOffsetAndLimit(serviceFilters) serviceArgs := parseServiceArgs(serviceFilters, serviceMetas, ctx) total, services, err := s.caches.Service().GetServicesByFilter(ctx, serviceArgs, instanceArgs, offset, limit) @@ -756,6 +753,8 @@ func service2Api(service *model.Service) *apiservice.Service { Ctime: utils.NewStringValue(commontime.Time2String(service.CreateTime)), Mtime: utils.NewStringValue(commontime.Time2String(service.ModifyTime)), ExportTo: service.ListExportTo(), + Editable: utils.NewBoolValue(true), + Deleteable: utils.NewBoolValue(true), } return out diff --git a/service/service_alias.go b/service/service_alias.go index e95a6400f..d97363a13 100644 --- a/service/service_alias.go +++ b/service/service_alias.go @@ -142,28 +142,16 @@ func (s *Server) DeleteServiceAlias(ctx context.Context, req *apiservice.Service return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } + s.RecordHistory(ctx, serviceRecordEntry(ctx, &apiservice.Service{ + Name: req.GetAlias(), + Namespace: req.GetAliasNamespace(), + }, alias, model.ODelete)) return api.NewServiceAliasResponse(apimodel.Code_ExecuteSuccess, req) } -func checkBatchAlias(req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - // DeleteServiceAliases 删除服务别名列表 func (s *Server) DeleteServiceAliases( ctx context.Context, req []*apiservice.ServiceAlias) *apiservice.BatchWriteResponse { - if checkError := checkBatchAlias(req); checkError != nil { - return checkError - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, alias := range req { response := s.DeleteServiceAlias(ctx, alias) @@ -265,6 +253,8 @@ func (s *Server) GetServiceAliases(ctx context.Context, query map[string]string) Comment: utils.NewStringValue(entry.Comment), Ctime: utils.NewStringValue(commontime.Time2String(entry.CreateTime)), Mtime: utils.NewStringValue(commontime.Time2String(entry.ModifyTime)), + Editable: utils.NewBoolValue(true), + Deleteable: utils.NewBoolValue(true), } resp.Aliases = append(resp.Aliases, item) } diff --git a/service/service_contract.go b/service/service_contract.go index a9e9f0eab..475c90964 100644 --- a/service/service_contract.go +++ b/service/service_contract.go @@ -117,7 +117,6 @@ func (s *Server) CreateServiceContract(ctx context.Context, contract *apiservice } func (s *Server) GetServiceContracts(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { - out := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) out.Amount = utils.NewUInt32Value(0) out.Size = utils.NewUInt32Value(0) diff --git a/service/service_test.go b/service/service_test.go index 3aea3496a..53d00062f 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -45,6 +45,7 @@ import ( "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/store" "github.com/polarismesh/polaris/store/mock" + testsuit "github.com/polarismesh/polaris/test/suit" ) // 测试新增服务 @@ -1383,7 +1384,7 @@ func TestConcurrencyCreateSameService(t *testing.T) { userMgn, strategyMgn, err := auth.TestInitialize(ctx, &auth.Config{}, mockStore, cacheMgr) assert.NoError(t, err) - nsSvr, err = namespace.TestInitialize(ctx, &namespace.Config{ + nsSvr, err = testsuit.TestNamespaceInitialize(ctx, &namespace.Config{ AutoCreate: true, }, mockStore, cacheMgr, userMgn, strategyMgn) assert.NoError(t, err) diff --git a/store/auth_api.go b/store/auth_api.go index 625df4a34..0db55fec0 100644 --- a/store/auth_api.go +++ b/store/auth_api.go @@ -98,9 +98,6 @@ type StrategyStore interface { principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) // GetStrategyDetail Get strategy details GetStrategyDetail(id string) (*authcommon.StrategyDetail, error) - // GetStrategies Get a list of strategies - GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*authcommon.StrategyDetail, error) // GetMoreStrategies Used to refresh policy cache // 此方法用于 cache 增量更新,需要注意 mtime 应为数据库时间戳 GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) @@ -108,12 +105,14 @@ type StrategyStore interface { // RoleStore Role related storage operation interface type RoleStore interface { + // GetRole + GetRole(id string) (*authcommon.Role, error) // AddRole Add a role AddRole(role *authcommon.Role) error // UpdateRole Update a role UpdateRole(role *authcommon.Role) error // DeleteRole Delete a role - DeleteRole(role *authcommon.Role) error + DeleteRole(tx Tx, role *authcommon.Role) error // CleanPrincipalRoles Clean all the roles associated with the principal CleanPrincipalRoles(tx Tx, p *authcommon.Principal) error // GetRole get more role for cache update diff --git a/store/boltdb/default.go b/store/boltdb/default.go index 17907b923..17f143e59 100644 --- a/store/boltdb/default.go +++ b/store/boltdb/default.go @@ -18,11 +18,10 @@ package boltdb import ( + "os" "time" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" - bolt "go.etcd.io/bbolt" - "go.uber.org/zap" "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" @@ -121,11 +120,12 @@ func (m *boltStore) Initialize(c *store.Config) error { } if loadFile, ok := c.Option["loadFile"].(string); ok { - if err := m.loadByFile(loadFile); err != nil { - return err + // 仅用于本地测试验证单机数据 + loadFileName := os.Getenv("POLARIS_DEV_BOLT_INIT_DATA_FILA") + if loadFileName != "" { + loadFile = loadFileName } - } else { - if err := m.loadByDefault(); err != nil { + if err := m.loadByFile(loadFile); err != nil { return err } } @@ -143,166 +143,52 @@ var ( servicesToInit = map[string]string{ "polaris.checker": "fbca9bfa04ae4ead86e1ecf5811e32a9", } - - mainUser = &authcommon.User{ - ID: "65e4789a6d5b49669adf1e9e8387549c", - Name: "polaris", - Password: "$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum", - Owner: "", - Source: "Polaris", - Mobile: "", - Email: "", - Type: 20, - Token: "nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=", - TokenEnable: true, - Valid: true, - Comment: "default polaris admin account", - CreateTime: time.Now(), - ModifyTime: time.Now(), - } - - superDefaultStrategy = &authcommon.StrategyDetail{ - ID: "super_user_default_strategy", - Name: "(用户) polarissys@admin的默认策略", - Action: "READ_WRITE", - Comment: "default admin", - Principals: []authcommon.Principal{ - { - StrategyID: "super_user_default_strategy", - PrincipalID: "", - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: true, - Owner: "", - Resources: []authcommon.StrategyResource{ - { - StrategyID: "super_user_default_strategy", - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: "*", - }, - { - StrategyID: "super_user_default_strategy", - ResType: int32(apisecurity.ResourceType_Services), - ResID: "*", - }, - { - StrategyID: "super_user_default_strategy", - ResType: int32(apisecurity.ResourceType_ConfigGroups), - ResID: "*", - }, - }, - Valid: true, - Revision: "fbca9bfa04ae4ead86e1ecf5811e32a9", - CreateTime: time.Now(), - ModifyTime: time.Now(), - } - - mainDefaultStrategy = &authcommon.StrategyDetail{ - ID: "fbca9bfa04ae4ead86e1ecf5811e32a9", - Name: "(用户) polaris的默认策略", - Action: "READ_WRITE", - Comment: "default admin", - Principals: []authcommon.Principal{ - { - StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", - PrincipalID: "65e4789a6d5b49669adf1e9e8387549c", - PrincipalType: authcommon.PrincipalUser, - }, - }, - Default: true, - Owner: "65e4789a6d5b49669adf1e9e8387549c", - Resources: []authcommon.StrategyResource{ - { - StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: "*", - }, - { - StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", - ResType: int32(apisecurity.ResourceType_Services), - ResID: "*", - }, - { - StrategyID: "fbca9bfa04ae4ead86e1ecf5811e32a9", - ResType: int32(apisecurity.ResourceType_ConfigGroups), - ResID: "*", - }, - }, - Valid: true, - Revision: "fbca9bfa04ae4ead86e1ecf5811e32a9", - CreateTime: time.Now(), - ModifyTime: time.Now(), - } ) func (m *boltStore) initNamingStoreData() error { for _, namespace := range namespacesToInit { curTime := time.Now() - err := m.AddNamespace(&model.Namespace{ - Name: namespace, - Token: utils.NewUUID(), - Owner: ownerToInit, - Valid: true, - CreateTime: curTime, - ModifyTime: curTime, - }) + val, err := m.GetNamespace(namespace) if err != nil { return err } + if val == nil { + if err := m.AddNamespace(&model.Namespace{ + Name: namespace, + Token: utils.NewUUID(), + Owner: ownerToInit, + Valid: true, + CreateTime: curTime, + ModifyTime: curTime, + }); err != nil { + return err + } + } } for svc, id := range servicesToInit { curTime := time.Now() - err := m.AddService(&model.Service{ - ID: id, - Name: svc, - Namespace: namespacePolaris, - Token: utils.NewUUID(), - Owner: ownerToInit, - Revision: utils.NewUUID(), - Valid: true, - CreateTime: curTime, - ModifyTime: curTime, - }) - if err != nil { - return err - } - } - return nil -} - -func (m *boltStore) initAuthStoreData() error { - return m.handler.Execute(true, func(tx *bolt.Tx) error { - user, err := m.getUser(tx, mainUser.ID) + val, err := m.getServiceByNameAndNs(svc, namespacePolaris) if err != nil { return err } - - if user == nil { - user = mainUser - // 添加主账户主体信息 - if err := saveValue(tx, tblUser, user.ID, converToUserStore(user)); err != nil { - authLog.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) + if val != nil { + if err := m.AddService(&model.Service{ + ID: id, + Name: svc, + Namespace: namespacePolaris, + Token: utils.NewUUID(), + Owner: ownerToInit, + Revision: utils.NewUUID(), + Valid: true, + CreateTime: curTime, + ModifyTime: curTime, + }); err != nil { return err } } - rule, err := m.getStrategyDetail(tx, mainDefaultStrategy.ID) - if err != nil { - return err - } - - if rule == nil { - strategy := mainDefaultStrategy - // 添加主账户的默认鉴权策略信息 - if err := saveValue(tx, tblStrategy, strategy.ID, convertForStrategyStore(strategy)); err != nil { - authLog.Error("[Store][Strategy] save auth_strategy", zap.Error(err), - zap.String("name", strategy.Name), zap.String("owner", strategy.Owner)) - return err - } - } - return nil - }) + } + return nil } func (m *boltStore) newStore() error { @@ -341,6 +227,7 @@ func (m *boltStore) newAuthModuleStore() { m.userStore = &userStore{handler: m.handler} m.strategyStore = &strategyStore{handler: m.handler} m.groupStore = &groupStore{handler: m.handler} + m.roleStore = &roleStore{handle: m.handler} } func (m *boltStore) newConfigModuleStore() { @@ -382,3 +269,15 @@ func init() { s := &boltStore{} _ = store.RegisterStore(s) } + +func buildAllResAllow(id string) []authcommon.StrategyResource { + ret := make([]authcommon.StrategyResource, 0, 8) + for i := range apisecurity.ResourceType_value { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: id, + ResType: apisecurity.ResourceType_value[i], + ResID: "*", + }) + } + return ret +} diff --git a/store/boltdb/handler_test.go b/store/boltdb/handler_test.go index 33263bc59..676811b45 100644 --- a/store/boltdb/handler_test.go +++ b/store/boltdb/handler_test.go @@ -27,8 +27,10 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "gopkg.in/yaml.v3" "github.com/polarismesh/polaris/common/model" + authcommon "github.com/polarismesh/polaris/common/model/auth" ) func CreateTableDBHandlerAndRun(t *testing.T, tableName string, tf func(t *testing.T, handler BoltHandler)) { @@ -430,3 +432,10 @@ func TestBoltHandler_UpdateValue(t *testing.T) { } } + +func Test_PrintBoltInitData(t *testing.T) { + d, _ := yaml.Marshal(&authcommon.User{ + TokenEnable: true, + }) + t.Log(string(d)) +} diff --git a/store/boltdb/instance.go b/store/boltdb/instance.go index 0b61e055f..e97ecbeb2 100644 --- a/store/boltdb/instance.go +++ b/store/boltdb/instance.go @@ -19,7 +19,6 @@ package boltdb import ( "errors" - "fmt" "sort" "strconv" "strings" @@ -28,6 +27,7 @@ import ( "github.com/golang/protobuf/ptypes/wrappers" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" bolt "go.etcd.io/bbolt" + "go.uber.org/zap" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/polarismesh/polaris/common/model" @@ -603,8 +603,7 @@ func (i *instanceStore) SetInstanceHealthStatus(instanceID string, flag int, rev return err } if len(instances) == 0 { - msg := fmt.Sprintf("cant not find instance in kv, %s", instanceID) - log.Errorf(msg) + log.Errorf("cant not find instance in kv, %s", instanceID) return nil } @@ -667,8 +666,7 @@ func (i *instanceStore) BatchSetInstanceIsolate(ids []interface{}, isolate int, return err } if len(instances) == 0 { - msg := fmt.Sprintf("cant not find instance in kv, %v", ids) - log.Errorf(msg) + log.Errorf("cant not find instance in kv, %v", ids) return nil } @@ -684,7 +682,7 @@ func (i *instanceStore) BatchSetInstanceIsolate(ids []interface{}, isolate int, instance.Mtime = &wrappers.StringValue{Value: commontime.Time2String(curr)} err = i.handler.UpdateValue(tblNameInstance, id, properties) if err != nil { - log.Errorf("[Store][boltdb] update instance in set instance isolate error, %v", err) + log.Error("[Store][boltdb] update instance in set instance isolate error", zap.Error(err)) return err } } diff --git a/store/boltdb/load.go b/store/boltdb/load.go index 41cfdbd6f..0779c01d0 100644 --- a/store/boltdb/load.go +++ b/store/boltdb/load.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "os" - "sort" "time" bolt "go.etcd.io/bbolt" @@ -32,22 +31,11 @@ import ( authcommon "github.com/polarismesh/polaris/common/model/auth" ) -func (m *boltStore) loadByDefault() error { - if err := m.initAuthStoreData(); err != nil { - _ = m.handler.Close() - return err - } - if err := m.initNamingStoreData(); err != nil { - _ = m.handler.Close() - return err - } - return nil -} - // DefaultData 默认数据信息 type DefaultData struct { - Namespaces []*model.Namespace `yaml:"namespaces"` - Users []*authcommon.User `yaml:"users"` + Namespaces []*model.Namespace `yaml:"namespaces"` + Users []*authcommon.User `yaml:"users"` + Policies []*authcommon.StrategyDetail `yaml:"policies"` } func (m *boltStore) loadByFile(loadFile string) error { @@ -71,102 +59,47 @@ func (m *boltStore) loadByFile(loadFile string) error { return err } } - if len(data.Users) == 0 { - if err := m.initAuthStoreData(); err != nil { - _ = m.handler.Close() - return err - } - return nil - } return m.loadFromData(data) } func (m *boltStore) loadFromData(data *DefaultData) error { - users := data.Users - - // 确保排序为 admin -> main -> sub - sort.Slice(users, func(i, j int) bool { - return users[i].Type < users[j].Type - }) - - tn := time.Now() - var ( - superUser, mainUser *authcommon.User - ) - if len(users) >= 2 && users[0].Type == authcommon.AdminUserRole && users[1].Type == authcommon.OwnerUserRole { - superUser = users[0] - superUser.CreateTime = tn - superUser.ModifyTime = tn - mainUser = users[1] - mainUser.CreateTime = tn - mainUser.ModifyTime = tn - } else if users[0].Type == authcommon.OwnerUserRole { - mainUser = users[0] - mainUser.CreateTime = tn - mainUser.ModifyTime = tn - } else { - return errors.New("invalid init users info, must be have main user info") - } - if err := m.handler.Execute(true, func(tx *bolt.Tx) error { - saveFunc := func(user *authcommon.User, rule *authcommon.StrategyDetail) error { - rule.Owner = user.ID - rule.Principals[0].PrincipalID = user.ID - saveUser, err := m.getUser(tx, user.ID) + for i := range data.Users { + saveUser, err := m.getUser(tx, data.Users[i].ID) if err != nil { return err } - if saveUser == nil { + data.Users[i].CreateTime = time.Now() + data.Users[i].ModifyTime = time.Now() // 添加主账户主体信息 - if err := saveValue(tx, tblUser, user.ID, converToUserStore(user)); err != nil { - log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) + if err := saveValue(tx, tblUser, data.Users[i].ID, converToUserStore(data.Users[i])); err != nil { + log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", data.Users[i].Name)) return err } } + } - saveRule, err := m.getStrategyDetail(tx, rule.ID) + for i := range data.Policies { + saveRule, err := m.getStrategyDetail(tx, data.Policies[i].ID) if err != nil { return err } if saveRule == nil { + data.Policies[i].CreateTime = time.Now() + data.Policies[i].ModifyTime = time.Now() // 添加主账户的默认鉴权策略信息 - if err := saveValue(tx, tblStrategy, rule.ID, convertForStrategyStore(rule)); err != nil { + if err := saveValue(tx, tblStrategy, data.Policies[i].ID, convertForStrategyStore(data.Policies[i])); err != nil { log.Error("[Store][Strategy] save auth_strategy", zap.Error(err), - zap.String("name", rule.Name), zap.String("owner", rule.Owner)) + zap.String("name", data.Policies[i].Name), zap.String("owner", data.Policies[i].Owner)) return err } } - - return nil - } - - if superUser != nil { - if err := saveFunc(superUser, superDefaultStrategy); err != nil { - return err - } - } - if err := saveFunc(mainUser, mainDefaultStrategy); err != nil { - return err } return nil }); err != nil { return err } - - tx, err := m.handle.StartTx() - if err != nil { - return err - } - defer func() { - _ = tx.Rollback() - }() - // 挨个处理其他用户数据信息 - for i := 1; i < len(users); i++ { - if err := m.AddUser(tx, users[i]); err != nil { - return nil - } - } - return tx.Commit() + return nil } diff --git a/store/boltdb/namespace.go b/store/boltdb/namespace.go index a463ee4c1..2090770b7 100644 --- a/store/boltdb/namespace.go +++ b/store/boltdb/namespace.go @@ -24,6 +24,8 @@ import ( "sort" "time" + "go.uber.org/zap" + "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" ) @@ -101,7 +103,7 @@ func (n *namespaceStore) AddNamespace(namespace *model.Namespace) error { func (n *namespaceStore) cleanNamespace(name string) error { if err := n.handler.DeleteValues(tblNameNamespace, []string{name}); err != nil { - log.Errorf("[Store][boltdb] delete invalid namespace error, %+v", err) + log.Error("[Store][boltdb] delete invalid namespace error", zap.Error(err)) return err } @@ -118,7 +120,12 @@ func (n *namespaceStore) UpdateNamespace(namespace *model.Namespace) error { properties["Comment"] = namespace.Comment properties["ModifyTime"] = time.Now() properties["ServiceExportTo"] = utils.MustJson(namespace.ServiceExportTo) - return n.handler.UpdateValue(tblNameNamespace, namespace.Name, properties) + properties["Metadata"] = utils.MustJson(namespace.Metadata) + if err := n.handler.UpdateValue(tblNameNamespace, namespace.Name, properties); err != nil { + log.Error("[Store][boltdb] update namespace error", zap.Error(err)) + return err + } + return nil } // UpdateNamespaceToken update the token of a namespace @@ -144,6 +151,9 @@ func (n *namespaceStore) GetNamespace(name string) (*model.Namespace, error) { return nil, nil } ns := nsValue.(*Namespace) + if !ns.Valid { + return nil, nil + } return n.toModel(ns), nil } @@ -243,7 +253,7 @@ func (n *namespaceStore) GetMoreNamespaces(mtime time.Time) ([]*model.Namespace, if !ok { return false } - return mTimeValue.(time.Time).After(mtime) + return !mTimeValue.(time.Time).Before(mtime) }) if err != nil { return nil, err @@ -256,8 +266,14 @@ func (n *namespaceStore) toModel(data *Namespace) *model.Namespace { } func toModelNamespace(data *Namespace) *model.Namespace { + if !data.Valid { + return nil + } export := make(map[string]struct{}) _ = json.Unmarshal([]byte(data.ServiceExportTo), &export) + + metadata := make(map[string]string) + _ = json.Unmarshal([]byte(data.Metadata), &metadata) return &model.Namespace{ Name: data.Name, Comment: data.Comment, @@ -267,6 +283,7 @@ func toModelNamespace(data *Namespace) *model.Namespace { CreateTime: data.CreateTime, ModifyTime: data.ModifyTime, Valid: data.Valid, + Metadata: metadata, } } @@ -280,6 +297,7 @@ func (n *namespaceStore) toStore(data *model.Namespace) *Namespace { CreateTime: data.CreateTime, ModifyTime: data.ModifyTime, Valid: data.Valid, + Metadata: utils.MustJson(data.Metadata), } } @@ -294,4 +312,5 @@ type Namespace struct { ServiceExportTo string CreateTime time.Time ModifyTime time.Time + Metadata string } diff --git a/store/boltdb/ratelimit_test.go b/store/boltdb/ratelimit_test.go index 8531f38b1..bad397335 100644 --- a/store/boltdb/ratelimit_test.go +++ b/store/boltdb/ratelimit_test.go @@ -57,6 +57,7 @@ func createTestRateLimit(id string, createId bool) *model.RateLimit { CreateTime: time.Now(), ModifyTime: time.Now(), EnableTime: time.Now(), + Metadata: map[string]string{}, } } @@ -77,7 +78,7 @@ func Test_rateLimitStore_CreateRateLimit(t *testing.T) { t.Fatal(err) } - tN := time.Now() + tN := time.Time{} tVal := testVal tVal.ModifyTime = tN tVal.CreateTime = tN @@ -239,7 +240,7 @@ func Test_rateLimitStore_GetExtendRateLimits(t *testing.T) { got1Limits = append(got1Limits, got1[i].RateLimit) } - tN := time.Now() + tN := time.Time{} sort.Slice(got1, func(i, j int) bool { got1Limits[i].CreateTime = tN diff --git a/store/boltdb/role.go b/store/boltdb/role.go index 21b8a5085..adb41b84d 100644 --- a/store/boltdb/role.go +++ b/store/boltdb/role.go @@ -21,12 +21,13 @@ import ( "encoding/json" "time" + bolt "go.etcd.io/bbolt" + "go.uber.org/zap" + "github.com/polarismesh/polaris/common/model" authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" - bolt "go.etcd.io/bbolt" - "go.uber.org/zap" ) var _ store.RoleStore = (*roleStore)(nil) @@ -109,22 +110,19 @@ func (s *roleStore) UpdateRole(role *authcommon.Role) error { } // DeleteRole Delete a role -func (s *roleStore) DeleteRole(role *authcommon.Role) error { +func (s *roleStore) DeleteRole(tx store.Tx, role *authcommon.Role) error { if role.ID == "" { log.Error("[Store][role] delete role missing some params") return ErrBadParam } + dbTx := tx.GetDelegateTx().(*bolt.Tx) data := newRoleData(role) - - err := s.handle.Execute(true, func(tx *bolt.Tx) error { - properties := map[string]interface{}{ - CommonFieldValid: false, - CommonFieldModifyTime: time.Now(), - } - return updateValue(tx, tblRole, data.ID, properties) - }) - if err != nil { + properties := map[string]interface{}{ + CommonFieldValid: false, + CommonFieldModifyTime: time.Now(), + } + if err := updateValue(dbTx, tblRole, data.ID, properties); err != nil { log.Error("[Store][role] delete role failed", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) } @@ -192,6 +190,20 @@ func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) er return nil } +// GetRole get more role for cache update +func (s *roleStore) GetRole(id string) (*authcommon.Role, error) { + ret, err := s.handle.LoadValues(tblRole, []string{id}, &model.RoutingConfig{}) + if err != nil { + log.Errorf("[Store][role] get one role, %v", err) + return nil, store.Error(err) + } + + for i := range ret { + return newRole(ret[i].(*roleData)), nil + } + return nil, nil +} + // GetMoreRoles get more role for cache update func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { fields := []string{CommonFieldModifyTime, CommonFieldValid} @@ -249,8 +261,8 @@ func newRoleData(r *authcommon.Role) *roleData { } func newRole(r *roleData) *authcommon.Role { - users := make([]*authcommon.User, 0, 32) - groups := make([]*authcommon.UserGroup, 0, 32) + users := make([]authcommon.Principal, 0, 32) + groups := make([]authcommon.Principal, 0, 32) _ = json.Unmarshal([]byte(r.Users), &users) _ = json.Unmarshal([]byte(r.UserGroups), &groups) diff --git a/store/boltdb/strategy.go b/store/boltdb/strategy.go index 1a028c9a8..a47f19807 100644 --- a/store/boltdb/strategy.go +++ b/store/boltdb/strategy.go @@ -18,9 +18,9 @@ package boltdb import ( + "encoding/json" "errors" "fmt" - "sort" "strings" "time" @@ -58,24 +58,6 @@ var ( ErrorStrategyNotFound error = errors.New("strategy not fonud") ) -type strategyForStore struct { - ID string - Name string - Action string - Comment string - Users map[string]string - Groups map[string]string - Default bool - Owner string - NsResources map[string]string - SvcResources map[string]string - CfgResources map[string]string - Valid bool - Revision string - CreateTime time.Time - ModifyTime time.Time -} - // StrategyStore type strategyStore struct { handler BoltHandler @@ -136,17 +118,19 @@ func (ss *strategyStore) UpdateStrategy(strategy *authcommon.ModifyStrategyDetai // updateStrategy func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *authcommon.ModifyStrategyDetail, - saveVal *strategyForStore) error { + saveVal *strategyData) error { saveVal.Action = modify.Action saveVal.Comment = modify.Comment saveVal.Revision = utils.NewUUID() + saveVal.CalleeFunctions = utils.MustJson(modify.CalleeMethods) + saveVal.Conditions = utils.MustJson(modify.Conditions) computePrincipals(false, modify.AddPrincipals, saveVal) computePrincipals(true, modify.RemovePrincipals, saveVal) - computeResources(false, modify.AddResources, saveVal) - computeResources(true, modify.RemoveResources, saveVal) + saveVal.computeResources(false, modify.AddResources) + saveVal.computeResources(true, modify.RemoveResources) saveVal.ModifyTime = time.Now() @@ -165,7 +149,7 @@ func (ss *strategyStore) updateStrategy(tx *bolt.Tx, modify *authcommon.ModifySt return nil } -func computePrincipals(remove bool, principals []authcommon.Principal, saveVal *strategyForStore) { +func computePrincipals(remove bool, principals []authcommon.Principal, saveVal *strategyData) { for i := range principals { principal := principals[i] if principal.PrincipalType == authcommon.PrincipalUser { @@ -184,36 +168,6 @@ func computePrincipals(remove bool, principals []authcommon.Principal, saveVal * } } -func computeResources(remove bool, resources []authcommon.StrategyResource, saveVal *strategyForStore) { - for i := range resources { - resource := resources[i] - if resource.ResType == int32(apisecurity.ResourceType_Namespaces) { - if remove { - delete(saveVal.NsResources, resource.ResID) - } else { - saveVal.NsResources[resource.ResID] = "" - } - continue - } - if resource.ResType == int32(apisecurity.ResourceType_Services) { - if remove { - delete(saveVal.SvcResources, resource.ResID) - } else { - saveVal.SvcResources[resource.ResID] = "" - } - continue - } - if resource.ResType == int32(apisecurity.ResourceType_ConfigGroups) { - if remove { - delete(saveVal.CfgResources, resource.ResID) - } else { - saveVal.CfgResources[resource.ResID] = "" - } - continue - } - } -} - // DeleteStrategy delete a strategy func (ss *strategyStore) DeleteStrategy(id string) error { if id == "" { @@ -264,7 +218,7 @@ func (ss *strategyStore) operateStrategyResources(remove bool, resources []authc return ErrorStrategyNotFound } - computeResources(remove, ress, rule) + rule.computeResources(remove, ress) rule.ModifyTime = time.Now() if err := saveValue(tx, tblStrategy, rule.ID, rule); err != nil { log.Error("[Store][Strategy] operate strategy resource", zap.Error(err), @@ -282,10 +236,10 @@ func (ss *strategyStore) operateStrategyResources(remove bool, resources []authc return nil } -func loadStrategyById(tx *bolt.Tx, id string) (*strategyForStore, error) { +func loadStrategyById(tx *bolt.Tx, id string) (*strategyData, error) { values := make(map[string]interface{}) - if err := loadValues(tx, tblStrategy, []string{id}, &strategyForStore{}, values); err != nil { + if err := loadValues(tx, tblStrategy, []string{id}, &strategyData{}, values); err != nil { log.Error("[Store][Strategy] get auth_strategy by id", zap.Error(err), zap.String("id", id)) return nil, err @@ -298,9 +252,9 @@ func loadStrategyById(tx *bolt.Tx, id string) (*strategyForStore, error) { return nil, ErrorMultiDefaultStrategy } - var ret *strategyForStore + var ret *strategyData for _, v := range values { - ret = v.(*strategyForStore) + ret = v.(*strategyData) break } @@ -366,7 +320,7 @@ func (ss *strategyStore) GetStrategyResources(principalId string, fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } - values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, + values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyData{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) if ok && !valid { @@ -393,43 +347,13 @@ func (ss *strategyStore) GetStrategyResources(principalId string, ret := make([]authcommon.StrategyResource, 0, 4) for _, item := range values { - rule := item.(*strategyForStore) - ret = append(ret, collectStrategyResources(rule)...) + rule := item.(*strategyData) + ret = append(ret, rule.GetResources()...) } return ret, nil } -func collectStrategyResources(rule *strategyForStore) []authcommon.StrategyResource { - ret := make([]authcommon.StrategyResource, 0, len(rule.NsResources)+len(rule.SvcResources)+len(rule.CfgResources)) - - for id := range rule.NsResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: rule.ID, - ResType: int32(apisecurity.ResourceType_Namespaces), - ResID: id, - }) - } - - for id := range rule.SvcResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: rule.ID, - ResType: int32(apisecurity.ResourceType_Services), - ResID: id, - }) - } - - for id := range rule.CfgResources { - ret = append(ret, authcommon.StrategyResource{ - StrategyID: rule.ID, - ResType: int32(apisecurity.ResourceType_ConfigGroups), - ResID: id, - }) - } - - return ret -} - // GetDefaultStrategyDetailByPrincipal 获取默认策略详情 func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, principalType authcommon.PrincipalType) (*authcommon.StrategyDetail, error) { @@ -440,7 +364,7 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, fields = []string{StrategyFieldValid, StrategyFieldDefault, StrategyFieldGroupsPrincipal} } - values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, + values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyData{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) if ok && !valid { @@ -477,142 +401,15 @@ func (ss *strategyStore) GetDefaultStrategyDetailByPrincipal(principalId string, return nil, ErrorMultiDefaultStrategy } - var ret *strategyForStore + var ret *strategyData for _, v := range values { - ret = v.(*strategyForStore) + ret = v.(*strategyData) break } return convertForStrategyDetail(ret), nil } -// GetStrategies 查询鉴权策略列表 -func (ss *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*authcommon.StrategyDetail, error) { - - showDetail := filters["show_detail"] - delete(filters, "show_detail") - - return ss.listStrategies(filters, offset, limit, showDetail == "true") -} - -func (ss *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, - showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { - - fields := []string{StrategyFieldValid, StrategyFieldName, StrategyFieldUsersPrincipal, - StrategyFieldGroupsPrincipal, StrategyFieldNsResources, StrategyFieldSvcResources, - StrategyFieldCfgResources, StrategyFieldOwner, StrategyFieldDefault} - - values, err := ss.handler.LoadValuesByFilter(tblStrategy, fields, &strategyForStore{}, - func(m map[string]interface{}) bool { - valid, ok := m[StrategyFieldValid].(bool) - if ok && !valid { - return false - } - - saveName, _ := m[StrategyFieldName].(string) - saveDefault, _ := m[StrategyFieldDefault].(bool) - saveOwner, _ := m[StrategyFieldOwner].(string) - - if name, ok := filters["name"]; ok { - if utils.IsPrefixWildName(name) { - name = name[:len(name)-1] - } - if !strings.Contains(saveName, name) { - return false - } - } - - if owner, ok := filters["owner"]; ok { - if strings.Compare(saveOwner, owner) != 0 { - if principalId, ok := filters["principal_id"]; ok { - principalType := filters["principal_type"] - if !comparePrincipalExist(principalType, principalId, m) { - return false - } - } - } - } - - if isDefault, ok := filters["default"]; ok { - compareParam2BoolNotEqual := func(param string, b bool) bool { - if param == "0" && !b { - return true - } - if param == "1" && b { - return true - } - return false - } - if !compareParam2BoolNotEqual(isDefault, saveDefault) { - return false - } - } - - if resType, ok := filters["res_type"]; ok { - resId := filters["res_id"] - if !compareResExist(resType, resId, m) { - return false - } - } - - if principalId, ok := filters["principal_id"]; ok { - principalType := filters["principal_type"] - if !comparePrincipalExist(principalType, principalId, m) { - return false - } - } - - return true - }) - - if err != nil { - log.Error("[Store][Strategy] get auth_strategy for list", zap.Error(err)) - return 0, nil, err - } - - return uint32(len(values)), doStrategyPage(values, offset, limit, showDetail), nil -} - -func doStrategyPage(ret map[string]interface{}, offset, limit uint32, showDetail bool) []*authcommon.StrategyDetail { - rules := make([]*authcommon.StrategyDetail, 0, len(ret)) - - beginIndex := offset - endIndex := beginIndex + limit - totalCount := uint32(len(ret)) - - if totalCount == 0 { - return rules - } - if beginIndex >= endIndex { - return rules - } - if beginIndex >= totalCount { - return rules - } - if endIndex > totalCount { - endIndex = totalCount - } - - emptyPrincipals := make([]authcommon.Principal, 0) - emptyResources := make([]authcommon.StrategyResource, 0) - - for k := range ret { - rule := convertForStrategyDetail(ret[k].(*strategyForStore)) - if !showDetail { - rule.Principals = emptyPrincipals - rule.Resources = emptyResources - } - rules = append(rules, rule) - } - - sort.Slice(rules, func(i, j int) bool { - return rules[i].ModifyTime.After(rules[j].ModifyTime) - }) - - return rules[beginIndex:endIndex] -} - func compareResExist(resType, resId string, m map[string]interface{}) bool { saveNsRes, _ := m[StrategyFieldNsResources].(map[string]string) saveSvcRes, _ := m[StrategyFieldSvcResources].(map[string]string) @@ -656,7 +453,7 @@ func comparePrincipalExist(principalType, principalId string, m map[string]inter // GetMoreStrategies get strategy details for cache func (ss *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { - ret, err := ss.handler.LoadValuesByFilter(tblStrategy, []string{StrategyFieldModifyTime}, &strategyForStore{}, + ret, err := ss.handler.LoadValuesByFilter(tblStrategy, []string{StrategyFieldModifyTime}, &strategyData{}, func(m map[string]interface{}) bool { mt := m[StrategyFieldModifyTime].(time.Time) isAfter := mt.After(mtime) @@ -671,7 +468,7 @@ func (ss *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([ for k := range ret { val := ret[k] - strategies = append(strategies, convertForStrategyDetail(val.(*strategyForStore))) + strategies = append(strategies, convertForStrategyDetail(val.(*strategyData))) } return strategies, nil @@ -682,7 +479,7 @@ func (ss *strategyStore) CleanPrincipalPolicies(tx store.Tx, p authcommon.Princi values := make(map[string]interface{}) dbTx := tx.GetDelegateTx().(*bolt.Tx) - err := loadValuesByFilter(dbTx, tblStrategy, fields, &strategyForStore{}, + err := loadValuesByFilter(dbTx, tblStrategy, fields, &strategyData{}, func(m map[string]interface{}) bool { isDefault := m[StrategyFieldDefault].(bool) if !isDefault { @@ -735,7 +532,7 @@ func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) e fields := []string{StrategyFieldName, StrategyFieldOwner, StrategyFieldValid} values := make(map[string]interface{}) - err := loadValuesByFilter(tx, tblStrategy, fields, &strategyForStore{}, + err := loadValuesByFilter(tx, tblStrategy, fields, &strategyData{}, func(m map[string]interface{}) bool { valid, ok := m[StrategyFieldValid].(bool) // 如果数据是 valid 的,则不能被清理 @@ -767,7 +564,85 @@ func (ss *strategyStore) cleanInvalidStrategy(tx *bolt.Tx, name, owner string) e return deleteValues(tx, tblStrategy, keys) } -func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyForStore { +type strategyData struct { + ID string + Name string + Action string + Comment string + Users map[string]string + Groups map[string]string + Default bool + Owner string + NsResources map[string]string + SvcResources map[string]string + CfgResources map[string]string + AllResources string + CalleeFunctions string + Conditions string + Valid bool + Revision string + CreateTime time.Time + ModifyTime time.Time +} + +func (s *strategyData) computeResources(remove bool, resources []authcommon.StrategyResource) { + saveVal := s.GetResources() + + tmp := make(map[string]authcommon.StrategyResource, 8) + for i := range saveVal { + tmp[saveVal[i].Key()] = saveVal[i] + } + for i := range resources { + resource := resources[i] + if remove { + delete(tmp, resource.Key()) + } else { + tmp[resource.Key()] = resource + } + } + + ret := make([]authcommon.StrategyResource, 0, 8) + for i := range tmp { + ret = append(ret, tmp[i]) + } + + s.AllResources = utils.MustJson(ret) +} + +func (s *strategyData) GetResources() []authcommon.StrategyResource { + ret := make([]authcommon.StrategyResource, 0, len(s.NsResources)+len(s.SvcResources)+len(s.CfgResources)) + + for id := range s.NsResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: s.ID, + ResType: int32(apisecurity.ResourceType_Namespaces), + ResID: id, + }) + } + + for id := range s.SvcResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: s.ID, + ResType: int32(apisecurity.ResourceType_Services), + ResID: id, + }) + } + + for id := range s.CfgResources { + ret = append(ret, authcommon.StrategyResource{ + StrategyID: s.ID, + ResType: int32(apisecurity.ResourceType_ConfigGroups), + ResID: id, + }) + } + if len(s.AllResources) != 0 { + ret = make([]authcommon.StrategyResource, 0, 4) + _ = json.Unmarshal([]byte(s.AllResources), &ret) + } + return ret +} + +func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyData { var ( users = make(map[string]string, 4) @@ -784,48 +659,28 @@ func convertForStrategyStore(strategy *authcommon.StrategyDetail) *strategyForSt } } - ns := make(map[string]string, 4) - svc := make(map[string]string, 4) - cfg := make(map[string]string, 4) - - resources := strategy.Resources - - for i := range resources { - res := resources[i] - switch res.ResType { - case int32(apisecurity.ResourceType_Namespaces): - ns[res.ResID] = "" - case int32(apisecurity.ResourceType_Services): - svc[res.ResID] = "" - case int32(apisecurity.ResourceType_ConfigGroups): - cfg[res.ResID] = "" - } - } - - return &strategyForStore{ - ID: strategy.ID, - Name: strategy.Name, - Action: strategy.Action, - Comment: strategy.Comment, - Users: users, - Groups: groups, - Default: strategy.Default, - Owner: strategy.Owner, - NsResources: ns, - SvcResources: svc, - CfgResources: cfg, - Valid: strategy.Valid, - Revision: strategy.Revision, - CreateTime: strategy.CreateTime, - ModifyTime: strategy.ModifyTime, + return &strategyData{ + ID: strategy.ID, + Name: strategy.Name, + Action: strategy.Action, + Comment: strategy.Comment, + Users: users, + Groups: groups, + Default: strategy.Default, + Owner: strategy.Owner, + AllResources: utils.MustJson(strategy.Resources), + CalleeFunctions: utils.MustJson(strategy.CalleeMethods), + Conditions: utils.MustJson(strategy.Conditions), + Valid: strategy.Valid, + Revision: strategy.Revision, + CreateTime: strategy.CreateTime, + ModifyTime: strategy.ModifyTime, } } -func convertForStrategyDetail(strategy *strategyForStore) *authcommon.StrategyDetail { +func convertForStrategyDetail(strategy *strategyData) *authcommon.StrategyDetail { principals := make([]authcommon.Principal, 0, len(strategy.Users)+len(strategy.Groups)) - resources := make([]authcommon.StrategyResource, 0, len(strategy.NsResources)+ - len(strategy.SvcResources)+len(strategy.CfgResources)) for id := range strategy.Users { principals = append(principals, authcommon.Principal{ @@ -842,31 +697,13 @@ func convertForStrategyDetail(strategy *strategyForStore) *authcommon.StrategyDe }) } - fillRes := func(idMap map[string]string, resType apisecurity.ResourceType) []authcommon.StrategyResource { - res := make([]authcommon.StrategyResource, 0, len(idMap)) - - for id := range idMap { - res = append(res, authcommon.StrategyResource{ - StrategyID: strategy.ID, - ResType: int32(resType), - ResID: id, - }) - } - - return res - } - - resources = append(resources, fillRes(strategy.NsResources, apisecurity.ResourceType_Namespaces)...) - resources = append(resources, fillRes(strategy.SvcResources, apisecurity.ResourceType_Services)...) - resources = append(resources, fillRes(strategy.CfgResources, apisecurity.ResourceType_ConfigGroups)...) - - return &authcommon.StrategyDetail{ + ret := &authcommon.StrategyDetail{ ID: strategy.ID, Name: strategy.Name, Action: strategy.Action, Comment: strategy.Comment, Principals: principals, - Resources: resources, + Resources: strategy.GetResources(), Default: strategy.Default, Owner: strategy.Owner, Valid: strategy.Valid, @@ -874,6 +711,18 @@ func convertForStrategyDetail(strategy *strategyForStore) *authcommon.StrategyDe CreateTime: strategy.CreateTime, ModifyTime: strategy.ModifyTime, } + + if len(strategy.CalleeFunctions) != 0 { + functions := make([]string, 0, 4) + _ = json.Unmarshal([]byte(strategy.CalleeFunctions), &functions) + ret.CalleeMethods = functions + } + if len(strategy.Conditions) != 0 { + condition := make([]authcommon.Condition, 0, 4) + _ = json.Unmarshal([]byte(strategy.Conditions), &condition) + ret.Conditions = condition + } + return ret } func initStrategy(rule *authcommon.StrategyDetail) { diff --git a/store/boltdb/transaction.go b/store/boltdb/transaction.go index 60780cbdf..ca818103e 100644 --- a/store/boltdb/transaction.go +++ b/store/boltdb/transaction.go @@ -18,6 +18,8 @@ package boltdb import ( + "time" + "github.com/polarismesh/polaris/common/model" ) @@ -59,7 +61,12 @@ func (t *transaction) RLockNamespace(name string) (*model.Namespace, error) { // DeleteNamespace 删除namespace func (t *transaction) DeleteNamespace(name string) error { - return t.handler.DeleteValues(tblNameNamespace, []string{name}) + properties := map[string]interface{}{ + CommonFieldValid: false, + CommonFieldModifyTime: time.Now(), + } + + return t.handler.UpdateValue(tblNameNamespace, name, properties) } const ( diff --git a/store/boltdb/user.go b/store/boltdb/user.go index cc386a404..5e8eeb804 100644 --- a/store/boltdb/user.go +++ b/store/boltdb/user.go @@ -90,7 +90,6 @@ func (us *userStore) AddUser(tx store.Tx, user *authcommon.User) error { if owner == "" { owner = user.ID } - // 添加用户信息 if err := saveValue(dbTx, tblUser, user.ID, converToUserStore(user)); err != nil { log.Error("[Store][User] save user fail", zap.Error(err), zap.String("name", user.Name)) @@ -431,8 +430,15 @@ func (us *userStore) getGroupUsers(filters map[string]string, offset uint32, lim // GetUsersForCache 获取所有用户信息 func (us *userStore) GetUsersForCache(mtime time.Time, firstUpdate bool) ([]*authcommon.User, error) { - ret, err := us.handler.LoadValuesByFilter(tblUser, []string{UserFieldModifyTime}, &userForStore{}, + fields := []string{UserFieldModifyTime, UserFieldValid} + ret, err := us.handler.LoadValuesByFilter(tblUser, fields, &userForStore{}, func(m map[string]interface{}) bool { + if firstUpdate { + valid, _ := m[UserFieldValid].(bool) + if !valid { + return false + } + } mt := m[UserFieldModifyTime].(time.Time) isBefore := mt.Before(mtime) return !isBefore diff --git a/store/mock/api_mock.go b/store/mock/api_mock.go index 3b397247d..57ffee0af 100644 --- a/store/mock/api_mock.go +++ b/store/mock/api_mock.go @@ -865,17 +865,17 @@ func (mr *MockStoreMockRecorder) DeleteRateLimit(limiting interface{}) *gomock.C } // DeleteRole mocks base method. -func (m *MockStore) DeleteRole(role *auth.Role) error { +func (m *MockStore) DeleteRole(tx store.Tx, role *auth.Role) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRole", role) + ret := m.ctrl.Call(m, "DeleteRole", tx, role) ret0, _ := ret[0].(error) return ret0 } // DeleteRole indicates an expected call of DeleteRole. -func (mr *MockStoreMockRecorder) DeleteRole(role interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteRole(tx, role interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRole", reflect.TypeOf((*MockStore)(nil).DeleteRole), role) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRole", reflect.TypeOf((*MockStore)(nil).DeleteRole), tx, role) } // DeleteRoutingConfig mocks base method. @@ -1848,6 +1848,21 @@ func (mr *MockStoreMockRecorder) GetRateLimitsForCache(mtime, firstUpdate interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRateLimitsForCache", reflect.TypeOf((*MockStore)(nil).GetRateLimitsForCache), mtime, firstUpdate) } +// GetRole mocks base method. +func (m *MockStore) GetRole(id string) (*auth.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRole", id) + ret0, _ := ret[0].(*auth.Role) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRole indicates an expected call of GetRole. +func (mr *MockStoreMockRecorder) GetRole(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockStore)(nil).GetRole), id) +} + // GetRoutingConfigV2WithID mocks base method. func (m *MockStore) GetRoutingConfigV2WithID(id string) (*model.RouterConfig, error) { m.ctrl.T.Helper() diff --git a/store/mysql/admin.go b/store/mysql/admin.go index b3e775530..658c9d9cf 100644 --- a/store/mysql/admin.go +++ b/store/mysql/admin.go @@ -468,7 +468,7 @@ func (m *adminStore) BatchCleanDeletedInstances(timeout time.Duration, batchSize } cleanCheckStr := fmt.Sprintf("delete from health_check where id in (%s)", inSql) - if _, err := tx.Exec(cleanCheckStr, waitDelIds...); err != nil { + if _, err = tx.Exec(cleanCheckStr, waitDelIds...); err != nil { log.Errorf("[Store][database] batch clean soft deleted instances(%d), err: %s", batchSize, err.Error()) return store.Error(err) } @@ -487,7 +487,7 @@ func (m *adminStore) BatchCleanDeletedInstances(timeout time.Duration, batchSize return store.Error(err) } - if err := tx.Commit(); err != nil { + if err = tx.Commit(); err != nil { log.Errorf("[Store][database] batch clean soft deleted instances(%d) commit tx err: %s", batchSize, err.Error()) return err diff --git a/store/mysql/config_file.go b/store/mysql/config_file.go index bc8c35d3d..5854437e8 100644 --- a/store/mysql/config_file.go +++ b/store/mysql/config_file.go @@ -300,7 +300,7 @@ func (cf *configFileStore) QueryConfigFiles(filter map[string]string, offset, li err = cf.slave.processWithTransaction("batch-load-file-tags", func(tx *BaseTx) error { for i := range files { item := files[i] - if err := cf.loadFileTags(tx, item); err != nil { + if err = cf.loadFileTags(tx, item); err != nil { return err } } diff --git a/store/mysql/config_file_release.go b/store/mysql/config_file_release.go index 55e511f5c..f2fe05457 100644 --- a/store/mysql/config_file_release.go +++ b/store/mysql/config_file_release.go @@ -52,7 +52,7 @@ func (cfr *configFileReleaseStore) CreateConfigFileReleaseTx(tx store.Tx, data * } clean := "DELETE FROM config_file_release WHERE namespace = ? AND `group` = ? AND file_name = ? AND name = ? AND flag = 1" - if _, err := dbTx.Exec(clean, data.Namespace, data.Group, data.FileName, data.Name); err != nil { + if _, err = dbTx.Exec(clean, data.Namespace, data.Group, data.FileName, data.Name); err != nil { return store.Error(err) } diff --git a/store/mysql/default.go b/store/mysql/default.go index ae5d8c205..3d2bb5b39 100644 --- a/store/mysql/default.go +++ b/store/mysql/default.go @@ -272,12 +272,15 @@ func (s *stableStore) newStore() { s.configFileTemplateStore = &configFileTemplateStore{master: s.master, slave: s.slave} s.clientStore = &clientStore{master: s.master, slave: s.slave} + s.grayStore = &grayStore{master: s.master, slave: s.slave} + s.adminStore = newAdminStore(s.master) s.toolStore = &toolStore{db: s.master} + s.userStore = &userStore{master: s.master, slave: s.slave} s.groupStore = &groupStore{master: s.master, slave: s.slave} s.strategyStore = &strategyStore{master: s.master, slave: s.slave} - s.grayStore = &grayStore{master: s.master, slave: s.slave} + s.roleStore = &roleStore{master: s.master, slave: s.slave} } func buildEtimeStr(enable bool) string { diff --git a/store/mysql/group.go b/store/mysql/group.go index 938a10087..93df9d99a 100644 --- a/store/mysql/group.go +++ b/store/mysql/group.go @@ -128,14 +128,14 @@ func (u *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { // 更新用户-用户组关联数据 if len(group.AddUserIds) != 0 { - if err := u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil { + if err = u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil { log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error()) return err } } if len(group.RemoveUserIds) != 0 { - if err := u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil { + if err = u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil { log.Errorf("[Store][Group] remove usergroup relation err: %s", err.Error()) return err } @@ -153,7 +153,7 @@ func (u *groupStore) updateGroup(group *authcommon.ModifyUserGroup) error { return err } - if err := tx.Commit(); err != nil { + if err = tx.Commit(); err != nil { log.Errorf("[Store][Group] update usergroup tx commit err: %s", err.Error()) return err } diff --git a/store/mysql/lane.go b/store/mysql/lane.go index 068cd3603..705f0b2ab 100644 --- a/store/mysql/lane.go +++ b/store/mysql/lane.go @@ -200,9 +200,9 @@ SELECT id, name, rule, description, revision, flag, UNIX_TIMESTAMP(ctime), UNIX_ for k, v := range filter { switch k { case "name": - if v, ok := utils.ParseWildName(v); ok { + if pv, ok := utils.ParseWildName(v); ok { conditions = append(conditions, "name = ?") - args = append(args, v) + args = append(args, pv) } else { conditions = append(conditions, "name LIKE ?") args = append(args, "%"+v+"%") @@ -295,8 +295,8 @@ func (l *laneStore) getLaneRulesByGroup(tx *BaseTx, names []string) (map[string] } querySql := ` -SELECT id, name, group_name, rule, revision, priority, description, enable, flag, UNIX_TIMESTAMP(ctime), UNIX_TIMESTAMP(etime), UNIX_TIMESTAMP(mtime) - FROM lane_rule WHERE flag = 0 AND group_name IN (%s) +SELECT id, name, group_name, rule, revision, priority, description, enable, flag, UNIX_TIMESTAMP(ctime), +UNIX_TIMESTAMP(etime), UNIX_TIMESTAMP(mtime) FROM lane_rule WHERE flag = 0 AND group_name IN (%s) ` querySql = fmt.Sprintf(querySql, placeholders(len(names))) diff --git a/store/mysql/role.go b/store/mysql/role.go index 97732eed9..dfe353626 100644 --- a/store/mysql/role.go +++ b/store/mysql/role.go @@ -21,10 +21,11 @@ import ( "encoding/json" "time" + "go.uber.org/zap" + authcommon "github.com/polarismesh/polaris/common/model/auth" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/store" - "go.uber.org/zap" ) type roleStore struct { @@ -69,16 +70,16 @@ func (s *roleStore) savePrincipals(tx *BaseTx, role *authcommon.Role) error { return err } - insertTpl := "INSERT INTO auth_role_principal(role_id, principal_id, principal_role) VALUES (?, ?, ?)" + insertTpl := "INSERT INTO auth_role_principal(role_id, principal_id, principal_role, extend_info) VALUES (?, ?, ?)" for i := range role.Users { - args := []interface{}{role.ID, role.Users[i].ID, authcommon.PrincipalUser} + args := []interface{}{role.ID, role.Users[i].PrincipalID, authcommon.PrincipalUser, utils.MustJson(role.Users[i].Extend)} if _, err := tx.Exec(insertTpl, args...); err != nil { return err } } for i := range role.UserGroups { - args := []interface{}{role.ID, role.UserGroups[i].ID, authcommon.PrincipalGroup} + args := []interface{}{role.ID, role.UserGroups[i].PrincipalID, authcommon.PrincipalGroup, utils.MustJson(role.UserGroups[i].Extend)} if _, err := tx.Exec(insertTpl, args...); err != nil { return err } @@ -113,18 +114,16 @@ WHERE id = ? } // DeleteRole Delete a role -func (s *roleStore) DeleteRole(role *authcommon.Role) error { +func (s *roleStore) DeleteRole(tx store.Tx, role *authcommon.Role) error { if role.ID == "" { return store.NewStatusError(store.EmptyParamsErr, "role id is empty") } - err := s.master.processWithTransaction("delete_role", func(tx *BaseTx) error { - if _, err := tx.Exec("UPDATE auth_role SET flag = 1 WHERE id = ?", role.ID); err != nil { - log.Error("[store][role] delete role", zap.String("name", role.Name), zap.Error(err)) - return err - } - return nil - }) - return store.Error(err) + dbTx := tx.GetDelegateTx().(*BaseTx) + if _, err := dbTx.Exec("UPDATE auth_role SET flag = 1 WHERE id = ?", role.ID); err != nil { + log.Error("[store][role] delete role", zap.String("name", role.Name), zap.Error(err)) + return store.Error(err) + } + return nil } // CleanPrincipalRoles clean principal roles @@ -163,6 +162,47 @@ func (s *roleStore) CleanPrincipalRoles(tx store.Tx, p *authcommon.Principal) er return nil } +func (s *roleStore) GetRole(id string) (*authcommon.Role, error) { + tx, err := s.master.Begin() + if err != nil { + return nil, store.Error(err) + } + + defer func() { _ = tx.Commit() }() + + querySql := "SELECT id, name, owner, source, role_type, comment, flag, metadata, UNIX_TIMESTAMP(ctime), " + + " UNIX_TIMESTAMP(mtime) FROM auth_role WHERE flag = 0 AND id = ?" + args := []interface{}{id} + + row := tx.QueryRow(querySql, args...) + var ( + ctime, mtime int64 + flag int16 + metadata string + ) + ret := &authcommon.Role{ + Metadata: map[string]string{}, + Users: make([]authcommon.Principal, 0, 4), + UserGroups: make([]authcommon.Principal, 0, 4), + } + + if err := row.Scan(&ret.ID, &ret.Name, &ret.Owner, &ret.Source, &ret.Type, &ret.Comment, + &flag, &metadata, &ctime, &mtime); err != nil { + log.Error("[store][role] fetch one record role info", zap.Error(err)) + return nil, store.Error(err) + } + + ret.CreateTime = time.Unix(ctime, 0) + ret.ModifyTime = time.Unix(mtime, 0) + ret.Valid = flag == 0 + _ = json.Unmarshal([]byte(metadata), &ret.Metadata) + + if err := s.fetchRolePrincipals(tx, ret); err != nil { + return nil, store.Error(err) + } + return ret, nil +} + // GetRole get more role for cache update func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcommon.Role, error) { tx, err := s.slave.Begin() @@ -201,8 +241,8 @@ func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcomm ) ret := &authcommon.Role{ Metadata: map[string]string{}, - Users: make([]*authcommon.User, 0, 4), - UserGroups: make([]*authcommon.UserGroup, 0, 4), + Users: make([]authcommon.Principal, 0, 4), + UserGroups: make([]authcommon.Principal, 0, 4), } if err := rows.Scan(&ret.ID, &ret.Name, &ret.Owner, &ret.Source, &ret.Type, &ret.Comment, @@ -227,7 +267,7 @@ func (s *roleStore) GetMoreRoles(firstUpdate bool, mtime time.Time) ([]*authcomm } func (s *roleStore) fetchRolePrincipals(tx *BaseTx, role *authcommon.Role) error { - rows, err := tx.Query("SELECT role_id, principal_id, principal_role FROM auth_role_principal WHERE rold_id = ?", role.ID) + rows, err := tx.Query("SELECT role_id, principal_id, principal_role, extend_info FROM auth_role_principal WHERE rold_id = ?", role.ID) if err != nil { log.Error("[store][role] fetch role principals", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) @@ -238,21 +278,28 @@ func (s *roleStore) fetchRolePrincipals(tx *BaseTx, role *authcommon.Role) error for rows.Next() { var ( - roleID, principalID string - principalRole int + roleID, principalID, extendStr string + principalRole int ) - if err := rows.Scan(&roleID, &principalID, &principalRole); err != nil { + if err := rows.Scan(&roleID, &principalID, &principalRole, &extendStr); err != nil { log.Error("[store][role] fetch one record role principal", zap.String("name", role.Name), zap.Error(err)) return store.Error(err) } + extend := map[string]string{} + _ = json.Unmarshal([]byte(extendStr), &extend) + if principalRole == int(authcommon.PrincipalUser) { - role.Users = append(role.Users, &authcommon.User{ - ID: principalID, + role.Users = append(role.Users, authcommon.Principal{ + PrincipalID: principalID, + PrincipalType: authcommon.PrincipalUser, + Extend: extend, }) } else { - role.UserGroups = append(role.UserGroups, &authcommon.UserGroup{ - ID: principalID, + role.UserGroups = append(role.UserGroups, authcommon.Principal{ + PrincipalID: principalID, + PrincipalType: authcommon.PrincipalGroup, + Extend: extend, }) } } diff --git a/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql b/store/mysql/scripts/delta/v1_18_1-v1_19_0.sql similarity index 73% rename from store/mysql/scripts/delta/v1_18_1-v1_18_2.sql rename to store/mysql/scripts/delta/v1_18_1-v1_19_0.sql index 49041073d..04548adc0 100644 --- a/store/mysql/scripts/delta/v1_18_1-v1_18_2.sql +++ b/store/mysql/scripts/delta/v1_18_1-v1_19_0.sql @@ -1,3 +1,19 @@ +/* + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ /* 角色数据 */ CREATE TABLE `auth_role` ( diff --git a/store/mysql/scripts/polaris_server.sql b/store/mysql/scripts/polaris_server.sql index da5351911..4d4afb714 100644 --- a/store/mysql/scripts/polaris_server.sql +++ b/store/mysql/scripts/polaris_server.sql @@ -649,6 +649,7 @@ CREATE TABLE `strategy_id` VARCHAR(128) NOT NULL COMMENT 'Strategy ID', `principal_id` VARCHAR(128) NOT NULL COMMENT 'Principal ID', `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group, 3 is Role', + `extend_info` TEXT COMMENT 'link principal extend info', PRIMARY KEY (`strategy_id`, `principal_id`, `principal_role`) ) ENGINE = InnoDB; @@ -688,6 +689,7 @@ CREATE TABLE `role_id` VARCHAR(128) NOT NULL COMMENT 'role id', `principal_id` VARCHAR(128) NOT NULL COMMENT 'principal id', `principal_role` INT NOT NULL COMMENT 'PRINCIPAL type, 1 is User, 2 is Group', + `extend_info` TEXT COMMENT 'link principal extend info', PRIMARY KEY (`role_id`, `principal_id`, `principal_role`) ) ENGINE = InnoDB; @@ -707,104 +709,6 @@ CREATE TABLE `auth_strategy_function` ( PRIMARY KEY (`strategy_id`, `function`) ) ENGINE = InnoDB; --- Create a default master account, password is Polarismesh @ 2021 -INSERT INTO - `user` ( - `id`, - `name`, - `password`, - `source`, - `token`, - `token_enable`, - `user_type`, - `comment`, - `mobile`, - `email`, - `owner` - ) -VALUES - ( - '65e4789a6d5b49669adf1e9e8387549c', - 'polaris', - '$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum', - 'Polaris', - 'nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=', - 1, - 20, - 'default polaris admin account', - '12345678910', - '12345678910', - '' - ); - --- Permissions policy inserted into Polaris-Admin -INSERT INTO - `auth_strategy` ( - `id`, - `name`, - `action`, - `owner`, - `comment`, - `default`, - `revision`, - `flag`, - `ctime`, - `mtime` - ) -VALUES - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - '(用户) polaris的默认策略', - 'READ_WRITE', - '65e4789a6d5b49669adf1e9e8387549c', - 'default admin', - 1, - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - SYSDATE (), - SYSDATE () - ); - --- Sport rules inserted into Polaris-Admin to access -INSERT INTO - `auth_strategy_resource` ( - `strategy_id`, - `res_type`, - `res_id`, - `ctime`, - `mtime` - ) -VALUES - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 0, - '*', - SYSDATE (), - SYSDATE () - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 1, - '*', - SYSDATE (), - SYSDATE () - ), - ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - 2, - '*', - SYSDATE (), - SYSDATE () - ); - --- Insert permission policies and association relationships for Polaris-Admin accounts -INSERT INTO - auth_principal (`strategy_id`, `principal_id`, `principal_role`) VALUE ( - 'fbca9bfa04ae4ead86e1ecf5811e32a9', - '65e4789a6d5b49669adf1e9e8387549c', - 1 - ); - -- v1.8.0, support client info storage CREATE TABLE `client` ( @@ -883,9 +787,9 @@ VALUES }', 'json', 'Spring Cloud Gateway 染色规则', - NOW (), + NOW(), 'polaris', - NOW (), + NOW(), 'polaris' ); @@ -1062,3 +966,439 @@ CREATE TABLE PRIMARY KEY (`id`), UNIQUE KEY `name` (`group_name`, `name`) ) ENGINE = InnoDB; + + +/* 默认资源信息数据插入 */ + +-- Create a default master account, password is Polarismesh @ 2021 +INSERT INTO + `user` ( + `id`, + `name`, + `password`, + `source`, + `token`, + `token_enable`, + `user_type`, + `comment`, + `mobile`, + `email`, + `owner` + ) +VALUES + ( + '65e4789a6d5b49669adf1e9e8387549c', + 'polaris', + '$2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum', + 'Polaris', + 'nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=', + 1, + 20, + 'default polaris admin account', + '12345678910', + '12345678910', + '' + ); + +-- Permissions policy inserted into Polaris-Admin +INSERT INTO + `auth_strategy` ( + `id`, + `name`, + `action`, + `owner`, + `comment`, + `default`, + `revision`, + `flag`, + `ctime`, + `mtime` + ) +VALUES + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + '(用户) polaris的默认策略', + 'READ_WRITE', + '65e4789a6d5b49669adf1e9e8387549c', + 'default admin', + 1, + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + sysdate(), + sysdate() + ); + +-- Sport rules inserted into Polaris-Admin to access +INSERT INTO + `auth_strategy_resource` ( + `strategy_id`, + `res_type`, + `res_id`, + `ctime`, + `mtime` + ) +VALUES + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 1, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 2, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 3, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 4, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 5, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 6, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 7, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 20, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 21, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 22, + '*', + sysdate(), + sysdate() + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 23, + '*', + sysdate(), + sysdate() + ); + +-- Insert permission policies and association relationships for Polaris-Admin accounts +INSERT INTO + auth_principal (`strategy_id`, `principal_id`, `principal_role`) VALUES ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + '65e4789a6d5b49669adf1e9e8387549c', + 1 + ); + +INSERT INTO + auth_strategy_function (`strategy_id`, `function`) VALUES ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + '*' + ); + +/* 默认的全局只读策略 */ +INSERT INTO + `auth_strategy` ( + `id`, + `name`, + `action`, + `owner`, + `comment`, + `default`, + `revision`, + `flag`, + `ctime`, + `mtime` + ) +VALUES + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + '全局只读策略', + 'ALLOW', + '65e4789a6d5b49669adf1e9e8387549c', + 'global resources read onyly', + 1, + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + sysdate(), + sysdate() + ); + +INSERT INTO + `auth_strategy_resource` ( + `strategy_id`, + `res_type`, + `res_id`, + `ctime`, + `mtime` + ) +VALUES + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 0, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 1, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 2, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 3, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 4, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 5, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 6, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 7, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 20, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 21, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 22, + '*', + sysdate(), + sysdate() + ), + ( + 'bfa04ae1e32a94fbca9ead86e1ecf581', + 23, + '*', + sysdate(), + sysdate() + ); + +INSERT INTO + auth_strategy_function (`strategy_id`, `function`) VALUES ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 'Describe*' + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 'List*' + ), + ( + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 'Get*' + ); + + +/* 默认的全局读写策略 */ +INSERT INTO + `auth_strategy` ( + `id`, + `name`, + `action`, + `owner`, + `comment`, + `default`, + `revision`, + `flag`, + `ctime`, + `mtime` + ) +VALUES + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + '全局读写策略', + 'ALLOW', + '65e4789a6d5b49669adf1e9e8387549c', + 'global resources read and write', + 1, + 'fbca9bfa04ae4ead86e1ecf5811e32a9', + 0, + sysdate(), + sysdate() + ); + +INSERT INTO + `auth_strategy_resource` ( + `strategy_id`, + `res_type`, + `res_id`, + `ctime`, + `mtime` + ) +VALUES + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 0, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 1, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 2, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 3, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 4, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 5, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 6, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 7, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 20, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 21, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 22, + '*', + sysdate(), + sysdate() + ), + ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + 23, + '*', + sysdate(), + sysdate() + ); + +INSERT INTO + auth_strategy_function (`strategy_id`, `function`) VALUES ( + 'e3d86e1ecf5812bfa04ae1a94fbca9ea', + '*' + ); + diff --git a/store/mysql/strategy.go b/store/mysql/strategy.go index cd95c2fce..23c8e096d 100644 --- a/store/mysql/strategy.go +++ b/store/mysql/strategy.go @@ -19,6 +19,7 @@ package sqldb import ( "database/sql" + "encoding/json" "fmt" "strings" "time" @@ -75,23 +76,30 @@ func (s *strategyStore) AddStrategy(tx store.Tx, strategy *authcommon.StrategyDe isDefault = 1 } - if err := s.addStrategyPrincipals(dbTx, strategy.ID, strategy.Principals); err != nil { + if err := s.addPolicyPrincipals(dbTx, strategy.ID, strategy.Principals); err != nil { log.Error("[Store][Strategy] add auth_strategy principals", zap.Error(err)) return err } - - if err := s.addStrategyResources(dbTx, strategy.ID, strategy.Resources); err != nil { + if err := s.addPolicyResources(dbTx, strategy.ID, strategy.Resources); err != nil { log.Error("[Store][Strategy] add auth_strategy resources", zap.Error(err)) return err } + if err := s.savePolicyFunctions(dbTx, strategy.ID, strategy.CalleeMethods); err != nil { + log.Error("[Store][Strategy] save auth_strategy functions", zap.Error(err)) + return err + } + if err := s.savePolicyConditions(dbTx, strategy.ID, strategy.Conditions); err != nil { + log.Error("[Store][Strategy] save auth_strategy conditions", zap.Error(err)) + return err + } // 保存策略主信息 saveMainSql := "INSERT INTO auth_strategy(`id`, `name`, `action`, `owner`, `comment`, `flag`, " + - " `default`, `revision`) VALUES (?,?,?,?,?,?,?,?)" + " `default`, `revision`, `source`, `metadata`) VALUES (?,?,?,?,?,?,?,?,?,?)" if _, err := dbTx.Exec(saveMainSql, []interface{}{ strategy.ID, strategy.Name, strategy.Action, strategy.Owner, strategy.Comment, - 0, isDefault, strategy.Revision}..., + 0, isDefault, strategy.Revision, strategy.Source, utils.MustJson(strategy.Metadata)}..., ); err != nil { log.Error("[Store][Strategy] add auth_strategy main info", zap.Error(err)) return err @@ -120,25 +128,34 @@ func (s *strategyStore) updateStrategy(strategy *authcommon.ModifyStrategyDetail defer func() { _ = tx.Rollback() }() // 调整 principal 信息 - if err := s.addStrategyPrincipals(tx, strategy.ID, strategy.AddPrincipals); err != nil { + if err = s.addPolicyPrincipals(tx, strategy.ID, strategy.AddPrincipals); err != nil { log.Errorf("[Store][Strategy] add strategy principal err: %s", err.Error()) return err } - if err := s.deleteStrategyPrincipals(tx, strategy.ID, strategy.RemovePrincipals); err != nil { + if err = s.deletePolicyPrincipals(tx, strategy.ID, strategy.RemovePrincipals); err != nil { log.Errorf("[Store][Strategy] remove strategy principal err: %s", err.Error()) return err } // 调整鉴权资源信息 - if err := s.addStrategyResources(tx, strategy.ID, strategy.AddResources); err != nil { + if err = s.addPolicyResources(tx, strategy.ID, strategy.AddResources); err != nil { log.Errorf("[Store][Strategy] add strategy resource err: %s", err.Error()) return err } - if err := s.deleteStrategyResources(tx, strategy.ID, strategy.RemoveResources); err != nil { + if err = s.deletePolicyResources(tx, strategy.ID, strategy.RemoveResources); err != nil { log.Errorf("[Store][Strategy] remove strategy resource err: %s", err.Error()) return err } + if err = s.savePolicyFunctions(tx, strategy.ID, strategy.CalleeMethods); err != nil { + log.Error("[Store][Strategy] save auth_strategy functions", zap.Error(err)) + return err + } + if err = s.savePolicyConditions(tx, strategy.ID, strategy.Conditions); err != nil { + log.Error("[Store][Strategy] save auth_strategy conditions", zap.Error(err)) + return err + } + // 保存策略主信息 saveMainSql := "UPDATE auth_strategy SET action = ?, comment = ?, mtime = sysdate() WHERE id = ?" if _, err = tx.Exec(saveMainSql, []interface{}{strategy.Action, strategy.Comment, strategy.ID}...); err != nil { @@ -146,7 +163,7 @@ func (s *strategyStore) updateStrategy(strategy *authcommon.ModifyStrategyDetail return err } - if err := tx.Commit(); err != nil { + if err = tx.Commit(); err != nil { log.Errorf("[Store][Strategy] update auth_strategy tx commit err: %s", err.Error()) return err } @@ -174,21 +191,19 @@ func (s *strategyStore) deleteStrategy(id string) error { defer func() { _ = tx.Rollback() }() - if _, err = tx.Exec("UPDATE auth_strategy SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{ - id, - }...); err != nil { + if _, err = tx.Exec("UPDATE auth_strategy SET flag = 1, mtime = sysdate() WHERE id = ?", id); err != nil { return err } - - if _, err = tx.Exec("DELETE FROM auth_strategy_resource WHERE strategy_id = ?", []interface{}{ - id, - }...); err != nil { + if _, err = tx.Exec("DELETE FROM auth_strategy_resource WHERE strategy_id = ?", id); err != nil { return err } - - if _, err = tx.Exec("DELETE FROM auth_principal WHERE strategy_id = ?", []interface{}{ - id, - }...); err != nil { + if _, err = tx.Exec("DELETE FROM auth_principal WHERE strategy_id = ?", id); err != nil { + return err + } + if _, err = tx.Exec("DELETE FROM auth_strategy_function WHERE strategy_id = ?", id); err != nil { + return err + } + if _, err = tx.Exec("DELETE FROM auth_strategy_label WHERE strategy_id = ?", id); err != nil { return err } @@ -199,20 +214,77 @@ func (s *strategyStore) deleteStrategy(id string) error { return nil } -// addStrategyPrincipals -func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { +// savePolicyFunctions +func (s *strategyStore) savePolicyFunctions(tx *BaseTx, id string, functions []string) error { + if len(functions) == 0 { + return nil + } + + if _, err := tx.Exec("DELETE FROM auth_strategy_function WHERE strategy_id = ?", id); err != nil { + return err + } + + savePrincipalSql := "INSERT IGNORE INTO auth_strategy_function(`strategy_id`, `function`) VALUES " + values := make([]string, 0) + args := make([]interface{}, 0) + + for i := range functions { + values = append(values, "(?,?)") + args = append(args, id, functions[i]) + } + + savePrincipalSql += strings.Join(values, ",") + + log.Debug("[Store][Strategy] save policy functions", zap.String("sql", savePrincipalSql), + zap.Any("args", args)) + + _, err := tx.Exec(savePrincipalSql, args...) + return err +} + +// savePolicyConditions +func (s *strategyStore) savePolicyConditions(tx *BaseTx, id string, conditions []authcommon.Condition) error { + if len(conditions) == 0 { + return nil + } + + if _, err := tx.Exec("DELETE FROM auth_strategy_label WHERE strategy_id = ?", id); err != nil { + return err + } + + savePrincipalSql := "INSERT IGNORE INTO auth_strategy_label(`strategy_id`, `key`, `value`, `compare_type`) VALUES " + values := make([]string, 0) + args := make([]interface{}, 0) + + for i := range conditions { + item := conditions[i] + values = append(values, "(?,?,?,?)") + args = append(args, id, item.Key, item.Value, item.CompareFunc) + } + + savePrincipalSql += strings.Join(values, ",") + + log.Debug("[Store][Strategy] save policy conditions", zap.String("sql", savePrincipalSql), + zap.Any("args", args)) + + _, err := tx.Exec(savePrincipalSql, args...) + return err +} + +// addPolicyPrincipals +func (s *strategyStore) addPolicyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { if len(principals) == 0 { return nil } - savePrincipalSql := "INSERT IGNORE INTO auth_principal(strategy_id, principal_id, principal_role) VALUES " + savePrincipalSql := "INSERT IGNORE INTO auth_principal(strategy_id, principal_id, principal_role, extend_info) VALUES " values := make([]string, 0) args := make([]interface{}, 0) for i := range principals { principal := principals[i] - values = append(values, "(?,?,?)") - args = append(args, id, principal.PrincipalID, principal.PrincipalType) + values = append(values, "(?,?,?,?)") + args = append(args, id, principal.PrincipalID, principal.PrincipalType, utils.MustJson(principal.Extend)) } savePrincipalSql += strings.Join(values, ",") @@ -224,8 +296,8 @@ func (s *strategyStore) addStrategyPrincipals(tx *BaseTx, id string, principals return err } -// deleteStrategyPrincipals -func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, +// deletePolicyPrincipals +func (s *strategyStore) deletePolicyPrincipals(tx *BaseTx, id string, principals []authcommon.Principal) error { if len(principals) == 0 { return nil @@ -245,7 +317,8 @@ func (s *strategyStore) deleteStrategyPrincipals(tx *BaseTx, id string, return nil } -func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { +// addPolicyResources . +func (s *strategyStore) addPolicyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { if len(resources) == 0 { return nil } @@ -271,7 +344,8 @@ func (s *strategyStore) addStrategyResources(tx *BaseTx, id string, resources [] return err } -func (s *strategyStore) deleteStrategyResources(tx *BaseTx, id string, +// deletePolicyResources . +func (s *strategyStore) deletePolicyResources(tx *BaseTx, id string, resources []authcommon.StrategyResource) error { if len(resources) == 0 { @@ -280,11 +354,9 @@ func (s *strategyStore) deleteStrategyResources(tx *BaseTx, id string, for i := range resources { resource := resources[i] - saveResSql := "DELETE FROM auth_strategy_resource WHERE strategy_id = ? AND res_id = ? AND res_type = ?" if _, err := tx.Exec( - saveResSql, - []interface{}{resource.StrategyID, resource.ResID, resource.ResType}..., + saveResSql, []interface{}{resource.StrategyID, resource.ResID, resource.ResType}..., ); err != nil { return err } @@ -348,7 +420,7 @@ func (s *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyR saveResSql = "DELETE FROM auth_strategy_resource WHERE res_id = ? AND res_type = ?" args = append(args, resource.ResID, resource.ResType) } - if _, err := tx.Exec(saveResSql, args...); err != nil { + if _, err = tx.Exec(saveResSql, args...); err != nil { return err } // 主要是为了能够触发 StrategyCache 的刷新逻辑 @@ -358,7 +430,7 @@ func (s *strategyStore) RemoveStrategyResources(resources []authcommon.StrategyR } } - if err := tx.Commit(); err != nil { + if err = tx.Commit(); err != nil { log.Errorf("[Store][Strategy] add auth_strategy tx commit err: %s", err.Error()) return err } @@ -445,168 +517,6 @@ func (s *strategyStore) getStrategyDetail(row *sql.Row) (*authcommon.StrategyDet return ret, nil } -// GetStrategies 获取策略列表 -func (s *strategyStore) GetStrategies(filters map[string]string, offset uint32, limit uint32) (uint32, - []*authcommon.StrategyDetail, error) { - showDetail := filters["show_detail"] - delete(filters, "show_detail") - - filters["ag.flag"] = "0" - - return s.listStrategies(filters, offset, limit, showDetail == "true") -} - -// listStrategies -func (s *strategyStore) listStrategies(filters map[string]string, offset uint32, limit uint32, - showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { - - querySql := - `SELECT - ag.id, - ag.name, - ag.action, - ag.owner, - ag.comment, - ag.default, - ag.revision, - ag.flag, - UNIX_TIMESTAMP(ag.ctime), - UNIX_TIMESTAMP(ag.mtime) - FROM - ( - auth_strategy ag - LEFT JOIN auth_strategy_resource ar ON ag.id = ar.strategy_id - ) - LEFT JOIN auth_principal ap ON ag.id = ap.strategy_id ` - countSql := ` - SELECT COUNT(DISTINCT ag.id) - FROM - ( - auth_strategy ag - LEFT JOIN auth_strategy_resource ar ON ag.id = ar.strategy_id - ) - LEFT JOIN auth_principal ap ON ag.id = ap.strategy_id - ` - - return s.queryStrategies(s.master.Query, filters, RuleFilters, querySql, countSql, - offset, limit, showDetail) -} - -// queryStrategies 通用的查询策略列表 -func (s *strategyStore) queryStrategies( - handler QueryHandler, - filters map[string]string, mapping map[string]string, - querySqlPrefix string, countSqlPrefix string, - offset uint32, limit uint32, showDetail bool) (uint32, []*authcommon.StrategyDetail, error) { - querySql := querySqlPrefix - countSql := countSqlPrefix - - args := make([]interface{}, 0) - if len(filters) != 0 { - querySql += " WHERE " - countSql += " WHERE " - firstIndex := true - for k, v := range filters { - needLike := false - if !firstIndex { - querySql += " AND " - countSql += " AND " - } - firstIndex = false - - if val, ok := mapping[k]; ok { - if _, exist := RuleNeedLikeFilters[k]; exist { - needLike = true - } - k = val - } - - if needLike { - if utils.IsPrefixWildName(v) { - v = v[:len(v)-1] - } - querySql += (" " + k + " like ? ") - countSql += (" " + k + " like ? ") - args = append(args, "%"+v+"%") - } else if k == "ag.owner" { - querySql += " (ag.owner = ? OR (ap.principal_id = ? AND ap.principal_role = 1 )) " - countSql += " (ag.owner = ? OR (ap.principal_id = ? AND ap.principal_role = 1 )) " - args = append(args, v, v) - } else { - querySql += (" " + k + " = ? ") - countSql += (" " + k + " = ? ") - args = append(args, v) - } - } - } - - count, err := queryEntryCount(s.master, countSql, args) - if err != nil { - return 0, nil, store.Error(err) - } - - querySql += " GROUP BY ag.id ORDER BY ag.mtime LIMIT ?, ? " - args = append(args, offset, limit) - - ret, err := s.collectStrategies(s.master.Query, querySql, args, showDetail) - if err != nil { - return 0, nil, err - } - - return count, ret, nil -} - -// collectStrategies 执行真正的 sql 并从 rows 中获取策略列表 -func (s *strategyStore) collectStrategies(handler QueryHandler, querySql string, - args []interface{}, showDetail bool) ([]*authcommon.StrategyDetail, error) { - log.Debug("[Store][Strategy] get simple strategies", zap.String("query sql", querySql), - zap.Any("args", args)) - - rows, err := handler(querySql, args...) - if err != nil { - log.Error("[Store][Strategy] get simple strategies", zap.String("query sql", querySql), - zap.Any("args", args)) - return nil, store.Error(err) - } - defer func() { - _ = rows.Close() - }() - - idMap := make(map[string]struct{}) - - ret := make([]*authcommon.StrategyDetail, 0, 16) - for rows.Next() { - detail, err := fetchRown2StrategyDetail(rows) - if err != nil { - return nil, store.Error(err) - } - - // 为了避免数据重复被加入到 slice 中,做一个 map 去重 - if _, ok := idMap[detail.ID]; ok { - continue - } - idMap[detail.ID] = struct{}{} - - if showDetail { - resArr, err := s.getStrategyResources(s.slave.Query, detail.ID) - if err != nil { - return nil, store.Error(err) - } - principals, err := s.getStrategyPrincipals(s.slave.Query, detail.ID) - if err != nil { - return nil, store.Error(err) - } - - detail.Resources = resArr - detail.Principals = principals - } - - ret = append(ret, detail) - } - - return ret, nil -} - func (s *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([]*authcommon.StrategyDetail, error) { tx, err := s.slave.Begin() if err != nil { @@ -646,9 +556,19 @@ func (s *strategyStore) GetMoreStrategies(mtime time.Time, firstUpdate bool) ([] if err != nil { return nil, store.Error(err) } + conditions, err := s.getStrategyConditions(s.slave.Query, detail.ID) + if err != nil { + return nil, store.Error(err) + } + functions, err := s.getStrategyFunctions(s.slave.Query, detail.ID) + if err != nil { + return nil, store.Error(err) + } detail.Resources = resArr detail.Principals = principals + detail.CalleeMethods = functions + detail.Conditions = conditions ret = append(ret, detail) } @@ -693,7 +613,7 @@ func (s *strategyStore) GetStrategyResources(principalId string, func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id string) ([]authcommon.Principal, error) { - rows, err := queryHander("SELECT principal_id, principal_role FROM auth_principal WHERE strategy_id = ?", id) + rows, err := queryHander("SELECT principal_id, principal_role, extend_info FROM auth_principal WHERE strategy_id = ?", id) if err != nil { switch err { case sql.ErrNoRows: @@ -709,15 +629,72 @@ func (s *strategyStore) getStrategyPrincipals(queryHander QueryHandler, id strin for rows.Next() { res := new(authcommon.Principal) - if err := rows.Scan(&res.PrincipalID, &res.PrincipalType); err != nil { + var extend string + if err := rows.Scan(&res.PrincipalID, &res.PrincipalType, &extend); err != nil { return nil, store.Error(err) } + res.Extend = map[string]string{} + _ = json.Unmarshal([]byte(extend), &res.Extend) principals = append(principals, *res) } return principals, nil } +func (s *strategyStore) getStrategyConditions(queryHander QueryHandler, id string) ([]authcommon.Condition, error) { + + rows, err := queryHander("SELECT `key`, `value`, `compare_type` FROM auth_strategy_label WHERE strategy_id = ?", id) + if err != nil { + switch err { + case sql.ErrNoRows: + log.Info("[Store][Strategy] not found link condition", zap.String("strategy-id", id)) + return nil, nil + default: + return nil, store.Error(err) + } + } + defer rows.Close() + + conditions := make([]authcommon.Condition, 0) + + for rows.Next() { + res := new(authcommon.Condition) + if err := rows.Scan(&res.Key, &res.Value, &res.CompareFunc); err != nil { + return nil, store.Error(err) + } + conditions = append(conditions, *res) + } + + return conditions, nil +} + +func (s *strategyStore) getStrategyFunctions(queryHander QueryHandler, id string) ([]string, error) { + + rows, err := queryHander("SELECT `function` FROM auth_strategy_label WHERE strategy_id = ?", id) + if err != nil { + switch err { + case sql.ErrNoRows: + log.Info("[Store][Strategy] not found link functions", zap.String("strategy-id", id)) + return nil, nil + default: + return nil, store.Error(err) + } + } + defer rows.Close() + + functions := make([]string, 0) + + for rows.Next() { + var item string + if err := rows.Scan(&item); err != nil { + return nil, store.Error(err) + } + functions = append(functions, item) + } + + return functions, nil +} + func (s *strategyStore) getStrategyResources(queryHander QueryHandler, id string) ([]authcommon.StrategyResource, error) { querySql := "SELECT res_id, res_type FROM auth_strategy_resource WHERE strategy_id = ?" rows, err := queryHander(querySql, id) diff --git a/store/mysql/tool.go b/store/mysql/tool.go index 62a6fd52a..55c3c4897 100644 --- a/store/mysql/tool.go +++ b/store/mysql/tool.go @@ -42,7 +42,7 @@ func (t *toolStore) GetUnixSecond(maxWait time.Duration) (int64, error) { defer rows.Close() timePass := time.Since(startTime) if maxWait != 0 && timePass > maxWait { - log.Infof("[Store][database] query now spend %s, exceed %s, skip", timePass, maxWait) + log.Warnf("[Store][database] query now spend %s, exceed %s, skip", timePass, maxWait) return 0, nil } var value int64 diff --git a/store/mysql/user.go b/store/mysql/user.go index 2a94f51ef..cdc11f611 100644 --- a/store/mysql/user.go +++ b/store/mysql/user.go @@ -69,19 +69,11 @@ func (u *userStore) AddUser(tx store.Tx, user *authcommon.User) error { } func (u *userStore) addUser(tx *BaseTx, user *authcommon.User) error { - - tx, err := u.master.Begin() - if err != nil { - return err - } - - defer func() { _ = tx.Rollback() }() - addSql := "INSERT INTO user(`id`, `name`, `password`, `owner`, `source`, `token`, " + " `comment`, `flag`, `user_type`, " + " `ctime`, `mtime`, `mobile`, `email`) VALUES (?,?,?,?,?,?,?,?,?,sysdate(),sysdate(),?,?)" - _, err = tx.Exec(addSql, []interface{}{ + _, err := tx.Exec(addSql, []interface{}{ user.ID, user.Name, user.Password, diff --git a/test/data/bolt-data.yaml b/test/data/bolt-data.yaml new file mode 100644 index 000000000..2c7827204 --- /dev/null +++ b/test/data/bolt-data.yaml @@ -0,0 +1,174 @@ +# Tencent is pleased to support the open source community by making Polaris available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +users: + - name: polaris + token: nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I= + password: $2a$10$3izWuZtE5SBdAtSZci.gs.iZ2pAn9I8hEqYrC6gwJp1dyjqQnrrum + id: 65e4789a6d5b49669adf1e9e8387549c + tokenenable: true + type: 20 + valid: true +policies: + - id: fbca9bfa04ae4ead86e1ecf5811e32a9 + name: (用户) polaris的默认策略 + action: READ_WRITE + comment: default admin + default: true + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["*"] + resources: + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 6 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 7 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 20 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 0 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 3 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 4 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 5 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 21 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 22 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 23 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 1 + resid: "*" + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + restype: 2 + resid: "*" + conditions: [] + principals: + - strategyid: fbca9bfa04ae4ead86e1ecf5811e32a9 + principalid: 65e4789a6d5b49669adf1e9e8387549c + principaltype: 1 + valid: true + revision: fbca9bfa04ae4ead86e1ecf5811e32a9 + metadata: {} + - id: bfa04ae1e32a94fbca9ead86e1ecf581 + name: 全局只读策略 + action: ALLOW + comment: global resources read onyly + default: false + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["Describe*", "List*", "Get*"] + resources: + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 6 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 7 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 20 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 0 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 3 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 4 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 5 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 21 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 22 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 23 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 1 + resid: "*" + - strategyid: bfa04ae1e32a94fbca9ead86e1ecf581 + restype: 2 + resid: "*" + conditions: [] + principals: [] + valid: true + revision: 2a04ae4ead86e1e9bfacf59fbca811e3 + metadata: {} + - id: e3d86e1ecf5812bfa04ae1a94fbca9ea + name: 全局读写策略 + action: ALLOW + comment: global resources read and write + default: false + owner: 65e4789a6d5b49669adf1e9e8387549c + calleemethods: ["*"] + resources: + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 6 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 7 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 20 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 0 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 3 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 4 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 5 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 21 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 22 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 23 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 1 + resid: "*" + - strategyid: e3d86e1ecf5812bfa04ae1a94fbca9ea + restype: 2 + resid: "*" + conditions: [] + principals: [] + valid: true + revision: 4ead86e1e9bfac2a04aef59fbca811e3 + metadata: {} diff --git a/test/integrate/client_grpc_test.go b/test/integrate/client_grpc_test.go index fe40a5ad7..eaddaa625 100644 --- a/test/integrate/client_grpc_test.go +++ b/test/integrate/client_grpc_test.go @@ -203,12 +203,12 @@ func TestClientGRPC_DiscoverServices(t *testing.T) { }, }) if err != nil { - t.Fatalf("discover services fail") + t.Fatalf("discover services fail: %+v", err) } - assert.False(t, len(resp.Services) == 0, "discover services response not empty") - assert.Truef(t, len(newSvcs) == len(resp.Services), - "discover services size not equal, expect : %d, actual : %s", len(newSvcs), len(resp.Services)) + assert.False(t, len(resp.GetServices()) == 0, "discover services response not empty") + assert.Truef(t, len(newSvcs) == len(resp.GetServices()), + "discover services size not equal, expect : %d, actual : %s", len(newSvcs), len(resp.GetServices())) }) }) diff --git a/test/integrate/http/client.go b/test/integrate/http/client.go index db868b94d..cac98ae8b 100644 --- a/test/integrate/http/client.go +++ b/test/integrate/http/client.go @@ -25,6 +25,7 @@ import ( "net/http" "github.com/golang/protobuf/jsonpb" + "github.com/google/uuid" apiconfig "github.com/polarismesh/specification/source/go/api/v1/config_manage" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" ) @@ -47,6 +48,10 @@ type Client struct { // SendRequest 发送请求 HTTP Post/Put func (c *Client) SendRequest(method string, url string, body *bytes.Buffer) (*http.Response, error) { + return c.SendRequestWithRequestID(uuid.New().String(), method, url, body) +} + +func (c *Client) SendRequestWithRequestID(requestId, method string, url string, body *bytes.Buffer) (*http.Response, error) { var request *http.Request var err error @@ -61,7 +66,7 @@ func (c *Client) SendRequest(method string, url string, body *bytes.Buffer) (*ht } request.Header.Add("Content-Type", "application/json") - request.Header.Add("Request-Id", "test") + request.Header.Add("Request-Id", requestId) request.Header.Add("X-Polaris-Token", "nu/0WRA4EqSR1FagrjRj0fZwPXuGlMpX+zCuWu4uMqy8xr1vRjisSbA25aAC3mtU8MeeRsKhQiDAynUR09I=") response, err := c.Worker.Do(request) diff --git a/test/integrate/http/namespace.go b/test/integrate/http/namespace.go index 93fe85a80..6cb00fa57 100644 --- a/test/integrate/http/namespace.go +++ b/test/integrate/http/namespace.go @@ -64,7 +64,7 @@ func (c *Client) CreateNamespaces(namespaces []*apimodel.Namespace) (*apiservice return nil, err } - response, err := c.SendRequest("POST", url, body) + response, err := c.SendRequestWithRequestID("CreateNamespaces", "POST", url, body) if err != nil { fmt.Printf("%v\n", err) return nil, err @@ -179,7 +179,7 @@ func (c *Client) GetNamespaces(namespaces []*apimodel.Namespace) ([]*apimodel.Na } url = c.CompleteURL(url, params) - response, err := c.SendRequest("GET", url, nil) + response, err := c.SendRequestWithRequestID("GetNamespaces", "GET", url, nil) if err != nil { return nil, err } @@ -197,7 +197,7 @@ func (c *Client) GetNamespaces(namespaces []*apimodel.Namespace) ([]*apimodel.Na namespacesSize := len(namespaces) if ret.GetAmount() == nil || ret.GetAmount().GetValue() != uint32(namespacesSize) { - return nil, errors.New("invalid batch amount") + return nil, fmt.Errorf("invalid batch amount: %d %d", ret.GetAmount().GetValue(), namespacesSize) } if ret.GetSize() == nil || ret.GetSize().GetValue() != uint32(namespacesSize) { diff --git a/test/integrate/namespace_test.go b/test/integrate/namespace_test.go index 361a8bdce..60352e7d3 100644 --- a/test/integrate/namespace_test.go +++ b/test/integrate/namespace_test.go @@ -29,6 +29,7 @@ import ( v1 "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/test/integrate/http" "github.com/polarismesh/polaris/test/integrate/resource" ) @@ -56,7 +57,7 @@ func TestNamespace(t *testing.T) { // 查询命名空间 _, err = client.GetNamespaces(namespaces) if err != nil { - t.Fatalf("get namespaces fail: %s", err.Error()) + t.Fatalf("get namespaces: %#v fail: %s", utils.MustJson(namespaces), err.Error()) } t.Log("get namespaces success") diff --git a/test/suit/test_suit.go b/test/suit/test_suit.go index 00b8d3ff1..f51a98f95 100644 --- a/test/suit/test_suit.go +++ b/test/suit/test_suit.go @@ -44,7 +44,9 @@ import ( "github.com/polarismesh/polaris/common/metrics" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/config" + "github.com/polarismesh/polaris/namespace" ns "github.com/polarismesh/polaris/namespace" + "github.com/polarismesh/polaris/namespace/interceptor" "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/service" "github.com/polarismesh/polaris/service/batch" @@ -219,6 +221,9 @@ func (d *DiscoverTestSuit) loadConfig() error { fmt.Printf("[ERROR] %v\n", err) return err } + if os.Getenv("STORE_MODE") != "sqldb" { + d.cfg.Store.Option["loadFile"] = testdata.Path("bolt-data.yaml") + } d.cfg.Naming.Interceptors = service.GetChainOrder() d.cfg.Config.Interceptors = config.GetChainOrder() return err @@ -321,7 +326,7 @@ func (d *DiscoverTestSuit) initialize(opts ...options) error { } // 初始化命名空间模块 - namespaceSvr, err := ns.TestInitialize(ctx, &d.cfg.Namespace, d.Storage, cacheMgn, d.userMgn, d.strategyMgn) + namespaceSvr, err := TestNamespaceInitialize(ctx, &d.cfg.Namespace, d.Storage, cacheMgn, d.userMgn, d.strategyMgn) if err != nil { panic(err) } @@ -391,6 +396,19 @@ func (d *DiscoverTestSuit) initialize(opts ...options) error { return nil } +func TestNamespaceInitialize(ctx context.Context, nsOpt *namespace.Config, storage store.Store, cacheMgr *cache.CacheManager, + userMgn auth.UserServer, strategyMgn auth.StrategyServer) (namespace.NamespaceOperateServer, error) { + + ctx = context.WithValue(ctx, interceptor.ContextKeyUserSvr{}, userMgn) + ctx = context.WithValue(ctx, interceptor.ContextKeyPolicySvr{}, strategyMgn) + + _, proxySvr, err := namespace.InitServer(ctx, nsOpt, storage, cacheMgr) + if err != nil { + return nil, err + } + return proxySvr, nil +} + func (d *DiscoverTestSuit) Destroy() { d.cancel() if svr, ok := d.configOriginSvr.(*config.Server); ok { diff --git a/version b/version index 13e94ce5c..e9c1a1883 100644 --- a/version +++ b/version @@ -1 +1 @@ -v1.18.0 \ No newline at end of file +v1.19.0-alpha.0 \ No newline at end of file