Skip to content

Commit

Permalink
Use grpc client interceptors to properly check for auth requirement (#…
Browse files Browse the repository at this point in the history
…315)

* Use grpc client interceptors to properly check for auth requirement

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Some refactor and add unit tests

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* PR Comments

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* lint

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* unit tests

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Attempt a random port

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Listen to localhost only

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* PR Comments

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* use chain unary interceptor instead

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* only log on errors

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Attempt to disable error check

Signed-off-by: Haytham Abuelfutuh <[email protected]>

Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
EngHabu authored Sep 9, 2022
1 parent 8b176dd commit 3e5066a
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 52 deletions.
89 changes: 89 additions & 0 deletions flyteidl/clients/go/admin/atomic_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package admin

import (
"context"
"sync/atomic"

"google.golang.org/grpc/credentials"

stdlibAtomic "github.com/flyteorg/flytestdlib/atomic"
)

// atomicPerRPCCredentials provides a convenience on top of atomic.Value and credentials.PerRPCCredentials to be thread-safe.
type atomicPerRPCCredentials struct {
atomic.Value
}

func (t *atomicPerRPCCredentials) Store(properties credentials.PerRPCCredentials) {
t.Value.Store(properties)
}

func (t *atomicPerRPCCredentials) Load() credentials.PerRPCCredentials {
val := t.Value.Load()
if val == nil {
return CustomHeaderTokenSource{}
}

return val.(credentials.PerRPCCredentials)
}

func newAtomicPerPRCCredentials() *atomicPerRPCCredentials {
return &atomicPerRPCCredentials{
Value: atomic.Value{},
}
}

// PerRPCCredentialsFuture is a future wrapper for credentials.PerRPCCredentials that can act as one and also be
// materialized later.
type PerRPCCredentialsFuture struct {
perRPCCredentials *atomicPerRPCCredentials
initialized stdlibAtomic.Bool
}

// GetRequestMetadata gets the authorization metadata as a map using a TokenSource to generate a token
func (ts *PerRPCCredentialsFuture) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if ts.initialized.Load() {
tp := ts.perRPCCredentials.Load()
return tp.GetRequestMetadata(ctx, uri...)
}

return map[string]string{}, nil
}

// RequireTransportSecurity returns whether this credentials class requires TLS/SSL. OAuth uses Bearer tokens that are
// susceptible to MITM (Man-In-The-Middle) attacks that are mitigated by TLS/SSL. We may return false here to make it
// easier to setup auth. However, in a production environment, TLS for OAuth2 is a requirement.
// see also: https://tools.ietf.org/html/rfc6749#section-3.1
func (ts *PerRPCCredentialsFuture) RequireTransportSecurity() bool {
if ts.initialized.Load() {
return ts.perRPCCredentials.Load().RequireTransportSecurity()
}

return false
}

func (ts *PerRPCCredentialsFuture) Store(tokenSource credentials.PerRPCCredentials) {
ts.perRPCCredentials.Store(tokenSource)
ts.initialized.Store(true)
}

func (ts *PerRPCCredentialsFuture) Get() credentials.PerRPCCredentials {
return ts.perRPCCredentials.Load()
}

func (ts *PerRPCCredentialsFuture) IsInitialized() bool {
return ts.initialized.Load()
}

// NewPerRPCCredentialsFuture initializes a new PerRPCCredentialsFuture that can act as a credentials.PerRPCCredentials
// and can also be resolved in the future. Users of the future can check if it has been initialized before by calling
// PerRPCCredentialsFuture.IsInitialized(). Calling PerRPCCredentialsFuture.Get() multiple times will return
// the same stored object (unless it changed in between calls). Calling PerRPCCredentialsFuture.Store() multiple
// times is supported and will result in overriding the old value atomically.
func NewPerRPCCredentialsFuture() *PerRPCCredentialsFuture {
tokenSource := PerRPCCredentialsFuture{
perRPCCredentials: newAtomicPerPRCCredentials(),
}

return &tokenSource
}
57 changes: 57 additions & 0 deletions flyteidl/clients/go/admin/atomic_credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package admin

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAtomicPerRPCCredentials(t *testing.T) {
a := atomicPerRPCCredentials{}
assert.True(t, a.Load().RequireTransportSecurity())

tokenSource := DummyTestTokenSource{}
chTokenSource := NewCustomHeaderTokenSource(tokenSource, true, "my_custom_header")
a.Store(chTokenSource)

assert.False(t, a.Load().RequireTransportSecurity())
}

func TestNewPerRPCCredentialsFuture(t *testing.T) {
f := NewPerRPCCredentialsFuture()
assert.False(t, f.RequireTransportSecurity())
assert.Equal(t, CustomHeaderTokenSource{}, f.Get())

tokenSource := DummyTestTokenSource{}
chTokenSource := NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header")
f.Store(chTokenSource)

assert.True(t, f.Get().RequireTransportSecurity())
assert.True(t, f.RequireTransportSecurity())
}

func ExampleNewPerRPCCredentialsFuture() {
f := NewPerRPCCredentialsFuture()

// Starts uninitialized
fmt.Println("Initialized:", f.IsInitialized())

// Implements credentials.PerRPCCredentials so can be used as one
m, err := f.GetRequestMetadata(context.TODO(), "")
fmt.Println("GetRequestMetadata:", m, "Error:", err)

// Materialize the value later and populate
tokenSource := DummyTestTokenSource{}
f.Store(NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header"))

// Future calls to credentials.PerRPCCredentials methods will use the new instance
m, err = f.GetRequestMetadata(context.TODO(), "")
fmt.Println("GetRequestMetadata:", m, "Error:", err)

// Output:
// Initialized: false
// GetRequestMetadata: map[] Error: <nil>
// GetRequestMetadata: map[my_custom_header:Bearer abc] Error: <nil>
}
81 changes: 81 additions & 0 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package admin

import (
"context"
"fmt"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/logger"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"google.golang.org/grpc"
)

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialized token source provider. Err: %w", err)
}

clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return err
}

wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)
return nil
}

func shouldAttemptToAuthenticate(errorCode codes.Code) bool {
return errorCode == codes.Unauthenticated
}

// newAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error.
// It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc
// pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized,
// it'll take care of refreshing when tokens expire... etc.
// If the first invocation succeeds (either due to grpc.PerRPCCredentials setting the right tokens or the server not
// requiring authentication), the interceptor will be no-op.
// If the first invocation fails with an auth error, this interceptor will then attempt to establish a token source once
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func newAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err)

if st, ok := status.FromError(err); ok {
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
}

return invoker(ctx, method, req, reply, cc, opts...)
}
}
}

return err
}
}
Loading

0 comments on commit 3e5066a

Please sign in to comment.