diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 95287a32da..b618c63677 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -98,6 +98,7 @@ def __init__(self): self._credentials = None self._encryption_spec_key_name = None self._network = None + self._service_account = None def init( self, @@ -113,6 +114,7 @@ def init( credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, network: Optional[str] = None, + service_account: Optional[str] = None, ): """Updates common initialization parameters with provided options. @@ -155,6 +157,12 @@ def init( Private services access must already be configured for the network. If specified, all eligible jobs and resources created will be peered with this VPC. + service_account (str): + Optional. The service account used to launch jobs and deploy models. + Jobs that use service_account: BatchPredictionJob, CustomJob, + PipelineJob, HyperparameterTuningJob, CustomTrainingJob, + CustomPythonPackageTrainingJob, CustomContainerTrainingJob, + ModelEvaluationJob. Raises: ValueError: If experiment_description is provided but experiment is not. @@ -194,6 +202,8 @@ def init( self._encryption_spec_key_name = encryption_spec_key_name if network is not None: self._network = network + if service_account is not None: + self._service_account = service_account if experiment: metadata._experiment_tracker.set_experiment( @@ -297,6 +307,11 @@ def network(self) -> Optional[str]: """Default Compute Engine network to peer to, if provided.""" return self._network + @property + def service_account(self) -> Optional[str]: + """Default service account, if provided.""" + return self._service_account + @property def experiment_name(self) -> Optional[str]: """Default experiment name, if provided.""" diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 90e7a8471f..6c0afcec27 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -761,6 +761,7 @@ def create( ) gapic_batch_prediction_job.explanation_spec = explanation_spec + service_account = service_account or initializer.global_config.service_account if service_account: gapic_batch_prediction_job.service_account = service_account @@ -1693,6 +1694,7 @@ def run( `restart_job_on_worker_restart` to False. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account self._run( service_account=service_account, @@ -1880,6 +1882,8 @@ def submit( raise ValueError( "'experiment' is required since you've enabled autolog in 'from_local_script'." ) + + service_account = service_account or initializer.global_config.service_account if service_account: self._gca_resource.job_spec.service_account = service_account @@ -2356,6 +2360,7 @@ def run( `restart_job_on_worker_restart` to False. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account self._run( service_account=service_account, diff --git a/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py b/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py index f8ee2b9e21..bf19d8009b 100644 --- a/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py +++ b/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py @@ -278,6 +278,7 @@ def submit( Returns: (ModelEvaluationJob): Instantiated represnetation of the model evaluation job. """ + service_account = service_account or initializer.global_config.service_account if isinstance(model_name, aiplatform.Model): model_resource_name = model_name.versioned_resource_name diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 4694bf2250..2a755a5288 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1096,6 +1096,7 @@ def _deploy_call( to the resource project. Users deploying the Model must have the `iam.serviceAccounts.actAs` permission on this service account. + If not specified, uses the service account set in aiplatform.init. explanation_spec (aiplatform.explain.ExplanationSpec): Optional. Specification of Model explanation. metadata (Sequence[Tuple[str, str]]): @@ -1120,6 +1121,8 @@ def _deploy_call( is not 0 or 100. """ + service_account = service_account or initializer.global_config.service_account + max_replica_count = max(min_replica_count, max_replica_count) if bool(accelerator_type) != bool(accelerator_count): diff --git a/google/cloud/aiplatform/pipeline_job_schedules.py b/google/cloud/aiplatform/pipeline_job_schedules.py index 1e66fd8809..e07ca80b9c 100644 --- a/google/cloud/aiplatform/pipeline_job_schedules.py +++ b/google/cloud/aiplatform/pipeline_job_schedules.py @@ -226,6 +226,7 @@ def _create( if max_concurrent_run_count: self._gca_resource.max_concurrent_run_count = max_concurrent_run_count + service_account = service_account or initializer.global_config.service_account network = network or initializer.global_config.network if service_account: diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 47c33f79b5..378eff6017 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -382,6 +382,7 @@ def submit( current Experiment Run. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account if service_account: self._gca_resource.service_account = service_account diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 7af003d185..268ab8fdf4 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -3223,6 +3223,7 @@ def run( produce a Vertex AI Model. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account worker_pool_specs, managed_model = self._prepare_and_validate_run( model_display_name=model_display_name, @@ -4579,6 +4580,7 @@ def run( were not provided in constructor. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account worker_pool_specs, managed_model = self._prepare_and_validate_run( model_display_name=model_display_name, @@ -7348,6 +7350,7 @@ def run( service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + If not specified, uses the service account set in aiplatform.init. network (str): The full name of the Compute Engine network to which the job should be peered. For example, projects/12345/global/networks/myVPC. @@ -7501,6 +7504,7 @@ def run( produce a Vertex AI Model. """ network = network or initializer.global_config.network + service_account = service_account or initializer.global_config.service_account worker_pool_specs, managed_model = self._prepare_and_validate_run( model_display_name=model_display_name, diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 486fddd152..13e4eb3cfc 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -217,6 +217,7 @@ def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist( """ project = project or initializer.global_config.project location = location or initializer.global_config.location + service_account = service_account or initializer.global_config.service_account credentials = credentials or initializer.global_config.credentials output_artifacts_gcs_dir = ( diff --git a/samples/model-builder/init_sample.py b/samples/model-builder/init_sample.py index 2567e04f00..7ec00d84dd 100644 --- a/samples/model-builder/init_sample.py +++ b/samples/model-builder/init_sample.py @@ -25,6 +25,7 @@ def init_sample( staging_bucket: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, + service_account: Optional[str] = None, ): from google.cloud import aiplatform @@ -36,6 +37,7 @@ def init_sample( staging_bucket=staging_bucket, credentials=credentials, encryption_spec_key_name=encryption_spec_key_name, + service_account=service_account, ) diff --git a/samples/model-builder/init_sample_test.py b/samples/model-builder/init_sample_test.py index 3c4684a255..2ac5284e09 100644 --- a/samples/model-builder/init_sample_test.py +++ b/samples/model-builder/init_sample_test.py @@ -26,6 +26,7 @@ def test_init_sample(mock_sdk_init): staging_bucket=constants.STAGING_BUCKET, credentials=constants.CREDENTIALS, encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + service_account=constants.SERVICE_ACCOUNT, ) mock_sdk_init.assert_called_once_with( @@ -35,4 +36,5 @@ def test_init_sample(mock_sdk_init): staging_bucket=constants.STAGING_BUCKET, credentials=constants.CREDENTIALS, encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + service_account=constants.SERVICE_ACCOUNT, ) diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index ea43c42a4f..93a8717ba3 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -406,6 +406,8 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy location=_TEST_LOCATION, staging_bucket=_TEST_STAGING_BUCKET, encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + network=_TEST_NETWORK, + service_account=_TEST_SERVICE_ACCOUNT, ) job = aiplatform.CustomJob( @@ -416,8 +418,6 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy ) job.run( - service_account=_TEST_SERVICE_ACCOUNT, - network=_TEST_NETWORK, timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index bd7adb5075..80c48d32bf 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -44,6 +44,7 @@ _TEST_DESCRIPTION = "test-description" _TEST_STAGING_BUCKET = "test-bucket" _TEST_NETWORK = "projects/12345/global/networks/myVPC" +_TEST_SERVICE_ACCOUNT = "test-service-account@test-project.iam.gserviceaccount.com" # tensorboard _TEST_TENSORBOARD_ID = "1028944691210842416" @@ -105,6 +106,10 @@ def test_init_network_sets_network(self): initializer.global_config.init(network=_TEST_NETWORK) assert initializer.global_config.network == _TEST_NETWORK + def test_init_service_account_sets_service_account(self): + initializer.global_config.init(service_account=_TEST_SERVICE_ACCOUNT) + assert initializer.global_config.service_account == _TEST_SERVICE_ACCOUNT + @patch.object(_experiment_tracker, "set_experiment") def test_init_experiment_sets_experiment(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 32a294b7a9..bcbdcbd413 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -2215,6 +2215,61 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a timeout=None, ) + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_init_aiplatform_with_service_account_and_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + service_account=_TEST_SERVICE_ACCOUNT, + ) + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = ( + gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + encryption_spec=_TEST_ENCRYPTION_SPEC, + service_account=_TEST_SERVICE_ACCOUNT, + ) + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_gcs_source_and_dest( diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 0011758085..2b97589f8e 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -485,6 +485,8 @@ def test_run_call_pipeline_service_create( staging_bucket=_TEST_GCS_BUCKET_NAME, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, ) job = pipeline_jobs.PipelineJob( @@ -497,8 +499,6 @@ def test_run_call_pipeline_service_create( ) job.run( - service_account=_TEST_SERVICE_ACCOUNT, - network=_TEST_NETWORK, sync=sync, create_request_timeout=None, ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index a35e644b46..4a4ab7ee45 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -1056,6 +1056,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, credentials=_TEST_CREDENTIALS, + service_account=_TEST_SERVICE_ACCOUNT, encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, ) @@ -1082,7 +1083,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_from_job = job.run( dataset=mock_tabular_dataset, base_output_dir=_TEST_BASE_OUTPUT_DIR, - service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES, @@ -3181,6 +3181,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( aiplatform.init( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + service_account=_TEST_SERVICE_ACCOUNT, encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, ) @@ -3215,7 +3216,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, - service_account=_TEST_SERVICE_ACCOUNT, tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME, sync=sync, create_request_timeout=None, @@ -5242,6 +5242,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( aiplatform.init( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + service_account=_TEST_SERVICE_ACCOUNT, encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, ) @@ -5271,7 +5272,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset( model_display_name=_TEST_MODEL_DISPLAY_NAME, model_labels=_TEST_MODEL_LABELS, base_output_dir=_TEST_BASE_OUTPUT_DIR, - service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, args=_TEST_RUN_ARGS, environment_variables=_TEST_ENVIRONMENT_VARIABLES,