Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Moved the 3legged auth files from flytectl and also added config opti…
Browse files Browse the repository at this point in the history
…on for it (#156)

* Moved the 3legged auth files from flytectl and also added config option for it

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixed the expiry bug

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Changed logic to refresh the token

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Refactored the getAuthenticationDialOption

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Added more unit tests

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixed unit tests

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Added more unit tests

Signed-off-by: Prafulla Mahindrakar <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
pmahindrakar-oss authored and EngHabu committed Apr 29, 2021
1 parent c1a8b34 commit 89044a9
Show file tree
Hide file tree
Showing 22 changed files with 1,245 additions and 9 deletions.
85 changes: 85 additions & 0 deletions clients/go/admin/authtype_enumer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 52 additions & 8 deletions clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ import (
"sync"

"github.com/flyteorg/flyteidl/clients/go/admin/mocks"
leggedoauth "github.com/flyteorg/flyteidl/clients/go/admin/threelegauth"
"github.com/flyteorg/flyteidl/clients/go/admin/threelegauth/interfaces"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/logger"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand All @@ -28,6 +32,14 @@ var (
authMetadataConnection *grpc.ClientConn
)

var (
defaultTokenOrchestrator = NewTokenOrchestrator()
)

func NewTokenOrchestrator() interfaces.FetchTokenOrchestrator {
return leggedoauth.TokenOrchestrator{}
}

// Clientset contains the clients exposed to communicate with various admin services.
type Clientset struct {
adminServiceClient service.AdminServiceClient
Expand All @@ -40,7 +52,7 @@ func (c Clientset) AdminClient() service.AdminServiceClient {
return c.adminServiceClient
}

// AuthServiceClient retrieves the AuthServiceClient
// AuthMetadataClient retrieves the AuthMetadataServiceClient
func (c Clientset) AuthMetadataClient() service.AuthMetadataServiceClient {
return c.authMetadataServiceClient
}
Expand Down Expand Up @@ -90,35 +102,67 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, authClient se

tokenURL = metadata.TokenEndpoint
}

clientMetadata, err := authClient.FlyteClient(ctx, &service.FlyteClientRequest{})
if err != nil {
return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
var tSource oauth2.TokenSource
if cfg.MustAuthTypeConfig() == AuthTypeCLIENTSECRET {
tSource, err = getClientCredentialsTokenSource(ctx, cfg, clientMetadata, tokenURL)
if err != nil {
return nil, err
}
} else if cfg.MustAuthTypeConfig() == AuthTypeTHREELEGGEDAUTH {
tSource, err = getThreeLeggedAuthTokenSource(ctx, authClient)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("unsupported type %v", cfg.AuthType)
}
oauthTokenSource := NewCustomHeaderTokenSource(tSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey)
return grpc.WithPerRPCCredentials(oauthTokenSource), nil
}

// Returns the client credentials token source to be used eg by flytepropeller to communicate with admin/ by CI
func getClientCredentialsTokenSource(ctx context.Context, cfg *Config, clientMetadata *service.FlyteClientResponse, tokenURL string) (oauth2.TokenSource, error) {
secretBytes, err := ioutil.ReadFile(cfg.ClientSecretLocation)
if err != nil {
logger.Errorf(ctx, "Error reading secret from location %s", cfg.ClientSecretLocation)
return nil, err
}

secret := strings.TrimSpace(string(secretBytes))

scopes := cfg.Scopes
if len(scopes) == 0 {
scopes = clientMetadata.Scopes
}

ccConfig := clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: secret,
TokenURL: tokenURL,
Scopes: scopes,
}
return ccConfig.TokenSource(ctx), nil
}

tSource := ccConfig.TokenSource(ctx)
oauthTokenSource := NewCustomHeaderTokenSource(tSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey)
return grpc.WithPerRPCCredentials(oauthTokenSource), nil
// Returns the token source which would be used for three legged oauth. eg : for admin to authorize access to flytectl
func getThreeLeggedAuthTokenSource(ctx context.Context, authClient service.AuthMetadataServiceClient) (oauth2.TokenSource, error) {
var err error
// explicitly ignore error while fetching token from cache.
authToken, err := defaultTokenOrchestrator.FetchTokenFromCacheOrRefreshIt(ctx)
if err != nil {
logger.Warnf(ctx, "failed fetching from cache due to %v", err)
}
if authToken == nil {
// Fetch using auth flow
if authToken, err = defaultTokenOrchestrator.FetchTokenFromAuthFlow(ctx, authClient); err != nil {
logger.Errorf(ctx, "Error fetching token using auth flow due to %v", err)
return nil, err
}
}
return &leggedoauth.DefaultHeaderTokenSource{
DefaultHeaderToken: authToken,
}, nil
}

// InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client.
Expand Down
134 changes: 133 additions & 1 deletion clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@ package admin

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/url"
"sync"
"testing"
"time"

"github.com/flyteorg/flyteidl/clients/go/admin/mocks"
mcokauthinterfaces "github.com/flyteorg/flyteidl/clients/go/admin/threelegauth/interfaces/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"

"github.com/flyteorg/flytestdlib/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
)

func TestInitializeAndGetAdminClient(t *testing.T) {
Expand Down Expand Up @@ -73,6 +81,130 @@ func TestGetAdditionalAdminClientConfigOptions(t *testing.T) {
})
}

func TestGetAuthenticationDialOptionClientSecret(t *testing.T) {
ctx := context.Background()

u, _ := url.Parse("localhost:8089")
adminServiceConfig := &Config{
ClientSecretLocation: "testdata/secret_key",
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: "CLIENTSECRET",
PerRetryTimeout: config.Duration{Duration: 1 * time.Second},
}
t.Run("legal", func(t *testing.T) {
metatdata := &service.OAuth2MetadataResponse{
TokenEndpoint: "/token",
ScopesSupported: []string{"code", "all"},
}
clientMetatadata := &service.FlyteClientResponse{
AuthorizationMetadataKey: "flyte_authorization",
}
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockAuthClient.OnOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metatdata, nil)
mockAuthClient.OnFlyteClientMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
t.Run("error during oauth2Metatdata", func(t *testing.T) {
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockAuthClient.OnOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed"))
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.Nil(t, dialOption)
assert.NotNil(t, err)
})
t.Run("error during flyte client", func(t *testing.T) {
metatdata := &service.OAuth2MetadataResponse{
TokenEndpoint: "/token",
ScopesSupported: []string{"code", "all"},
}
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockAuthClient.OnOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metatdata, nil)
mockAuthClient.OnFlyteClientMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed"))
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.Nil(t, dialOption)
assert.NotNil(t, err)
})
incorrectSecretLocConfig := &Config{
ClientSecretLocation: "testdata/secret_key_invalid",
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: "CLIENTSECRET",
PerRetryTimeout: config.Duration{Duration: 1 * time.Second},
}
t.Run("incorrect client secret loc", func(t *testing.T) {
metatdata := &service.OAuth2MetadataResponse{
TokenEndpoint: "/token",
ScopesSupported: []string{"code", "all"},
}
clientMetatadata := &service.FlyteClientResponse{
AuthorizationMetadataKey: "flyte_authorization",
}
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockAuthClient.OnOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metatdata, nil)
mockAuthClient.OnFlyteClientMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
dialOption, err := getAuthenticationDialOption(ctx, incorrectSecretLocConfig, mockAuthClient)
assert.Nil(t, dialOption)
assert.NotNil(t, err)
})
}

func TestGetAuthenticationDialOptionThreeLegAuth(t *testing.T) {
ctx := context.Background()
mockAuthClient := new(mocks.AuthMetadataServiceClient)
u, _ := url.Parse("localhost:8089")
adminServiceConfig := &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: "THREELEGGEDAUTH",
PerRetryTimeout: config.Duration{Duration: 1 * time.Second},
}
metatdata := &service.OAuth2MetadataResponse{
TokenEndpoint: "/token",
ScopesSupported: []string{"code", "all"},
}
clientMetatadata := &service.FlyteClientResponse{
AuthorizationMetadataKey: "flyte_authorization",
}
mockAuthClient.OnOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metatdata, nil)
mockAuthClient.OnFlyteClientMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
t.Run("legal from cache", func(t *testing.T) {
mockTokenOrchestrator := &mcokauthinterfaces.FetchTokenOrchestrator{}
defaultTokenOrchestrator = mockTokenOrchestrator
plan, _ := ioutil.ReadFile("threelegauth/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.Nil(t, err)
mockTokenOrchestrator.OnFetchTokenFromCacheOrRefreshItMatch(mock.Anything, mock.Anything).Return(&tokenData, nil)
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
t.Run("legal from auth flow", func(t *testing.T) {
mockTokenOrchestrator := &mcokauthinterfaces.FetchTokenOrchestrator{}
defaultTokenOrchestrator = mockTokenOrchestrator
plan, _ := ioutil.ReadFile("threelegauth/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.Nil(t, err)
mockTokenOrchestrator.OnFetchTokenFromCacheOrRefreshItMatch(mock.Anything, mock.Anything).Return(nil, nil)
mockTokenOrchestrator.OnFetchTokenFromAuthFlowMatch(mock.Anything, mock.Anything).Return(&tokenData, nil)
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
t.Run("illegal", func(t *testing.T) {
mockTokenOrchestrator := &mcokauthinterfaces.FetchTokenOrchestrator{}
defaultTokenOrchestrator = mockTokenOrchestrator
mockTokenOrchestrator.OnFetchTokenFromCacheOrRefreshItMatch(mock.Anything, mock.Anything).Return(nil, nil)
mockTokenOrchestrator.OnFetchTokenFromAuthFlowMatch(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("fail"))
dialOption, err := getAuthenticationDialOption(ctx, adminServiceConfig, mockAuthClient)
assert.Nil(t, dialOption)
assert.NotNil(t, err)
})
}

func ExampleInitializeClients() {
// Create an AuthClient
ctx := context.Background()
Expand Down
Loading

0 comments on commit 89044a9

Please sign in to comment.