Skip to content

Commit

Permalink
feat(backend): enforce SA Token based auth b/w Persistence Agent and …
Browse files Browse the repository at this point in the history
…Pipeline API Server (#9957)

* Enforece SA-Toben auth b/n Persistence agent & Pipeline server for all reqs

Signed-off-by: Diana Atanasova <[email protected]>

* Fix persistence agent license file

Signed-off-by: Diana Atanasova <[email protected]>

---------

Signed-off-by: Diana Atanasova <[email protected]>
  • Loading branch information
difince authored Sep 7, 2023
1 parent 3b8cea0 commit 760c158
Show file tree
Hide file tree
Showing 21 changed files with 81 additions and 379 deletions.
85 changes: 0 additions & 85 deletions backend/src/agent/persistence/client/fake_namespace.go

This file was deleted.

87 changes: 0 additions & 87 deletions backend/src/agent/persistence/client/kubernetes_core.go

This file was deleted.

37 changes: 0 additions & 37 deletions backend/src/agent/persistence/client/kubernetes_core_fake.go

This file was deleted.

60 changes: 30 additions & 30 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ package client
import (
"context"
"fmt"
"os"
"strings"
"time"

"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"google.golang.org/grpc/metadata"

api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
Expand All @@ -38,8 +36,8 @@ const (
type PipelineClientInterface interface {
ReportWorkflow(workflow util.ExecutionSpec) error
ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error
ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error)
ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error)
}

type PipelineClient struct {
Expand Down Expand Up @@ -173,17 +171,26 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er

// ReadArtifact reads artifact content from run service. If the artifact is not present, returns
// nil response.
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
pctx := context.Background()
if user != "" {
pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
}
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReadArtifactV1(ctx, request)
if err != nil {
statusCode, _ := status.FromError(err)
if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting workflow resource (code: %v, message: %v): %v",
statusCode.Code(),
statusCode.Message(),
err.Error())
}
// TODO(hongyes): check NotFound error code before skip the error.
return nil, nil
}
Expand All @@ -192,37 +199,30 @@ func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user str
}

// ReportRunMetrics reports run metrics to run service.
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
pctx := context.Background()
if user != "" {
pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
}
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReportRunMetricsV1(ctx, request)
if err != nil {
statusCode, _ := status.FromError(err)
if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting workflow resource (code: %v, message: %v): %v",
statusCode.Code(),
statusCode.Message(),
err.Error())
}
// This call should always succeed unless the run doesn't exist or server is broken. In
// either cases, the job should retry at a later time.
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting metrics (%+v): %+v", request, err)
}
return response, nil
}

// TODO use config file & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDHeader()"
func getKubeflowUserIDHeader() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDHeader); ok {
return value
}
return common.GoogleIAPUserIdentityHeader
}

// TODO use of viper & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDPrefix()"
func getKubeflowUserIDPrefix() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDPrefix); ok {
return value
}
return common.GoogleIAPUserIdentityPrefix
}
4 changes: 2 additions & 2 deletions backend/src/agent/persistence/client/pipeline_client_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ func (p *PipelineClientFake) ReportScheduledWorkflow(swf *util.ScheduledWorkflow
return nil
}

func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
if p.err != nil {
return nil, p.err
}
p.readArtifactRequest = request
return p.artifacts[request.String()], nil
}

func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
p.reportedMetricsRequest = request
return p.reportMetricsResponseStub, p.reportMetricsErrorStub
}
Expand Down
5 changes: 0 additions & 5 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ func main() {
} else {
swfInformerFactory = swfinformers.NewFilteredSharedInformerFactory(swfClient, time.Second*30, namespace, nil)
}
k8sCoreClient := client.CreateKubernetesCoreOrFatal(DefaultConnectionTimeout, util.ClientParameters{
QPS: clientQPS,
Burst: clientBurst,
})

tokenRefresher := client.NewTokenRefresher(time.Duration(saTokenRefreshIntervalInSecs)*time.Second, nil)
err = tokenRefresher.StartTokenRefreshTicker()
Expand All @@ -122,7 +118,6 @@ func main() {
swfInformerFactory,
execInformer,
pipelineClient,
k8sCoreClient,
util.NewRealTime())

go swfInformerFactory.Start(stopCh)
Expand Down
3 changes: 1 addition & 2 deletions backend/src/agent/persistence/persistence_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ func NewPersistenceAgent(
swfInformerFactory swfinformers.SharedInformerFactory,
execInformer util.ExecutionInformer,
pipelineClient *client.PipelineClient,
k8sCoreClient client.KubernetesCoreInterface,
time util.TimeInterface) *PersistenceAgent {
// obtain references to shared informers
swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows()
Expand All @@ -63,7 +62,7 @@ func NewPersistenceAgent(

workflowWorker := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind,
execInformer, true,
worker.NewWorkflowSaver(workflowClient, pipelineClient, k8sCoreClient, ttlSecondsAfterWorkflowFinish))
worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish))

agent := &PersistenceAgent{
swfClient: swfClient,
Expand Down
6 changes: 3 additions & 3 deletions backend/src/agent/persistence/worker/metrics_reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewMetricsReporter(pipelineClient client.PipelineClientInterface) *MetricsR
}

// ReportMetrics reports workflow metrics to pipeline server.
func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec, user string) error {
func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec) error {
if !workflow.ExecutionStatus().HasMetrics() {
return nil
}
Expand All @@ -52,14 +52,14 @@ func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec, user string)
// Skip reporting if the workflow doesn't have the run id label
return nil
}
runMetrics, partialFailures := workflow.ExecutionStatus().CollectionMetrics(r.pipelineClient.ReadArtifact, user)
runMetrics, partialFailures := workflow.ExecutionStatus().CollectionMetrics(r.pipelineClient.ReadArtifact)
if len(runMetrics) == 0 {
return aggregateErrors(partialFailures)
}
reportMetricsResponse, err := r.pipelineClient.ReportRunMetrics(&api.ReportRunMetricsRequest{
RunId: runID,
Metrics: runMetrics,
}, user)
})
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 760c158

Please sign in to comment.