Skip to content

Commit

Permalink
chore: decrease unit test running time
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#2450 from googleapis:release-please--branches--main d38964a
PiperOrigin-RevId: 558800800
  • Loading branch information
sararob authored and copybara-github committed Aug 22, 2023
1 parent 7eaa1d4 commit bab0488
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 7 deletions.
14 changes: 10 additions & 4 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@

_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS

# _block_until_complete wait times
_JOB_WAIT_TIME = 5 # start at five seconds
_LOG_WAIT_TIME = 5
_MAX_WAIT_TIME = 60 * 5 # 5 minute wait
_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
Expand Down Expand Up @@ -561,10 +567,10 @@ def _dashboard_uri(self) -> str:
def _block_until_complete(self):
"""Helper method to block and check on job until complete."""
# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 5 # 5 minute wait
multiplier = 2 # scale wait by 2 every iteration
wait = _JOB_WAIT_TIME # start at five seconds
log_wait = _LOG_WAIT_TIME
max_wait = _MAX_WAIT_TIME # 5 minute wait
multiplier = _WAIT_TIME_MULTIPLIER # scale wait by 2 every iteration

previous_time = time.time()
while self.state not in _PIPELINE_COMPLETE_STATES:
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.compat.types import (
custom_job as gca_custom_job_compat,
)
Expand Down Expand Up @@ -399,6 +400,8 @@ def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sync):

aiplatform.init(
Expand Down Expand Up @@ -548,6 +551,8 @@ def test_submit_custom_job_with_experiments(
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_custom_job_with_timeout(
self, create_custom_job_mock, get_custom_job_mock, sync
):
Expand Down Expand Up @@ -591,6 +596,8 @@ def test_create_custom_job_with_timeout(
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_custom_job_with_timeout_not_explicitly_set(
self, create_custom_job_mock, get_custom_job_mock, sync
):
Expand Down Expand Up @@ -633,6 +640,8 @@ def test_create_custom_job_with_timeout_not_explicitly_set(
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_run_custom_job_with_fail_raises(
self, create_custom_job_mock, get_custom_job_mock_with_fail, sync
):
Expand Down Expand Up @@ -763,6 +772,8 @@ def test_get_custom_job(self, get_custom_job_mock):

@pytest.mark.usefixtures("mock_python_package_to_gcs")
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_from_local_script_prebuilt_container(
self, get_custom_job_mock, create_custom_job_mock, sync
):
Expand Down Expand Up @@ -797,6 +808,8 @@ def test_create_from_local_script_prebuilt_container(

@pytest.mark.usefixtures("mock_python_package_to_gcs")
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_from_local_script_custom_container(
self, get_custom_job_mock, create_custom_job_mock, sync
):
Expand Down Expand Up @@ -852,6 +865,8 @@ def test_create_from_local_script_raises_with_no_staging_bucket(
"update_context_mock",
)
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_from_local_script_prebuilt_container_with_all_args(
self, get_custom_job_mock, create_custom_job_mock, sync
):
Expand Down Expand Up @@ -914,6 +929,8 @@ def test_create_from_local_script_prebuilt_container_with_all_args(
"update_context_mock",
)
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_from_local_script_custom_container_with_all_args(
self, get_custom_job_mock, create_custom_job_mock, sync
):
Expand Down Expand Up @@ -989,6 +1006,8 @@ def test_create_from_local_script_enable_autolog_no_experiment_error(self):
job.run()

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_custom_job_with_enable_web_access(
self,
create_custom_job_mock_with_enable_web_access,
Expand Down Expand Up @@ -1052,6 +1071,8 @@ def test_get_web_access_uris(self, get_custom_job_mock_with_enable_web_access):
assert job.web_access_uris == _TEST_WEB_ACCESS_URIS
break

@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_log_access_web_uris_after_get(
self, get_custom_job_mock_with_enable_web_access
):
Expand All @@ -1066,6 +1087,8 @@ def test_get_web_access_uris_job_succeeded(
assert not job.web_access_uris

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_custom_job_with_tensorboard(
self, create_custom_job_mock_with_tensorboard, get_custom_job_mock, sync
):
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/aiplatform/test_hyperparameter_tuning_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import hyperparameter_tuning as hpt
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec_compat,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
Expand Down Expand Up @@ -394,6 +395,8 @@ def setup_method(self):
def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
def test_create_hyperparameter_tuning_job(
self,
Expand Down Expand Up @@ -467,6 +470,8 @@ def test_create_hyperparameter_tuning_job(
assert job.trials == []

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_hyperparameter_tuning_job_with_timeout(
self,
create_hyperparameter_tuning_job_mock,
Expand Down Expand Up @@ -535,6 +540,8 @@ def test_create_hyperparameter_tuning_job_with_timeout(
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_run_hyperparameter_tuning_job_with_fail_raises(
self,
create_hyperparameter_tuning_job_mock,
Expand Down Expand Up @@ -733,6 +740,8 @@ def test_get_hyperparameter_tuning_job(self, get_hyperparameter_tuning_job_mock)
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_hyperparameter_tuning_job_with_tensorboard(
self,
create_hyperparameter_tuning_job_mock_with_tensorboard,
Expand Down Expand Up @@ -809,6 +818,8 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
)

@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_create_hyperparameter_tuning_job_with_enable_web_access(
self,
create_hyperparameter_tuning_job_mock_with_enable_web_access,
Expand Down Expand Up @@ -889,6 +900,8 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(

caplog.clear()

@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
def test_log_enable_web_access_after_get_hyperparameter_tuning_job(
self,
get_hyperparameter_tuning_job_mock_with_enable_web_access,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ def teardown_method(self):
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 1)
def test_run_call_pipeline_service_create(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -887,6 +889,8 @@ def test_run_call_pipeline_service_create_with_timeout_not_explicitly_set(
],
)
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 1)
def test_run_call_pipeline_service_create_with_failure_policy(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1541,6 +1545,8 @@ def test_cancel_pipeline_job_without_running(
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 1)
def test_pipeline_failure_raises(self, mock_load_yaml_and_json, sync):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7227,6 +7227,8 @@ class TestVersionedTrainingJobs:
training_jobs.CustomPythonPackageTrainingJob,
],
)
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
def test_run_pipeline_for_versioned_model(
self,
mock_pipeline_service_create_with_version,
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/aiplatform/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def test_thread_continuously_uploads(self):
writer_a.flush()
uploader_thread = threading.Thread(target=uploader.start_uploading)
uploader_thread.start()
time.sleep(10)
time.sleep(5)
self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count)
call_args_list = mock_client.create_tensorboard_time_series.call_args_list
request = call_args_list[1][1]["tensorboard_time_series"]
Expand Down Expand Up @@ -1170,15 +1170,15 @@ def test_thread_continuously_uploads(self):
self.assertProtoEquals(expected_request2[0], request2[0])

uploader._end_uploading()
time.sleep(2)
time.sleep(1)
self.assertFalse(uploader_thread.is_alive())
mock_client.write_tensorboard_experiment_data.reset_mock()

# Empty directory
uploader._upload_once()
mock_client.write_tensorboard_experiment_data.assert_not_called()
uploader._end_uploading()
time.sleep(2)
time.sleep(1)
self.assertFalse(uploader_thread.is_alive())


Expand Down

0 comments on commit bab0488

Please sign in to comment.