diff --git a/backend/src/apiserver/resource/resource_manager.go b/backend/src/apiserver/resource/resource_manager.go index e6d65248d0e..201f19d5e9a 100644 --- a/backend/src/apiserver/resource/resource_manager.go +++ b/backend/src/apiserver/resource/resource_manager.go @@ -1967,7 +1967,40 @@ func (r *ResourceManager) GetArtifactSessionInfo(ctx context.Context, artifact * } // Retrieve Session info - sessionInfoString := artifactCtx.CustomProperties["store_session_info"].GetStringValue() + storeSessionInfo, ok_session := artifactCtx.CustomProperties["store_session_info"] + + var sessionInfoString = "" + + if ok_session { + sessionInfoString = storeSessionInfo.GetStringValue() + } else { + // bucket_session_info is an old struct that needs to be converted to store_session_info + bucketSession := &objectstore.S3Params{} + err1 := json.Unmarshal([]byte(artifactCtx.CustomProperties["bucket_session_info"].GetStringValue()), bucketSession) + if err1 != nil { + return nil, "", err1 + } + sessionInfoParams := &map[string]string{ + "fromEnv": "false", + "endpoint": bucketSession.Endpoint, + "region": bucketSession.Region, + "disableSSL": strconv.FormatBool(bucketSession.DisableSSL), + "secretName": bucketSession.SecretName, + "accessKeyKey": bucketSession.AccessKeyKey, + "secretKeyKey": bucketSession.SecretKeyKey, + } + + sessionInfo := &objectstore.SessionInfo{ + Provider: "s3", + Params: *sessionInfoParams, + } + sessionInfoBytes, err2 := json.Marshal(*sessionInfo) + if err2 != nil { + return nil, "", err2 + } + sessionInfoString = string(sessionInfoBytes) + } + if sessionInfoString == "" { return nil, "", fmt.Errorf("Unable to retrieve artifact session info via context property.") } diff --git a/backend/src/apiserver/resource/resource_manager_test.go b/backend/src/apiserver/resource/resource_manager_test.go index 3ff75028d22..5049a6b2bae 100644 --- a/backend/src/apiserver/resource/resource_manager_test.go +++ b/backend/src/apiserver/resource/resource_manager_test.go @@ -18,11 +18,13 @@ import ( "context" "encoding/json" "fmt" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "strings" "testing" "time" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" + "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata" + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/argoproj/argo-workflows/v3/util/file" "github.com/kubeflow/pipelines/backend/src/apiserver/client" @@ -43,6 +45,14 @@ import ( "k8s.io/apimachinery/pkg/types" ) +func intPtr(i int64) *int64 { + return &i +} + +func strPtr(i string) *string { + return &i +} + func initEnvVars() { viper.Set(common.PodNamespace, "ns1") } @@ -4020,6 +4030,35 @@ func TestCreateTask(t *testing.T) { assert.Equal(t, expectedTask, storedTask, "The StoredTask return has unexpected value") } +func TestBackwardsCompatibilityForSessionInfo(t *testing.T) { + _, manager, _, _, _, _ := initWithExperimentAndPipelineAndRun(t) + + // First Artifact has assigned a bucket_session_info + artifact1 := &ml_metadata.Artifact{ + Id: intPtr(0), + Uri: strPtr("s3://test-bucket/pipeline/some-pipeline-id/task/key0"), + } + + config1, _, err := manager.GetArtifactSessionInfo(context.Background(), artifact1) + + // Assert the results + assert.NoError(t, err) + assert.NotNil(t, config1) + + // Second Artifact has assigned a store_session_info + artifact2 := &ml_metadata.Artifact{ + Id: intPtr(1), + Uri: strPtr("s3://test-bucket/pipeline/some-pipeline-id/task/key1"), + } + + // Call the function + config2, _, err := manager.GetArtifactSessionInfo(context.Background(), artifact2) + + // Assert the results + assert.NoError(t, err) + assert.NotNil(t, config2) +} + var v2SpecHelloWorld = ` components: comp-hello-world: diff --git a/backend/src/v2/metadata/client_fake.go b/backend/src/v2/metadata/client_fake.go index 7503ce2929a..286745189a1 100644 --- a/backend/src/v2/metadata/client_fake.go +++ b/backend/src/v2/metadata/client_fake.go @@ -174,25 +174,52 @@ func (c *FakeClient) createDummyData() { AccessKeyKey: "testsecretaccesskey", SecretKeyKey: "testsecretsecretkey", } - storeSessionInfo, err := json.Marshal(ctx1SessInfo) + bucketSessionInfo, err := json.Marshal(ctx1SessInfo) if err != nil { glog.Fatal("failed to marshal fake session info") } + ctx2SessInfo := map[string]string{ + "Region": "test2", + "Endpoint": "test2.endpoint2", + "DisableSSL": "false", + "SecretName": "testsecret2", + "AccessKeyKey": "testsecretaccesskey2", + "SecretKeyKey": "testsecretsecretkey2", + "FromEnv": "false", + } + sessInfo := &objectstore.SessionInfo{ + Provider: "s3", + Params: ctx2SessInfo, + } + storeSessionInfo2, err1 := json.Marshal(sessInfo) + if err1 != nil { + glog.Fatal("failed to marshal fake session info") + } ctx1 := &pb.Context{ Id: intPtr(0), Name: strPtr("ctx-0"), Type: strPtr("1"), + CustomProperties: map[string]*pb.Value{ + "pipeline_root": stringValue("s3://test-bucket"), + "bucket_session_info": stringValue(string(bucketSessionInfo)), + "namespace": stringValue("test-namespace"), + }, + } + ctx2 := &pb.Context{ + Id: intPtr(1), + Name: strPtr("ctx-1"), + Type: strPtr("1"), CustomProperties: map[string]*pb.Value{ "pipeline_root": stringValue("s3://test-bucket"), - "store_session_info": stringValue(string(storeSessionInfo)), + "store_session_info": stringValue(string(storeSessionInfo2)), "namespace": stringValue("test-namespace"), }, } - c.contexts = []*pb.Context{ctx1} + c.contexts = []*pb.Context{ctx1, ctx2} c.artifacts = []*pb.Artifact{art1, art2} c.artifactIdsToContext = map[int64]*pb.Context{ *art1.Id: ctx1, - *art2.Id: ctx1, + *art2.Id: ctx2, } }