Skip to content

Commit

Permalink
Plugins for flyteadmin server middleware (try2) (flyteorg#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored May 20, 2022
1 parent ede2297 commit 77b83f6
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 25 deletions.
24 changes: 24 additions & 0 deletions auth/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package auth

import (
"context"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {

identityContext := IdentityContextFromContext(ctx)
if identityContext.IsEmpty() {
return handler(ctx, req)
}

if !identityContext.Scopes().Has(ScopeAll) {
return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope")
}

return handler(ctx, req)
}
61 changes: 61 additions & 0 deletions auth/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package auth

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/apimachinery/pkg/util/sets"
)

func TestBlanketAuthorization(t *testing.T) {
t.Run("authenticated and authorized", func(t *testing.T) {
allScopes := sets.NewString(ScopeAll)
identityCtx := IdentityContext{
audience: "aud",
userID: "uid",
appID: "appid",
scopes: &allScopes,
}
handlerCalled := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return nil, nil
}
ctx := context.WithValue(context.TODO(), ContextKeyIdentityContext, identityCtx)
_, err := BlanketAuthorization(ctx, nil, nil, handler)
assert.NoError(t, err)
assert.True(t, handlerCalled)
})
t.Run("unauthenticated", func(t *testing.T) {
handlerCalled := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return nil, nil
}
ctx := context.TODO()
_, err := BlanketAuthorization(ctx, nil, nil, handler)
assert.NoError(t, err)
assert.True(t, handlerCalled)
})
t.Run("authenticated and not authorized", func(t *testing.T) {
identityCtx := IdentityContext{
audience: "aud",
userID: "uid",
appID: "appid",
}
handlerCalled := false
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return nil, nil
}
ctx := context.WithValue(context.TODO(), ContextKeyIdentityContext, identityCtx)
_, err := BlanketAuthorization(ctx, nil, nil, handler)
asStatus, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, asStatus.Code(), codes.Unauthenticated)
assert.False(t, handlerCalled)
})
}
24 changes: 6 additions & 18 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
)

var defaultCorsHeaders = []string{"Content-Type"}
Expand All @@ -55,34 +53,24 @@ func Serve(ctx context.Context, pluginRegistry *plugins.Registry, additionalHand
return serveGatewayInsecure(ctx, pluginRegistry, serverConfig, authConfig.GetConfig(), storage.GetConfig(), additionalHandlers, adminScope)
}

func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {

identityContext := auth.IdentityContextFromContext(ctx)
if identityContext.IsEmpty() {
return handler(ctx, req)
}

if !identityContext.Scopes().Has(auth.ScopeAll) {
return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope")
}

return handler(ctx, req)
}

// Creates a new gRPC Server with all the configuration
func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig,
storageCfg *storage.Config, authCtx interfaces.AuthenticationContext,
scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) {

logger.Infof(ctx, "Registering default middleware with blanket auth validation")
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization))

// Not yet implemented for streaming
var chainedUnaryInterceptors grpc.UnaryServerInterceptor
if cfg.Security.UseAuth {
logger.Infof(ctx, "Creating gRPC server with authentication")
middlewareInterceptors := plugins.Get[grpc.UnaryServerInterceptor](pluginRegistry, plugins.PluginIDUnaryServiceMiddleware)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor,
auth.GetAuthenticationCustomMetadataInterceptor(authCtx),
grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)),
auth.AuthenticationLoggingInterceptor,
blanketAuthorization,
middlewareInterceptors,
)
} else {
logger.Infof(ctx, "Creating gRPC server without authentication")
Expand Down
5 changes: 3 additions & 2 deletions plugins/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
type PluginID = string

const (
PluginIDWorkflowExecutor PluginID = "WorkflowExecutor"
PluginIDDataProxy PluginID = "DataProxy"
PluginIDWorkflowExecutor PluginID = "WorkflowExecutor"
PluginIDDataProxy PluginID = "DataProxy"
PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware"
)

type AtomicRegistry struct {
Expand Down
9 changes: 4 additions & 5 deletions tests/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ package tests
import (
"context"
"fmt"

"github.com/flyteorg/flytestdlib/database"

"github.com/flyteorg/flyteadmin/pkg/repositories"
runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"

"gorm.io/gorm"

"github.com/flyteorg/flytestdlib/logger"
Expand All @@ -34,9 +33,9 @@ func getDbConfig() *database.DbConfig {
}
}

func getLocalDbConfig() *runtimeInterfaces.DbConfig {
return &runtimeInterfaces.DbConfig{
PostgresConfig: &runtimeInterfaces.PostgresConfig{
func getLocalDbConfig() *database.DbConfig {
return &database.DbConfig{
Postgres: database.PostgresConfig{
Host: "localhost",
Port: 5432,
DbName: "flyteadmin",
Expand Down

0 comments on commit 77b83f6

Please sign in to comment.