diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index b1ede68dbd..3403267847 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -3,10 +3,12 @@ package admin import ( "context" "fmt" + "net/http" "github.com/flyteorg/flyteidl/clients/go/admin/cache" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytestdlib/logger" + "golang.org/x/oauth2" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -50,6 +52,21 @@ func shouldAttemptToAuthenticate(errorCode codes.Code) bool { return errorCode == codes.Unauthenticated } +// Set up http client used in oauth2 +func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context { + httpClient := &http.Client{} + + if len(cfg.HTTPProxyURL.String()) > 0 { + // create a transport that uses the proxy + transport := &http.Transport{ + Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL), + } + httpClient.Transport = transport + } + + return context.WithValue(ctx, oauth2.HTTPClient, httpClient) +} + // NewAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error. // It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc // pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized, @@ -62,6 +79,8 @@ func shouldAttemptToAuthenticate(errorCode codes.Code) bool { // be able to find and acquire a valid AccessToken to annotate the request with. func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = setHTTPClientContext(ctx, cfg) + err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) diff --git a/flyteidl/clients/go/admin/config.go b/flyteidl/clients/go/admin/config.go index dd0652606b..e6c5ad06df 100644 --- a/flyteidl/clients/go/admin/config.go +++ b/flyteidl/clients/go/admin/config.go @@ -79,6 +79,9 @@ type Config struct { // find the full schema here https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L625 // Note that required packages may need to be preloaded to support certain service config. For example "google.golang.org/grpc/balancer/roundrobin" should be preloaded to have round-robin policy supported. DefaultServiceConfig string `json:"defaultServiceConfig" pdflag:",Set the default service config for the admin gRPC client"` + + // HTTPProxyURL allows operators to access external OAuth2 servers using an external HTTP Proxy + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",OPTIONAL: HTTP Proxy to be used for OAuth requests."` } var ( diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index ba6bb0a469..e836bf81a6 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -227,13 +227,14 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { token, err := s.new.Token() if err != nil { + logger.Warnf(s.ctx, "failed to get token: %w", err) return nil, fmt.Errorf("failed to get token: %w", err) } logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) err = s.tokenCache.SaveToken(token) if err != nil { - logger.Warnf(s.ctx, "failed to cache token") + logger.Warnf(s.ctx, "failed to cache token: %w", err) } return token, nil