Skip to content

Commit

Permalink
ADD support for PodIdentityProvider config in azure pipeline trigger (#…
Browse files Browse the repository at this point in the history
…4867)

Co-authored-by: Zbynek Roubalik <[email protected]>
  • Loading branch information
toniiiik and zroubalik authored Jan 16, 2024
1 parent b220384 commit 6d66962
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Here is an overview of all new **experimental** features:
- **General**: Request all ScaledObject/ScaledJob triggers in parallel ([#5276](https://github.com/kedacore/keda/issues/5276))
- **General**: Support TriggerAuthentication properties from ConfigMap ([#4830](https://github.com/kedacore/keda/issues/4830))
- **General**: Use client-side round-robin load balancing for grpc calls ([#5224](https://github.com/kedacore/keda/issues/5224))
- **Azure Pipelines Scaler**: Add support for workload identity authentication ([#5013](https://github.com/kedacore/keda/issues/5013))
- **GCP pubsub scaler**: Support distribution-valued metrics and metrics from topics ([#5070](https://github.com/kedacore/keda/issues/5070))
- **GCP stackdriver scaler**: Support valueIfNull parameter ([#5345](https://github.com/kedacore/keda/pull/5345))
- **Hashicorp Vault**: Add support to get secret that needs write operation (e.g. pki) ([#5067](https://github.com/kedacore/keda/issues/5067))
Expand Down
158 changes: 119 additions & 39 deletions pkg/scalers/azure_pipelines_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ import (
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/go-logr/logr"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

Expand All @@ -27,6 +32,10 @@ type JobRequests struct {
Value []JobRequest `json:"value"`
}

const (
devopsResource = "499b84ac-1321-427f-aa17-267ca6975798/.default"
)

type JobRequest struct {
RequestID int `json:"requestId"`
QueueTime time.Time `json:"queueTime"`
Expand Down Expand Up @@ -119,16 +128,17 @@ type azurePipelinesPoolIDResponse struct {
}

type azurePipelinesScaler struct {
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
logger logr.Logger
metricType v2.MetricTargetType
metadata *azurePipelinesMetadata
httpClient *http.Client
podIdentity kedav1alpha1.AuthPodIdentity
logger logr.Logger
}

type azurePipelinesMetadata struct {
organizationURL string
organizationName string
personalAccessToken string
authContext authContext
parent string
demands string
poolID int
Expand All @@ -139,36 +149,68 @@ type azurePipelinesMetadata struct {
requireAllDemands bool
}

type authContext struct {
cred *azidentity.ChainedTokenCredential
pat string
token *azcore.AccessToken
}

// NewAzurePipelinesScaler creates a new AzurePipelinesScaler
func NewAzurePipelinesScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) {
httpClient := kedautil.CreateHTTPClient(config.GlobalHTTPTimeout, false)

logger := InitializeLogger(config, "azure_pipelines_scaler")
metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
}

meta, err := parseAzurePipelinesMetadata(ctx, config, httpClient)
meta, podIdentity, err := parseAzurePipelinesMetadata(ctx, logger, config, httpClient)
if err != nil {
return nil, fmt.Errorf("error parsing azure Pipelines metadata: %w", err)
}

return &azurePipelinesScaler{
metricType: metricType,
metadata: meta,
httpClient: httpClient,
logger: InitializeLogger(config, "azure_pipelines_scaler"),
metricType: metricType,
metadata: meta,
httpClient: httpClient,
podIdentity: podIdentity,
logger: logger,
}, nil
}

func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, error) {
func getAuthMethod(logger logr.Logger, config *ScalerConfig) (string, *azidentity.ChainedTokenCredential, kedav1alpha1.AuthPodIdentity, error) {
pat := ""
if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
pat = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
pat = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else {
switch config.PodIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
return "", nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no personalAccessToken given or PodIdentity provider configured")
case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
cred, err := azure.NewChainedCredential(logger, config.PodIdentity.GetIdentityID(), config.PodIdentity.Provider)
if err != nil {
return "", nil, kedav1alpha1.AuthPodIdentity{}, err
}
return "", cred, kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}, nil
default:
return "", nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("pod identity %s not supported for azure pipelines", config.PodIdentity.Provider)
}
}
return pat, nil, kedav1alpha1.AuthPodIdentity{}, nil
}

func parseAzurePipelinesMetadata(ctx context.Context, logger logr.Logger, config *ScalerConfig, httpClient *http.Client) (*azurePipelinesMetadata, kedav1alpha1.AuthPodIdentity, error) {
meta := azurePipelinesMetadata{}
meta.targetPipelinesQueueLength = defaultTargetPipelinesQueueLength

if val, ok := config.TriggerMetadata["targetPipelinesQueueLength"]; ok {
queueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata targetPipelinesQueueLength: %w", err)
}

meta.targetPipelinesQueueLength = queueLength
Expand All @@ -178,7 +220,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["activationTargetPipelinesQueueLength"]; ok {
activationQueueLength, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing azure pipelines metadata activationTargetPipelinesQueueLength: %w", err)
}

meta.activationTargetPipelinesQueueLength = activationQueueLength
Expand All @@ -190,22 +232,24 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
} else if val, ok := config.TriggerMetadata["organizationURLFromEnv"]; ok && val != "" {
meta.organizationURL = config.ResolvedEnv[val]
} else {
return nil, fmt.Errorf("no organizationURL given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no organizationURL given")
}

if val := meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]; val != "" {
meta.organizationName = meta.organizationURL[strings.LastIndex(meta.organizationURL, "/")+1:]
} else {
return nil, fmt.Errorf("failed to extract organization name from organizationURL")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("failed to extract organization name from organizationURL")
}

if val, ok := config.AuthParams["personalAccessToken"]; ok && val != "" {
// Found the personalAccessToken in a parameter from TriggerAuthentication
meta.personalAccessToken = config.AuthParams["personalAccessToken"]
} else if val, ok := config.TriggerMetadata["personalAccessTokenFromEnv"]; ok && val != "" {
meta.personalAccessToken = config.ResolvedEnv[config.TriggerMetadata["personalAccessTokenFromEnv"]]
} else {
return nil, fmt.Errorf("no personalAccessToken given")
pat, cred, podIdentity, err := getAuthMethod(logger, config)
if err != nil {
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
// Trim any trailing new lines from the Azure Pipelines PAT
meta.authContext = authContext{
pat: strings.TrimSuffix(pat, "\n"),
cred: cred,
token: nil,
}

if val, ok := config.TriggerMetadata["parent"]; ok && val != "" {
Expand All @@ -224,7 +268,7 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["jobsToFetch"]; ok && val != "" {
jobsToFetch, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing jobsToFetch: %w", err)
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("error parsing jobsToFetch: %w", err)
}
meta.jobsToFetch = jobsToFetch
}
Expand All @@ -233,41 +277,40 @@ func parseAzurePipelinesMetadata(ctx context.Context, config *ScalerConfig, http
if val, ok := config.TriggerMetadata["requireAllDemands"]; ok && val != "" {
requireAllDemands, err := strconv.ParseBool(val)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.requireAllDemands = requireAllDemands
}

if val, ok := config.TriggerMetadata["poolName"]; ok && val != "" {
var err error
poolID, err := getPoolIDFromName(ctx, val, &meta, httpClient)
poolID, err := getPoolIDFromName(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
if val, ok := config.TriggerMetadata["poolID"]; ok && val != "" {
var err error
poolID, err := validatePoolID(ctx, val, &meta, httpClient)
poolID, err := validatePoolID(ctx, logger, val, &meta, podIdentity, httpClient)
if err != nil {
return nil, err
return nil, kedav1alpha1.AuthPodIdentity{}, err
}
meta.poolID = poolID
} else {
return nil, fmt.Errorf("no poolName or poolID given")
return nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no poolName or poolID given")
}
}

// Trim any trailing new lines from the Azure Pipelines PAT
meta.personalAccessToken = strings.TrimSuffix(meta.personalAccessToken, "\n")
meta.triggerIndex = config.TriggerIndex

return &meta, nil
return &meta, podIdentity, nil
}

func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func getPoolIDFromName(ctx context.Context, logger logr.Logger, poolName string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
urlString := fmt.Sprintf("%s/_apis/distributedtask/pools?poolName=%s", metadata.organizationURL, url.QueryEscape(poolName))
body, err := getAzurePipelineRequest(ctx, urlString, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, urlString, metadata, podIdentity, httpClient)

if err != nil {
return -1, err
}
Expand All @@ -290,9 +333,10 @@ func getPoolIDFromName(ctx context.Context, poolName string, metadata *azurePipe
return result.Value[0].ID, nil
}

func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelinesMetadata, httpClient *http.Client) (int, error) {
func validatePoolID(ctx context.Context, logger logr.Logger, poolID string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) (int, error) {
urlString := fmt.Sprintf("%s/_apis/distributedtask/pools?poolID=%s", metadata.organizationURL, poolID)
body, err := getAzurePipelineRequest(ctx, urlString, metadata, httpClient)
body, err := getAzurePipelineRequest(ctx, logger, urlString, metadata, podIdentity, httpClient)

if err != nil {
return -1, fmt.Errorf("agent pool with id `%s` not found: %w", poolID, err)
}
Expand All @@ -306,13 +350,49 @@ func validatePoolID(ctx context.Context, poolID string, metadata *azurePipelines
return result.ID, nil
}

func getAzurePipelineRequest(ctx context.Context, urlString string, metadata *azurePipelinesMetadata, httpClient *http.Client) ([]byte, error) {
func getToken(ctx context.Context, metadata *azurePipelinesMetadata, scope string) (string, error) {
if metadata.authContext.token != nil {
//if token expires after more then minute from now let's reuse
if metadata.authContext.token.ExpiresOn.After(time.Now().Add(time.Second * 60)) {
return metadata.authContext.token.Token, nil
}
}
token, err := metadata.authContext.cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{
scope,
},
})

if err != nil {
return "", err
}

metadata.authContext.token = &token

return metadata.authContext.token.Token, nil
}

func getAzurePipelineRequest(ctx context.Context, logger logr.Logger, urlString string, metadata *azurePipelinesMetadata, podIdentity kedav1alpha1.AuthPodIdentity, httpClient *http.Client) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlString, nil)
if err != nil {
return []byte{}, err
}

req.SetBasicAuth("", metadata.personalAccessToken)
switch podIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
//PAT
logger.V(1).Info("making request to ADO REST API using PAT")
req.SetBasicAuth("", metadata.authContext.pat)
case kedav1alpha1.PodIdentityProviderAzureWorkload:
//ADO Resource token
logger.V(1).Info("making request to ADO REST API using managed identity")
aadToken, err := getToken(ctx, metadata, devopsResource)
if err != nil {
return []byte{}, fmt.Errorf("cannot create workload identity credentials: %w", err)
}
logger.V(1).Info("token acquired setting auth header as 'bearer XXXXXX'")
req.Header.Set("Authorization", "Bearer "+aadToken)
}

r, err := httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -340,7 +420,7 @@ func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context)
} else {
urlString = fmt.Sprintf("%s/_apis/distributedtask/pools/%d/jobrequests?$top=%d", s.metadata.organizationURL, s.metadata.poolID, s.metadata.jobsToFetch)
}
body, err := getAzurePipelineRequest(ctx, urlString, s.metadata, s.httpClient)
body, err := getAzurePipelineRequest(ctx, s.logger, urlString, s.metadata, s.podIdentity, s.httpClient)
if err != nil {
return -1, err
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/scalers/azure_pipelines_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"net/http/httptest"
"strings"
"testing"

"github.com/go-logr/logr"
)

const loadCount = 1000 // the size of the pretend pool completed of job requests
Expand Down Expand Up @@ -68,7 +70,9 @@ func TestParseAzurePipelinesMetadata(t *testing.T) {
testData.authParams["organizationURL"] = apiStub.URL
}

_, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testData.resolvedEnv, AuthParams: testData.authParams}, http.DefaultClient)
logger := logr.Discard()

_, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testData.resolvedEnv, AuthParams: testData.authParams}, http.DefaultClient)
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand Down Expand Up @@ -121,8 +125,8 @@ func TestValidateAzurePipelinesPool(t *testing.T) {
"organizationURL": apiStub.URL,
"personalAccessToken": "PAT",
}

_, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: nil, AuthParams: authParams}, http.DefaultClient)
logger := logr.Discard()
_, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: nil, AuthParams: authParams}, http.DefaultClient)
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand Down Expand Up @@ -160,7 +164,9 @@ func TestAzurePipelinesGetMetricSpecForScaling(t *testing.T) {
"targetPipelinesQueueLength": "1",
}

meta, err := parseAzurePipelinesMetadata(context.TODO(), &ScalerConfig{TriggerMetadata: metadata, ResolvedEnv: nil, AuthParams: authParams, TriggerIndex: testData.triggerIndex}, http.DefaultClient)
logger := logr.Discard()

meta, _, err := parseAzurePipelinesMetadata(context.TODO(), logger, &ScalerConfig{TriggerMetadata: metadata, ResolvedEnv: nil, AuthParams: authParams, TriggerIndex: testData.triggerIndex}, http.DefaultClient)
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
Expand All @@ -183,7 +189,7 @@ func getMatchedAgentMetaData(url string) *azurePipelinesMetadata {
meta.organizationName = "testOrg"
meta.organizationURL = url
meta.parent = "dotnet60-keda-template"
meta.personalAccessToken = "testPAT"
meta.authContext.pat = "testPAT"
meta.poolID = 1
meta.targetPipelinesQueueLength = 1

Expand Down
Loading

0 comments on commit 6d66962

Please sign in to comment.