Skip to content

Commit

Permalink
feat: Allow setting default service account
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559266585
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 23, 2023
1 parent 7eaa1d4 commit d11b8e6
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 7 deletions.
15 changes: 15 additions & 0 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self):
self._credentials = None
self._encryption_spec_key_name = None
self._network = None
self._service_account = None

def init(
self,
Expand All @@ -113,6 +114,7 @@ def init(
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
network: Optional[str] = None,
service_account: Optional[str] = None,
):
"""Updates common initialization parameters with provided options.
Expand Down Expand Up @@ -155,6 +157,12 @@ def init(
Private services access must already be configured for the network.
If specified, all eligible jobs and resources created will be peered
with this VPC.
service_account (str):
Optional. The service account used to launch jobs and deploy models.
Jobs that use service_account: BatchPredictionJob, CustomJob,
PipelineJob, HyperparameterTuningJob, CustomTrainingJob,
CustomPythonPackageTrainingJob, CustomContainerTrainingJob,
ModelEvaluationJob.
Raises:
ValueError:
If experiment_description is provided but experiment is not.
Expand Down Expand Up @@ -194,6 +202,8 @@ def init(
self._encryption_spec_key_name = encryption_spec_key_name
if network is not None:
self._network = network
if service_account is not None:
self._service_account = service_account

if experiment:
metadata._experiment_tracker.set_experiment(
Expand Down Expand Up @@ -297,6 +307,11 @@ def network(self) -> Optional[str]:
"""Default Compute Engine network to peer to, if provided."""
return self._network

@property
def service_account(self) -> Optional[str]:
"""Default service account, if provided."""
return self._service_account

@property
def experiment_name(self) -> Optional[str]:
"""Default experiment name, if provided."""
Expand Down
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def create(
)
gapic_batch_prediction_job.explanation_spec = explanation_spec

service_account = service_account or initializer.global_config.service_account
if service_account:
gapic_batch_prediction_job.service_account = service_account

Expand Down Expand Up @@ -1693,6 +1694,7 @@ def run(
`restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

self._run(
service_account=service_account,
Expand Down Expand Up @@ -1880,6 +1882,8 @@ def submit(
raise ValueError(
"'experiment' is required since you've enabled autolog in 'from_local_script'."
)

service_account = service_account or initializer.global_config.service_account
if service_account:
self._gca_resource.job_spec.service_account = service_account

Expand Down Expand Up @@ -2356,6 +2360,7 @@ def run(
`restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

self._run(
service_account=service_account,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def submit(
Returns:
(ModelEvaluationJob): Instantiated represnetation of the model evaluation job.
"""
service_account = service_account or initializer.global_config.service_account

if isinstance(model_name, aiplatform.Model):
model_resource_name = model_name.versioned_resource_name
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def _deploy_call(
to the resource project.
Users deploying the Model must have the `iam.serviceAccounts.actAs`
permission on this service account.
If not specified, uses the service account set in aiplatform.init.
explanation_spec (aiplatform.explain.ExplanationSpec):
Optional. Specification of Model explanation.
metadata (Sequence[Tuple[str, str]]):
Expand All @@ -1120,6 +1121,8 @@ def _deploy_call(
is not 0 or 100.
"""

service_account = service_account or initializer.global_config.service_account

max_replica_count = max(min_replica_count, max_replica_count)

if bool(accelerator_type) != bool(accelerator_count):
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _create(
if max_concurrent_run_count:
self._gca_resource.max_concurrent_run_count = max_concurrent_run_count

service_account = service_account or initializer.global_config.service_account
network = network or initializer.global_config.network

if service_account:
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def submit(
current Experiment Run.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

if service_account:
self._gca_resource.service_account = service_account
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3223,6 +3223,7 @@ def run(
produce a Vertex AI Model.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

worker_pool_specs, managed_model = self._prepare_and_validate_run(
model_display_name=model_display_name,
Expand Down Expand Up @@ -4579,6 +4580,7 @@ def run(
were not provided in constructor.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

worker_pool_specs, managed_model = self._prepare_and_validate_run(
model_display_name=model_display_name,
Expand Down Expand Up @@ -7348,6 +7350,7 @@ def run(
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
If not specified, uses the service account set in aiplatform.init.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Expand Down Expand Up @@ -7501,6 +7504,7 @@ def run(
produce a Vertex AI Model.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account

worker_pool_specs, managed_model = self._prepare_and_validate_run(
model_display_name=model_display_name,
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
"""
project = project or initializer.global_config.project
location = location or initializer.global_config.location
service_account = service_account or initializer.global_config.service_account
credentials = credentials or initializer.global_config.credentials

output_artifacts_gcs_dir = (
Expand Down
2 changes: 2 additions & 0 deletions samples/model-builder/init_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def init_sample(
staging_bucket: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
service_account: Optional[str] = None,
):

from google.cloud import aiplatform
Expand All @@ -36,6 +37,7 @@ def init_sample(
staging_bucket=staging_bucket,
credentials=credentials,
encryption_spec_key_name=encryption_spec_key_name,
service_account=service_account,
)


Expand Down
2 changes: 2 additions & 0 deletions samples/model-builder/init_sample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_init_sample(mock_sdk_init):
staging_bucket=constants.STAGING_BUCKET,
credentials=constants.CREDENTIALS,
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
service_account=constants.SERVICE_ACCOUNT,
)

mock_sdk_init.assert_called_once_with(
Expand All @@ -35,4 +36,5 @@ def test_init_sample(mock_sdk_init):
staging_bucket=constants.STAGING_BUCKET,
credentials=constants.CREDENTIALS,
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
service_account=constants.SERVICE_ACCOUNT,
)
4 changes: 2 additions & 2 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
network=_TEST_NETWORK,
service_account=_TEST_SERVICE_ACCOUNT,
)

job = aiplatform.CustomJob(
Expand All @@ -416,8 +418,6 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_TEST_DESCRIPTION = "test-description"
_TEST_STAGING_BUCKET = "test-bucket"
_TEST_NETWORK = "projects/12345/global/networks/myVPC"
_TEST_SERVICE_ACCOUNT = "[email protected]"

# tensorboard
_TEST_TENSORBOARD_ID = "1028944691210842416"
Expand Down Expand Up @@ -105,6 +106,10 @@ def test_init_network_sets_network(self):
initializer.global_config.init(network=_TEST_NETWORK)
assert initializer.global_config.network == _TEST_NETWORK

def test_init_service_account_sets_service_account(self):
initializer.global_config.init(service_account=_TEST_SERVICE_ACCOUNT)
assert initializer.global_config.service_account == _TEST_SERVICE_ACCOUNT

@patch.object(_experiment_tracker, "set_experiment")
def test_init_experiment_sets_experiment(self, set_experiment_mock):
initializer.global_config.init(experiment=_TEST_EXPERIMENT)
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,61 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
def test_init_aiplatform_with_service_account_and_batch_predict_gcs_source_and_dest(
self, create_batch_prediction_job_mock, sync
):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
)
test_model = models.Model(_TEST_ID)

# Make SDK batch_predict method call
batch_prediction_job = test_model.batch_predict(
job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
)

if not sync:
batch_prediction_job.wait()

# Construct expected request
expected_gapic_batch_prediction_job = (
gca_batch_prediction_job.BatchPredictionJob(
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
model=model_service_client.ModelServiceClient.model_path(
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
),
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io.GcsSource(
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
),
),
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io.GcsDestination(
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
),
predictions_format="jsonl",
),
encryption_spec=_TEST_ENCRYPTION_SPEC,
service_account=_TEST_SERVICE_ACCOUNT,
)
)

create_batch_prediction_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
batch_prediction_job=expected_gapic_batch_prediction_job,
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
def test_batch_predict_gcs_source_and_dest(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def test_run_call_pipeline_service_create(
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

job = pipeline_jobs.PipelineJob(
Expand All @@ -497,8 +499,6 @@ def test_run_call_pipeline_service_create(
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
service_account=_TEST_SERVICE_ACCOUNT,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

Expand All @@ -1082,7 +1083,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_from_job = job.run(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
Expand Down Expand Up @@ -3181,6 +3181,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

Expand Down Expand Up @@ -3215,7 +3216,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
sync=sync,
create_request_timeout=None,
Expand Down Expand Up @@ -5242,6 +5242,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

Expand Down Expand Up @@ -5271,7 +5272,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
Expand Down

0 comments on commit d11b8e6

Please sign in to comment.