diff --git a/auth/interceptor.go b/auth/interceptor.go new file mode 100644 index 000000000..a4d78d4d6 --- /dev/null +++ b/auth/interceptor.go @@ -0,0 +1,24 @@ +package auth + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + + identityContext := IdentityContextFromContext(ctx) + if identityContext.IsEmpty() { + return handler(ctx, req) + } + + if !identityContext.Scopes().Has(ScopeAll) { + return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") + } + + return handler(ctx, req) +} diff --git a/auth/interceptor_test.go b/auth/interceptor_test.go new file mode 100644 index 000000000..862f76a13 --- /dev/null +++ b/auth/interceptor_test.go @@ -0,0 +1,61 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestBlanketAuthorization(t *testing.T) { + t.Run("authenticated and authorized", func(t *testing.T) { + allScopes := sets.NewString(ScopeAll) + identityCtx := IdentityContext{ + audience: "aud", + userID: "uid", + appID: "appid", + scopes: &allScopes, + } + handlerCalled := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + handlerCalled = true + return nil, nil + } + ctx := context.WithValue(context.TODO(), ContextKeyIdentityContext, identityCtx) + _, err := BlanketAuthorization(ctx, nil, nil, handler) + assert.NoError(t, err) + assert.True(t, handlerCalled) + }) + t.Run("unauthenticated", func(t *testing.T) { + handlerCalled := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + handlerCalled = true + return nil, nil + } + ctx := context.TODO() + _, err := BlanketAuthorization(ctx, nil, nil, handler) + assert.NoError(t, err) + assert.True(t, handlerCalled) + }) + t.Run("authenticated and not authorized", func(t *testing.T) { + identityCtx := IdentityContext{ + audience: "aud", + userID: "uid", + appID: "appid", + } + handlerCalled := false + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + handlerCalled = true + return nil, nil + } + ctx := context.WithValue(context.TODO(), ContextKeyIdentityContext, identityCtx) + _, err := BlanketAuthorization(ctx, nil, nil, handler) + asStatus, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, asStatus.Code(), codes.Unauthenticated) + assert.False(t, handlerCalled) + }) +} diff --git a/pkg/server/service.go b/pkg/server/service.go index 77dc460cf..86c552541 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -32,12 +32,10 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/pkg/errors" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" - "google.golang.org/grpc/status" ) var defaultCorsHeaders = []string{"Content-Type"} @@ -55,34 +53,24 @@ func Serve(ctx context.Context, pluginRegistry *plugins.Registry, additionalHand return serveGatewayInsecure(ctx, pluginRegistry, serverConfig, authConfig.GetConfig(), storage.GetConfig(), additionalHandlers, adminScope) } -func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( - resp interface{}, err error) { - - identityContext := auth.IdentityContextFromContext(ctx) - if identityContext.IsEmpty() { - return handler(ctx, req) - } - - if !identityContext.Scopes().Has(auth.ScopeAll) { - return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") - } - - return handler(ctx, req) -} - // Creates a new gRPC Server with all the configuration func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, storageCfg *storage.Config, authCtx interfaces.AuthenticationContext, scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { + + logger.Infof(ctx, "Registering default middleware with blanket auth validation") + pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization)) + // Not yet implemented for streaming var chainedUnaryInterceptors grpc.UnaryServerInterceptor if cfg.Security.UseAuth { logger.Infof(ctx, "Creating gRPC server with authentication") + middlewareInterceptors := plugins.Get[grpc.UnaryServerInterceptor](pluginRegistry, plugins.PluginIDUnaryServiceMiddleware) chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, auth.GetAuthenticationCustomMetadataInterceptor(authCtx), grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), auth.AuthenticationLoggingInterceptor, - blanketAuthorization, + middlewareInterceptors, ) } else { logger.Infof(ctx, "Creating gRPC server without authentication") diff --git a/plugins/registry.go b/plugins/registry.go index 1ab0def69..3c2186326 100644 --- a/plugins/registry.go +++ b/plugins/registry.go @@ -9,8 +9,9 @@ import ( type PluginID = string const ( - PluginIDWorkflowExecutor PluginID = "WorkflowExecutor" - PluginIDDataProxy PluginID = "DataProxy" + PluginIDWorkflowExecutor PluginID = "WorkflowExecutor" + PluginIDDataProxy PluginID = "DataProxy" + PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware" ) type AtomicRegistry struct { diff --git a/tests/bootstrap.go b/tests/bootstrap.go index 54160b637..e9804aa70 100644 --- a/tests/bootstrap.go +++ b/tests/bootstrap.go @@ -6,11 +6,10 @@ package tests import ( "context" "fmt" + "github.com/flyteorg/flytestdlib/database" "github.com/flyteorg/flyteadmin/pkg/repositories" - runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" - "gorm.io/gorm" "github.com/flyteorg/flytestdlib/logger" @@ -34,9 +33,9 @@ func getDbConfig() *database.DbConfig { } } -func getLocalDbConfig() *runtimeInterfaces.DbConfig { - return &runtimeInterfaces.DbConfig{ - PostgresConfig: &runtimeInterfaces.PostgresConfig{ +func getLocalDbConfig() *database.DbConfig { + return &database.DbConfig{ + Postgres: database.PostgresConfig{ Host: "localhost", Port: 5432, DbName: "flyteadmin",