Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into use-openapi-gener…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
Bobgy committed Jun 3, 2020
2 parents 440941e + 5d302b6 commit 3119fbc
Show file tree
Hide file tree
Showing 30 changed files with 151 additions and 64 deletions.
2 changes: 1 addition & 1 deletion backend/api/experiment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ message ListExperimentsRequest {
// nextPageToken field you can use to fetch the next page.
int32 page_size = 2;

// Can be format of "field_name", "field_name asc" or "field_name des"
// Can be format of "field_name", "field_name asc" or "field_name desc"
// Ascending by default.
string sort_by = 3;

Expand Down
2 changes: 1 addition & 1 deletion backend/api/job.proto
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ message ListJobsRequest {
// to fetch the next page.
int32 page_size = 2;

// Can be format of "field_name", "field_name asc" or "field_name des".
// Can be format of "field_name", "field_name asc" or "field_name desc".
// Ascending by default.
string sort_by = 3;

Expand Down
4 changes: 2 additions & 2 deletions backend/api/pipeline.proto
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ message ListPipelinesRequest {
// nextPageToken field.
int32 page_size = 2;

// Can be format of "field_name", "field_name asc" or "field_name des"
// Can be format of "field_name", "field_name asc" or "field_name desc"
// Ascending by default.
string sort_by = 3;

Expand Down Expand Up @@ -236,7 +236,7 @@ message ListPipelineVersionsRequest {
// ListPipelineVersions call or can be omitted when fetching the first page.
string page_token = 3;

// Can be format of "field_name", "field_name asc" or "field_name des"
// Can be format of "field_name", "field_name asc" or "field_name desc"
// Ascending by default.
string sort_by = 4;
// A base-64 encoded, JSON-serialized Filter protocol buffer (see
Expand Down
4 changes: 2 additions & 2 deletions backend/api/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ message ListRunsRequest {
// to fetch the next page.
int32 page_size = 2;

// Can be format of "field_name", "field_name asc" or "field_name des"
// (Example, "name asc" or "id des"). Ascending by default.
// Can be format of "field_name", "field_name asc" or "field_name desc"
// (Example, "name asc" or "id desc"). Ascending by default.
string sort_by = 3;

// What resource reference to filter on.
Expand Down
2 changes: 1 addition & 1 deletion backend/src/apiserver/server/pipeline_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *PipelineServer) CreatePipeline(ctx context.Context, request *api.Create
return nil, util.Wrap(err, "Invalid pipeline name.")
}

pipeline, err := s.resourceManager.CreatePipeline(pipelineName, "", pipelineFile)
pipeline, err := s.resourceManager.CreatePipeline(pipelineName, request.Pipeline.Description, pipelineFile)
if err != nil {
return nil, util.Wrap(err, "Create pipeline failed.")
}
Expand Down
4 changes: 4 additions & 0 deletions backend/src/apiserver/server/pipeline_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestCreatePipeline_YAML(t *testing.T) {
Pipeline: &api.Pipeline{
Url: &api.Url{PipelineUrl: httpServer.URL + "/arguments-parameters.yaml"},
Name: "argument-parameters",
Description: "pipeline description",
}})

assert.Nil(t, err)
Expand All @@ -41,6 +42,7 @@ func TestCreatePipeline_YAML(t *testing.T) {
err = json.Unmarshal([]byte(newPipeline.Parameters), &params)
assert.Nil(t, err)
assert.Equal(t, []api.Parameter{{Name: "param1", Value: "hello"}, {Name: "param2"}}, params)
assert.Equal(t, "pipeline description", newPipeline.Description)
}

func TestCreatePipeline_Tarball(t *testing.T) {
Expand All @@ -56,6 +58,7 @@ func TestCreatePipeline_Tarball(t *testing.T) {
Pipeline: &api.Pipeline{
Url: &api.Url{PipelineUrl: httpServer.URL + "/arguments_tarball/arguments.tar.gz"},
Name: "argument-parameters",
Description: "pipeline description",
}})

assert.Nil(t, err)
Expand All @@ -68,6 +71,7 @@ func TestCreatePipeline_Tarball(t *testing.T) {
err = json.Unmarshal([]byte(newPipeline.Parameters), &params)
assert.Nil(t, err)
assert.Equal(t, []api.Parameter{{Name: "param1", Value: "hello"}, {Name: "param2"}}, params)
assert.Equal(t, "pipeline description", newPipeline.Description)
}

func TestCreatePipeline_InvalidYAML(t *testing.T) {
Expand Down
19 changes: 13 additions & 6 deletions backend/src/apiserver/server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ func CheckPipelineVersionReference(resourceManager *resource.ResourceManager, re
return &pipelineVersionId, nil
}

func getUserIdentityFromHeader(userIdentityHeader, prefix string) (string, error) {
if len(userIdentityHeader) > len(prefix) && userIdentityHeader[:len(prefix)] == prefix {
return userIdentityHeader[len(prefix):], nil
}
return "", util.NewBadRequestError(
errors.New("Request header error: user identity value is incorrectly formatted"),
"Request header error: user identity value is incorrectly formatted. Expected prefix '%s', but got the header '%s'",
prefix,
userIdentityHeader,
)
}

func getUserIdentity(ctx context.Context) (string, error) {
if ctx == nil {
return "", util.NewBadRequestError(errors.New("Request error: context is nil"), "Request error: context is nil.")
Expand All @@ -287,12 +299,7 @@ func getUserIdentity(ctx context.Context) (string, error) {
return "", util.NewBadRequestError(errors.New("Request header error: unexpected number of user identity header. Expect 1 got "+strconv.Itoa(len(userIdentityHeader))),
"Request header error: unexpected number of user identity header. Expect 1 got "+strconv.Itoa(len(userIdentityHeader)))
}
userIdentityHeaderFields := strings.Split(userIdentityHeader[0], ":")
if len(userIdentityHeaderFields) != 2 {
return "", util.NewBadRequestError(errors.New("Request header error: user identity value is incorrectly formatted"),
"Request header error: user identity value is incorrectly formatted")
}
return userIdentityHeaderFields[1], nil
return getUserIdentityFromHeader(userIdentityHeader[0], common.GetKubeflowUserIDPrefix())
}
return "", util.NewBadRequestError(errors.New("Request header error: there is no user identity header."), "Request header error: there is no user identity header.")
}
Expand Down
28 changes: 28 additions & 0 deletions backend/src/apiserver/server/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,34 @@ func TestGetUserIdentity(t *testing.T) {
assert.Equal(t, "[email protected]", userIdentity)
}

func TestGetUserIdentityError(t *testing.T) {
md := metadata.New(map[string]string{"no-identity-header": "user"})
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := getUserIdentity(ctx)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Request header error: there is no user identity header.")
}

func TestGetUserIdentityFromHeaderGoogle(t *testing.T) {
userIdentity, err := getUserIdentityFromHeader(common.GoogleIAPUserIdentityPrefix+"[email protected]", common.GoogleIAPUserIdentityPrefix)
assert.Nil(t, err)
assert.Equal(t, "[email protected]", userIdentity)
}

func TestGetUserIdentityFromHeaderNonGoogle(t *testing.T) {
prefix := ""
userIdentity, err := getUserIdentityFromHeader(prefix+"user", prefix)
assert.Nil(t, err)
assert.Equal(t, "user", userIdentity)
}

func TestGetUserIdentityFromHeaderError(t *testing.T) {
prefix := "expected-prefix"
_, err := getUserIdentityFromHeader("user", prefix)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Request header error: user identity value is incorrectly formatted")
}

func TestCanAccessNamespaceInResourceReferences_Unauthorized(t *testing.T) {
viper.Set(common.MultiUserMode, "true")
defer viper.Set(common.MultiUserMode, "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ def test_workteamjob(
)

outputs = {"sagemaker-private-workforce": ["workteam_arn"]}
output_files = minio_utils.artifact_download_iterator(
workflow_json, outputs, download_dir
)

try:
try:
output_files = minio_utils.artifact_download_iterator(
workflow_json, outputs, download_dir
)

response = sagemaker_utils.describe_workteam(sagemaker_client, workteam_name)

# Verify WorkTeam was created in SageMaker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
from ._client import MLEngineClient
from .. import common as gcp_common

def create_job(project_id, job, job_id_prefix=None, wait_interval=30):
def create_job(project_id, job, job_id_prefix=None, job_id=None,
wait_interval=30):
"""Creates a MLEngine job.
Args:
project_id: the ID of the parent project of the job.
job: the payload of the job. Must have ``jobId``
and ``trainingInput`` or ``predictionInput`.
job_id_prefix: the prefix of the generated job id.
job_id: the created job_id, takes precedence over generated job
id if set.
wait_interval: optional wait interval between calls
to get job status. Defaults to 30.
Expand All @@ -42,15 +45,16 @@ def create_job(project_id, job, job_id_prefix=None, wait_interval=30):
/tmp/kfp/output/ml_engine/job_id.txt: The ID of the created job.
/tmp/kfp/output/ml_engine/job_dir.txt: The `jobDir` of the training job.
"""
return CreateJobOp(project_id, job, job_id_prefix,
wait_interval).execute_and_wait()
return CreateJobOp(project_id, job, job_id_prefix, job_id, wait_interval
).execute_and_wait()

class CreateJobOp:
def __init__(self, project_id, job, job_id_prefix=None, wait_interval=30):
def __init__(self,project_id, job, job_id_prefix=None, job_id=None,
wait_interval=30):
self._ml = MLEngineClient()
self._project_id = project_id
self._job_id_prefix = job_id_prefix
self._job_id = None
self._job_id = job_id
self._job = job
self._wait_interval = wait_interval

Expand All @@ -61,7 +65,9 @@ def execute_and_wait(self):
return wait_for_job_done(self._ml, self._project_id, self._job_id, self._wait_interval)

def _set_job_id(self, context_id):
if self._job_id_prefix:
if self._job_id:
job_id = self._job_id
elif self._job_id_prefix:
job_id = self._job_id_prefix + context_id[:16]
else:
job_id = 'job_' + context_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def train(project_id, python_module=None, package_uris=None,
region=None, args=None, job_dir=None, python_version=None,
runtime_version=None, master_image_uri=None, worker_image_uri=None,
training_input=None, job_id_prefix=None, wait_interval=30):
training_input=None, job_id_prefix=None, job_id=None, wait_interval=30):
"""Creates a MLEngine training job.
Args:
Expand Down Expand Up @@ -50,6 +50,8 @@ def train(project_id, python_module=None, package_uris=None,
This image must be in Container Registry.
training_input (dict): Input parameters to create a training job.
job_id_prefix (str): the prefix of the generated job id.
job_id (str): the created job_id, takes precedence over generated job
id if set.
wait_interval (int): optional wait interval between calls
to get job status. Defaults to 30.
"""
Expand Down Expand Up @@ -80,4 +82,4 @@ def train(project_id, python_module=None, package_uris=None,
job = {
'trainingInput': training_input
}
return create_job(project_id, job, job_id_prefix, wait_interval)
return create_job(project_id, job, job_id_prefix, job_id, wait_interval)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,27 @@ def test_create_job_with_job_id_prefix_succeed(self, mock_mlengine_client,
'jobId': 'mock_job_ctx1'
}
)

def test_create_job_with_job_id_succeed(self, mock_mlengine_client,
mock_kfp_context, mock_dump_json, mock_display):
mock_kfp_context().__enter__().context_id.return_value = 'ctx1'
job = {}
returned_job = {
'jobId': 'mock_job',
'state': 'SUCCEEDED'
}
mock_mlengine_client().get_job.return_value = (
returned_job)

result = create_job('mock_project', job, job_id='mock_job')

self.assertEqual(returned_job, result)
mock_mlengine_client().create_job.assert_called_with(
project_id = 'mock_project',
job = {
'jobId': 'mock_job'
}
)

def test_execute_retry_job_success(self, mock_mlengine_client,
mock_kfp_context, mock_dump_json, mock_display):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
CREATE_JOB_MODULE = 'kfp_component.google.ml_engine._train'

@mock.patch(CREATE_JOB_MODULE + '.create_job')
class TestCreateTraingingJob(unittest.TestCase):
class TestCreateTrainingJob(unittest.TestCase):

def test_train_succeed(self, mock_create_job):
train('proj-1', 'mock.module', ['gs://test/package'],
'region-1', args=['arg-1', 'arg-2'], job_dir='gs://test/job/dir',
training_input={
'runtimeVersion': '1.10',
'pythonVersion': '2.7'
}, job_id_prefix='job-', master_image_uri='tensorflow:latest',
}, job_id_prefix='job-', job_id='job-1',
master_image_uri='tensorflow:latest',
worker_image_uri='debian:latest')

mock_create_job.assert_called_with('proj-1', {
Expand All @@ -48,4 +49,4 @@ def test_train_succeed(self, mock_create_job):
'imageUri': 'debian:latest'
}
}
}, 'job-', 30)
}, 'job-', 'job-1', 30)
3 changes: 3 additions & 0 deletions components/gcp/ml_engine/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Use this component to submit a training job to AI Platform from a Kubeflow pipel
| worker_image_uri | The Docker image to run on the worker replica. This image must be in Container Registry. | Yes | GCRPath |- | None |
| training_input | The input parameters to create a training job. | Yes | Dict | [TrainingInput](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput) | None |
| job_id_prefix | The prefix of the job ID that is generated. | Yes | String | - | None |
| job_id | The ID of the job to create, takes precedence over generated job id if set. | Yes | String | - | None |
| wait_interval | The number of seconds to wait between API calls to get the status of the job. | Yes | Integer | - | 30 |


Expand Down Expand Up @@ -179,6 +180,7 @@ def pipeline(
worker_image_uri = '',
training_input = '',
job_id_prefix = '',
job_id = '',
wait_interval = '30'):
task = mlengine_train_op(
project_id=project_id,
Expand All @@ -193,6 +195,7 @@ def pipeline(
worker_image_uri=worker_image_uri,
training_input=training_input,
job_id_prefix=job_id_prefix,
job_id=job_id,
wait_interval=wait_interval)
```

Expand Down
7 changes: 7 additions & 0 deletions components/gcp/ml_engine/train/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ inputs:
description: 'The prefix of the generated job id.'
default: ''
type: String
- name: job_id
description: >-
The ID of the job to create, takes precedence over generated
job id if set.
default: ''
type: String
- name: wait_interval
description: >-
Optional. A time-interval to wait for between calls to get the job status.
Expand Down Expand Up @@ -119,6 +125,7 @@ implementation:
--worker_image_uri, {inputValue: worker_image_uri},
--training_input, {inputValue: training_input},
--job_id_prefix, {inputValue: job_id_prefix},
--job_id, {inputValue: job_id},
--wait_interval, {inputValue: wait_interval},
]
env:
Expand Down
7 changes: 5 additions & 2 deletions components/gcp/ml_engine/train/sample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"| worker_image_uri | The Docker image to run on the worker replica. This image must be in Container Registry. | Yes | GCRPath | | None |\n",
"| training_input | The input parameters to create a training job. | Yes | Dict | [TrainingInput](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput) | None |\n",
"| job_id_prefix | The prefix of the job ID that is generated. | Yes | String | | None |\n",
"| job_id | The ID of the job to create, takes precedence over generated job id if set. | Yes | String | - | None |\n",
"| wait_interval | The number of seconds to wait between API calls to get the status of the job. | Yes | Integer | | 30 |\n",
"\n",
"\n",
Expand Down Expand Up @@ -238,6 +239,7 @@
" worker_image_uri = '',\n",
" training_input = '',\n",
" job_id_prefix = '',\n",
" job_id = '',\n",
" wait_interval = '30'):\n",
" task = mlengine_train_op(\n",
" project_id=project_id, \n",
Expand All @@ -251,7 +253,8 @@
" master_image_uri=master_image_uri, \n",
" worker_image_uri=worker_image_uri, \n",
" training_input=training_input, \n",
" job_id_prefix=job_id_prefix, \n",
" job_id_prefix=job_id_prefix,\n",
" job_id=job_id,\n",
" wait_interval=wait_interval)"
]
},
Expand Down Expand Up @@ -354,4 +357,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 2 additions & 0 deletions components/pipeline_component_repository.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# A marker file that marks the location of a repository of Kubeflow Pipelines components
# This marker file makes it easier to find public repositories online
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Loading

0 comments on commit 3119fbc

Please sign in to comment.