From 48b7be22bd8d193f45f978e114fcff8daba608c6 Mon Sep 17 00:00:00 2001 From: ariadnafer Date: Mon, 22 May 2023 13:01:42 +0200 Subject: [PATCH 1/3] feat: radd missing tests --- Makefile | 12 ++ .../tests/test_extract_bq_to_dataset.py | 73 +++++++ .../tests/test_custom_training_job.py | 91 +++++++++ .../tests/test_import_model_evaluation.py | 51 +++++ pipelines/Pipfile | 1 + pipelines/Pipfile.lock | 189 ++++++++++++------ pipelines/pyproject.toml | 1 + 7 files changed, 352 insertions(+), 66 deletions(-) create mode 100644 components/bigquery-components/tests/test_extract_bq_to_dataset.py create mode 100644 components/vertex-components/tests/test_custom_training_job.py create mode 100644 components/vertex-components/tests/test_import_model_evaluation.py diff --git a/Makefile b/Makefile index 05a3c1e8..faa0fa43 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,18 @@ test-all-components: ## Run unit tests for all pipeline components $(MAKE) test-components GROUP=$$(basename $$component_group) ; \ done +test-components-coverage: ## Run tests with coverage + @cd "components/${GROUP}" && \ + pipenv run coverage run -m pytest && \ + pipenv run coverage report -m + +test-all-components-coverage: ## Run tests with coverage + @set -e && \ + for component_group in components/*/ ; do \ + echo "Test components under $$component_group" && \ + $(MAKE) test-components-coverage GROUP=$$(basename $$component_group) ; \ + done + sync-assets: ## Sync assets folder to GCS. Must specify pipeline= @if [ -d "./pipelines/src/pipelines/${PIPELINE_TEMPLATE}/$(pipeline)/assets/" ] ; then \ echo "Syncing assets to GCS" && \ diff --git a/components/bigquery-components/tests/test_extract_bq_to_dataset.py b/components/bigquery-components/tests/test_extract_bq_to_dataset.py new file mode 100644 index 00000000..8d83a62b --- /dev/null +++ b/components/bigquery-components/tests/test_extract_bq_to_dataset.py @@ -0,0 +1,73 @@ +import google.cloud.bigquery # noqa +from kfp.v2.dsl import Dataset +from unittest import mock + +import bigquery_components + +extract_bq_to_dataset = bigquery_components.extract_bq_to_dataset.python_func + + +def test_extract_bq_to_dataset(tmpdir): + with mock.patch("google.cloud.bigquery.client.Client") as mock_client, mock.patch( + "google.cloud.bigquery.job.ExtractJobConfig" + ) as mock_job_config, mock.patch("google.cloud.bigquery.table.Table") as mock_table: + + # Mock the Dataset path + mock_path = tmpdir + + # Set up the mock Client + mock_client.extract_table.return_value = "my-job" + + # Set up the mock Table + mock_table.return_value.table_ref = "my-table" + + # Set up the mock ExtractJob + mock_job_config.return_value = "mock-job-config" + + # Call the function + extract_bq_to_dataset( + bq_client_project_id="my-project-id", + source_project_id="source-project-id", + dataset_id="dataset-id", + table_name="table-name", + dataset=Dataset(uri=mock_path), + destination_gcs_uri="gs://mock_bucket", + dataset_location="EU", + extract_job_config=None, + skip_if_exists=False, + ) + + # Check that Client.extract_table was called correctly + mock_client.return_value.extract_table.assert_called_once_with( + mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config" + ) + + +def test_extract_bq_to_dataset_skip_existing(tmpdir): + with mock.patch("google.cloud.bigquery.client.Client") as mock_client, mock.patch( + "google.cloud.bigquery.table.Table" + ), mock.patch("google.cloud.bigquery.job.ExtractJobConfig"), mock.patch( + "pathlib.Path.exists" + ) as mock_path_exists: + + # # Mock the Dataset path + mock_path = tmpdir + + # Mock that the destination already exists + mock_path_exists.return_value = True + + # Call the function with skip_if_exists set to True + extract_bq_to_dataset( + bq_client_project_id="my-project-id", + source_project_id="source-project-id", + dataset_id="dataset-id", + table_name="table-name", + dataset=Dataset(uri=mock_path), + destination_gcs_uri="gs://mock_bucket", + dataset_location="EU", + extract_job_config=None, + skip_if_exists=True, + ) + + # Ensure that Client.extract_table was not called + assert not mock_client.return_value.extract_table.called diff --git a/components/vertex-components/tests/test_custom_training_job.py b/components/vertex-components/tests/test_custom_training_job.py new file mode 100644 index 00000000..2658f318 --- /dev/null +++ b/components/vertex-components/tests/test_custom_training_job.py @@ -0,0 +1,91 @@ +import google.cloud.aiplatform as aip # noqa +from kfp.v2.dsl import Dataset, Metrics, Artifact +from unittest import mock +import pytest + + +import vertex_components + +custom_train_job = vertex_components.custom_train_job.python_func + + +def test_custom_train_job(tmpdir): + with mock.patch( + "google.cloud.aiplatform.CustomTrainingJob" + ) as mock_job, mock.patch("os.path.exists") as mock_exists, mock.patch( + "builtins.open", mock.mock_open(read_data="{}") + ) as mock_open: + + mock_exists.return_value = True + + mock_train_data = Dataset(uri=tmpdir) + mock_valid_data = Dataset(uri=tmpdir) + mock_test_data = Dataset(uri=tmpdir) + + mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + + custom_train_job( + train_script_uri="gs://my-bucket/train_script.py", + train_data=mock_train_data, + valid_data=mock_valid_data, + test_data=mock_test_data, + project_id="my-project-id", + project_location="europe-west4", + model_display_name="my-model", + train_container_uri="gcr.io/my-project/my-image:latest", + serving_container_uri="gcr.io/my-project/my-serving-image:latest", + model=mock_model, + metrics=mock_metrics, + staging_bucket="gs://my-bucket", + job_name="my-job", + ) + + mock_job.assert_called_once_with( + project="my-project-id", + location="europe-west4", + staging_bucket="gs://my-bucket", + display_name="my-job", + script_path="/gcs/my-bucket/train_script.py", + container_uri="gcr.io/my-project/my-image:latest", + requirements=None, + model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501 + ) + + mock_open.assert_called_once_with(tmpdir, "r") + + +def test_custom_train_script_not_found(tmpdir): + with pytest.raises(ValueError), mock.patch( + "google.cloud.aiplatform.CustomTrainingJob" + ) as mock_job, mock.patch("os.path.exists") as mock_exists, mock.patch( + "builtins.open", mock.mock_open(read_data="{}") + ) as mock_open: + + mock_exists.return_value = False # Simulate script path not found + + mock_train_data = Dataset(uri=tmpdir) + mock_valid_data = Dataset(uri=tmpdir) + mock_test_data = Dataset(uri=tmpdir) + + mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + + custom_train_job( + train_script_uri="gs://my-bucket/train_script.py", + train_data=mock_train_data, + valid_data=mock_valid_data, + test_data=mock_test_data, + project_id="my-project-id", + project_location="europe-west4", + model_display_name="my-model", + train_container_uri="gcr.io/my-project/my-image:latest", + serving_container_uri="gcr.io/my-project/my-serving-image:latest", + model=mock_model, + metrics=mock_metrics, + staging_bucket="gs://my-bucket", + job_name="my-job", + ) + + mock_job.assert_not_called() + mock_open.assert_not_called() diff --git a/components/vertex-components/tests/test_import_model_evaluation.py b/components/vertex-components/tests/test_import_model_evaluation.py new file mode 100644 index 00000000..494c78a2 --- /dev/null +++ b/components/vertex-components/tests/test_import_model_evaluation.py @@ -0,0 +1,51 @@ +from unittest import mock +from kfp.v2.dsl import Model, Metrics, Dataset + +import vertex_components +from google.cloud.aiplatform_v1 import ModelEvaluation + + +import_model_evaluation = vertex_components.import_model_evaluation.python_func + + +def test_import_model_evaluation(tmpdir): + with mock.patch( + "google.cloud.aiplatform_v1.ModelServiceClient" + ) as mock_service_client, mock.patch( + "builtins.open", + mock.mock_open(read_data='{"accuracy": 0.85, "problemType": "classification"}'), + create=True, + ) as mock_open, mock.patch( + "google.protobuf.json_format.ParseDict" + ) as mock_parse_dict: + + mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + mock_dataset = Dataset(uri=tmpdir) + + mock_service_client_instance = mock_service_client.return_value + mock_service_client_instance.import_model_evaluation.return_value = ( + ModelEvaluation(name="model_evaluation_name") + ) + + # Set the return value for ParseDict to be a mock ModelEvaluation + mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation) + + model_evaluation_name = import_model_evaluation( + model=mock_model, + metrics=mock_metrics, + test_dataset=mock_dataset, + pipeline_job_id="1234", + project_location="my-location", + evaluation_name="Imported evaluation", + ) + + mock_service_client_instance.import_model_evaluation.assert_called_once_with( + parent=mock_model.metadata["resourceName"], + model_evaluation=mock_parse_dict.return_value, + ) + + # Check that open was called with the correct path + mock_open.assert_called_once_with(mock_metrics.uri) + + assert model_evaluation_name[0] == "model_evaluation_name" diff --git a/pipelines/Pipfile b/pipelines/Pipfile index 507b4e6d..9538e7fc 100644 --- a/pipelines/Pipfile +++ b/pipelines/Pipfile @@ -14,6 +14,7 @@ bigquery-components = {editable = true, path = "./../components/bigquery-compone [dev-packages] pytest = ">=7.3.1,<8.0.0" pre-commit = ">=2.14.1,<3.0.0" +coverage = "==7.2.5" [requires] python_version = "3.7" diff --git a/pipelines/Pipfile.lock b/pipelines/Pipfile.lock index 92d8d064..42ce01df 100644 --- a/pipelines/Pipfile.lock +++ b/pipelines/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "c889b828537547fe3e76c0ca6f812f331bffdf13bfd8c72104145db94eb555e0" + "sha256": "2350f00a78fbfb02ef816a1ba3f5316c810e711a7cc685895cb135cb293d38cd" }, "pipfile-spec": 6, "requires": { @@ -192,11 +192,11 @@ }, "google-auth": { "hashes": [ - "sha256:ce311e2bc58b130fddf316df57c9b3943c2a7b4f6ec31de9663a9333e4064efc", - "sha256:f586b274d3eb7bd932ea424b1c702a30e0393a2e2bc4ca3eae8263ffd8be229f" + "sha256:55a395cdfd3f3dd3f649131d41f97c17b4ed8a2aac1be3502090c716314e8a37", + "sha256:d7a3249027e7f464fbbfd7ee8319a08ad09d2eea51578575c4bd360ffa049ccb" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", - "version": "==2.17.3" + "version": "==2.18.1" }, "google-auth-httplib2": { "hashes": [ @@ -246,11 +246,11 @@ }, "google-cloud-resource-manager": { "hashes": [ - "sha256:26beb595b957972df50173f1d0fd51c00d280551eac73566017ebdda62b1616a", - "sha256:bfc3e60eb92e25ac562a9248bb8fc17e9bef04c3dc9f031ffbe0dfe28d919287" + "sha256:41a2204532f084c707fde0bc1a9bc95c7e0b739d7072dd0b8a25106667a56184", + "sha256:c974fb6f9810476cf7b63ea89394c1a8df47f7f2dc2303e728bb74b500bcde67" ], "markers": "python_version >= '3.7'", - "version": "==1.10.0" + "version": "==1.10.1" }, "google-cloud-storage": { "hashes": [ @@ -360,53 +360,53 @@ }, "grpcio": { "hashes": [ - "sha256:02000b005bc8b72ff50c477b6431e8886b29961159e8b8d03c00b3dd9139baed", - "sha256:031bbd26656e0739e4b2c81c172155fb26e274b8d0312d67aefc730bcba915b6", - "sha256:1209d6b002b26e939e4c8ea37a3d5b4028eb9555394ea69fb1adbd4b61a10bb8", - "sha256:125ed35aa3868efa82eabffece6264bf638cfdc9f0cd58ddb17936684aafd0f8", - "sha256:1382bc499af92901c2240c4d540c74eae8a671e4fe9839bfeefdfcc3a106b5e2", - "sha256:16bca8092dd994f2864fdab278ae052fad4913f36f35238b2dd11af2d55a87db", - "sha256:1c59d899ee7160638613a452f9a4931de22623e7ba17897d8e3e348c2e9d8d0b", - "sha256:1d109df30641d050e009105f9c9ca5a35d01e34d2ee2a4e9c0984d392fd6d704", - "sha256:1fa7d6ddd33abbd3c8b3d7d07c56c40ea3d1891ce3cd2aa9fa73105ed5331866", - "sha256:21c4a1aae861748d6393a3ff7867473996c139a77f90326d9f4104bebb22d8b8", - "sha256:224166f06ccdaf884bf35690bf4272997c1405de3035d61384ccb5b25a4c1ca8", - "sha256:2262bd3512ba9e9f0e91d287393df6f33c18999317de45629b7bd46c40f16ba9", - "sha256:2585b3c294631a39b33f9f967a59b0fad23b1a71a212eba6bc1e3ca6e6eec9ee", - "sha256:27fb030a4589d2536daec5ff5ba2a128f4f155149efab578fe2de2cb21596d3d", - "sha256:30fbbce11ffeb4f9f91c13fe04899aaf3e9a81708bedf267bf447596b95df26b", - "sha256:3930669c9e6f08a2eed824738c3d5699d11cd47a0ecc13b68ed11595710b1133", - "sha256:3b170e441e91e4f321e46d3cc95a01cb307a4596da54aca59eb78ab0fc03754d", - "sha256:3db71c6f1ab688d8dfc102271cedc9828beac335a3a4372ec54b8bf11b43fd29", - "sha256:48cb7af77238ba16c77879009003f6b22c23425e5ee59cb2c4c103ec040638a5", - "sha256:49eace8ea55fbc42c733defbda1e4feb6d3844ecd875b01bb8b923709e0f5ec8", - "sha256:533eaf5b2a79a3c6f35cbd6a095ae99cac7f4f9c0e08bdcf86c130efd3c32adf", - "sha256:5942a3e05630e1ef5b7b5752e5da6582460a2e4431dae603de89fc45f9ec5aa9", - "sha256:62117486460c83acd3b5d85c12edd5fe20a374630475388cfc89829831d3eb79", - "sha256:650f5f2c9ab1275b4006707411bb6d6bc927886874a287661c3c6f332d4c068b", - "sha256:6dc1e2c9ac292c9a484ef900c568ccb2d6b4dfe26dfa0163d5bc815bb836c78d", - "sha256:73c238ef6e4b64272df7eec976bb016c73d3ab5a6c7e9cd906ab700523d312f3", - "sha256:775a2f70501370e5ba54e1ee3464413bff9bd85bd9a0b25c989698c44a6fb52f", - "sha256:860fcd6db7dce80d0a673a1cc898ce6bc3d4783d195bbe0e911bf8a62c93ff3f", - "sha256:87f47bf9520bba4083d65ab911f8f4c0ac3efa8241993edd74c8dd08ae87552f", - "sha256:960b176e0bb2b4afeaa1cd2002db1e82ae54c9b6e27ea93570a42316524e77cf", - "sha256:a7caf553ccaf715ec05b28c9b2ab2ee3fdb4036626d779aa09cf7cbf54b71445", - "sha256:a947d5298a0bbdd4d15671024bf33e2b7da79a70de600ed29ba7e0fef0539ebb", - "sha256:a97b0d01ae595c997c1d9d8249e2d2da829c2d8a4bdc29bb8f76c11a94915c9a", - "sha256:b7655f809e3420f80ce3bf89737169a9dce73238af594049754a1128132c0da4", - "sha256:c33744d0d1a7322da445c0fe726ea6d4e3ef2dfb0539eadf23dce366f52f546c", - "sha256:c55a9cf5cba80fb88c850915c865b8ed78d5e46e1f2ec1b27692f3eaaf0dca7e", - "sha256:d2f62fb1c914a038921677cfa536d645cb80e3dd07dc4859a3c92d75407b90a5", - "sha256:d8ae6e0df3a608e99ee1acafaafd7db0830106394d54571c1ece57f650124ce9", - "sha256:e355ee9da9c1c03f174efea59292b17a95e0b7b4d7d2a389265f731a9887d5a9", - "sha256:e3e526062c690517b42bba66ffe38aaf8bc99a180a78212e7b22baa86902f690", - "sha256:eb0807323572642ab73fd86fe53d88d843ce617dd1ddf430351ad0759809a0ae", - "sha256:ebff0738be0499d7db74d20dca9f22a7b27deae31e1bf92ea44924fd69eb6251", - "sha256:ed36e854449ff6c2f8ee145f94851fe171298e1e793f44d4f672c4a0d78064e7", - "sha256:ed3d458ded32ff3a58f157b60cc140c88f7ac8c506a1c567b2a9ee8a2fd2ce54", - "sha256:f4a7dca8ccd8023d916b900aa3c626f1bd181bd5b70159479b142f957ff420e4" - ], - "version": "==1.54.0" + "sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3", + "sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3", + "sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821", + "sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea", + "sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98", + "sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5", + "sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8", + "sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08", + "sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c", + "sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12", + "sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c", + "sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f", + "sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e", + "sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796", + "sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019", + "sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474", + "sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c", + "sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598", + "sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7", + "sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe", + "sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94", + "sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c", + "sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05", + "sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8", + "sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf", + "sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b", + "sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771", + "sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b", + "sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f", + "sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148", + "sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a", + "sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610", + "sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a", + "sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa", + "sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee", + "sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e", + "sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b", + "sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c", + "sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb", + "sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1", + "sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4", + "sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb", + "sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a", + "sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2", + "sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb" + ], + "version": "==1.54.2" }, "grpcio-status": { "hashes": [ @@ -784,11 +784,11 @@ }, "setuptools": { "hashes": [ - "sha256:23aaf86b85ca52ceb801d32703f12d77517b2556af839621c641fca11287952b", - "sha256:f104fa03692a2602fa0fec6c6a9e63b6c8a968de13e17c026957dd1f53d80990" + "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f", + "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102" ], "markers": "python_version >= '3.7'", - "version": "==67.7.2" + "version": "==67.8.0" }, "shapely": { "hashes": [ @@ -1019,6 +1019,63 @@ "markers": "python_full_version >= '3.6.1'", "version": "==3.3.1" }, + "coverage": { + "hashes": [ + "sha256:0342a28617e63ad15d96dca0f7ae9479a37b7d8a295f749c14f3436ea59fdcb3", + "sha256:066b44897c493e0dcbc9e6a6d9f8bbb6607ef82367cf6810d387c09f0cd4fe9a", + "sha256:10b15394c13544fce02382360cab54e51a9e0fd1bd61ae9ce012c0d1e103c813", + "sha256:12580845917b1e59f8a1c2ffa6af6d0908cb39220f3019e36c110c943dc875b0", + "sha256:156192e5fd3dbbcb11cd777cc469cf010a294f4c736a2b2c891c77618cb1379a", + "sha256:1637253b11a18f453e34013c665d8bf15904c9e3c44fbda34c643fbdc9d452cd", + "sha256:292300f76440651529b8ceec283a9370532f4ecba9ad67d120617021bb5ef139", + "sha256:30dcaf05adfa69c2a7b9f7dfd9f60bc8e36b282d7ed25c308ef9e114de7fc23b", + "sha256:338aa9d9883aaaad53695cb14ccdeb36d4060485bb9388446330bef9c361c252", + "sha256:373ea34dca98f2fdb3e5cb33d83b6d801007a8074f992b80311fc589d3e6b790", + "sha256:38c0a497a000d50491055805313ed83ddba069353d102ece8aef5d11b5faf045", + "sha256:40cc0f91c6cde033da493227797be2826cbf8f388eaa36a0271a97a332bfd7ce", + "sha256:4436cc9ba5414c2c998eaedee5343f49c02ca93b21769c5fdfa4f9d799e84200", + "sha256:509ecd8334c380000d259dc66feb191dd0a93b21f2453faa75f7f9cdcefc0718", + "sha256:5c587f52c81211d4530fa6857884d37f514bcf9453bdeee0ff93eaaf906a5c1b", + "sha256:5f3671662dc4b422b15776cdca89c041a6349b4864a43aa2350b6b0b03bbcc7f", + "sha256:6599bf92f33ab041e36e06d25890afbdf12078aacfe1f1d08c713906e49a3fe5", + "sha256:6e8a95f243d01ba572341c52f89f3acb98a3b6d1d5d830efba86033dd3687ade", + "sha256:706ec567267c96717ab9363904d846ec009a48d5f832140b6ad08aad3791b1f5", + "sha256:780551e47d62095e088f251f5db428473c26db7829884323e56d9c0c3118791a", + "sha256:7ff8f3fb38233035028dbc93715551d81eadc110199e14bbbfa01c5c4a43f8d8", + "sha256:828189fcdda99aae0d6bf718ea766b2e715eabc1868670a0a07bf8404bf58c33", + "sha256:857abe2fa6a4973f8663e039ead8d22215d31db613ace76e4a98f52ec919068e", + "sha256:883123d0bbe1c136f76b56276074b0c79b5817dd4238097ffa64ac67257f4b6c", + "sha256:8877d9b437b35a85c18e3c6499b23674684bf690f5d96c1006a1ef61f9fdf0f3", + "sha256:8e575a59315a91ccd00c7757127f6b2488c2f914096077c745c2f1ba5b8c0969", + "sha256:97072cc90f1009386c8a5b7de9d4fc1a9f91ba5ef2146c55c1f005e7b5c5e068", + "sha256:9a22cbb5ede6fade0482111fa7f01115ff04039795d7092ed0db43522431b4f2", + "sha256:a063aad9f7b4c9f9da7b2550eae0a582ffc7623dca1c925e50c3fbde7a579771", + "sha256:a08c7401d0b24e8c2982f4e307124b671c6736d40d1c39e09d7a8687bddf83ed", + "sha256:a0b273fe6dc655b110e8dc89b8ec7f1a778d78c9fd9b4bda7c384c8906072212", + "sha256:a2b3b05e22a77bb0ae1a3125126a4e08535961c946b62f30985535ed40e26614", + "sha256:a66e055254a26c82aead7ff420d9fa8dc2da10c82679ea850d8feebf11074d88", + "sha256:aa387bd7489f3e1787ff82068b295bcaafbf6f79c3dad3cbc82ef88ce3f48ad3", + "sha256:ae453f655640157d76209f42c62c64c4d4f2c7f97256d3567e3b439bd5c9b06c", + "sha256:b5016e331b75310610c2cf955d9f58a9749943ed5f7b8cfc0bb89c6134ab0a84", + "sha256:b9a4ee55174b04f6af539218f9f8083140f61a46eabcaa4234f3c2a452c4ed11", + "sha256:bd3b4b8175c1db502adf209d06136c000df4d245105c8839e9d0be71c94aefe1", + "sha256:bebea5f5ed41f618797ce3ffb4606c64a5de92e9c3f26d26c2e0aae292f015c1", + "sha256:c10fbc8a64aa0f3ed136b0b086b6b577bc64d67d5581acd7cc129af52654384e", + "sha256:c2c41c1b1866b670573657d584de413df701f482574bad7e28214a2362cb1fd1", + "sha256:cf97ed82ca986e5c637ea286ba2793c85325b30f869bf64d3009ccc1a31ae3fd", + "sha256:d1f25ee9de21a39b3a8516f2c5feb8de248f17da7eead089c2e04aa097936b47", + "sha256:d2fbc2a127e857d2f8898aaabcc34c37771bf78a4d5e17d3e1f5c30cd0cbc62a", + "sha256:dc945064a8783b86fcce9a0a705abd7db2117d95e340df8a4333f00be5efb64c", + "sha256:ddc5a54edb653e9e215f75de377354e2455376f416c4378e1d43b08ec50acc31", + "sha256:e8834e5f17d89e05697c3c043d3e58a8b19682bf365048837383abfe39adaed5", + "sha256:ef9659d1cda9ce9ac9585c045aaa1e59223b143f2407db0eaee0b61a4f266fb6", + "sha256:f6f5cab2d7f0c12f8187a376cc6582c477d2df91d63f75341307fcdcb5d60303", + "sha256:f81c9b4bd8aa747d417407a7f6f0b1469a43b36a85748145e144ac4e8d303cb5", + "sha256:f99ef080288f09ffc687423b8d60978cf3a465d3f404a18d1a05474bd8575a47" + ], + "index": "pypi", + "version": "==7.2.5" + }, "distlib": { "hashes": [ "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46", @@ -1068,11 +1125,11 @@ }, "nodeenv": { "hashes": [ - "sha256:27083a7b96a25f2f5e1d8cb4b6317ee8aeda3bdd121394e5ac54e498028a042e", - "sha256:e0e7f7dfb85fc5394c6fe1e8fa98131a2473e04311a45afb6508f7cf1836fa2b" + "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2", + "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'", - "version": "==1.7.0" + "version": "==1.8.0" }, "packaging": { "hashes": [ @@ -1084,11 +1141,11 @@ }, "platformdirs": { "hashes": [ - "sha256:47692bc24c1958e8b0f13dd727307cff1db103fca36399f457da8e05f222fdc4", - "sha256:7954a68d0ba23558d753f73437c55f89027cf8f5108c19844d4b82e5af396335" + "sha256:412dae91f52a6f84830f39a8078cecd0e866cb72294a5c66808e74d5e88d251f", + "sha256:e2378146f1964972c03c085bb5662ae80b2b8c06226c54b2ff4aa9483e8a13a5" ], "markers": "python_version >= '3.7'", - "version": "==3.5.0" + "version": "==3.5.1" }, "pluggy": { "hashes": [ @@ -1151,11 +1208,11 @@ }, "setuptools": { "hashes": [ - "sha256:23aaf86b85ca52ceb801d32703f12d77517b2556af839621c641fca11287952b", - "sha256:f104fa03692a2602fa0fec6c6a9e63b6c8a968de13e17c026957dd1f53d80990" + "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f", + "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102" ], "markers": "python_version >= '3.7'", - "version": "==67.7.2" + "version": "==67.8.0" }, "tomli": { "hashes": [ diff --git a/pipelines/pyproject.toml b/pipelines/pyproject.toml index a5e2cbb7..693a3bd8 100644 --- a/pipelines/pyproject.toml +++ b/pipelines/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ tests = [ "google-cloud-bigquery == 2.30.0", "pytest >= 7.3.1,<8.0.0", + "coverage = ==7.2.5" ] [build-system] From 86b47b23e361fb90d6e44239335fbfc9c455433f Mon Sep 17 00:00:00 2001 From: ariadnafer Date: Mon, 22 May 2023 13:30:25 +0200 Subject: [PATCH 2/3] docs: add comments in tests --- .../tests/test_extract_bq_to_dataset.py | 6 +++--- .../vertex-components/tests/test_custom_training_job.py | 9 ++++++++- .../tests/test_import_model_evaluation.py | 9 +++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/components/bigquery-components/tests/test_extract_bq_to_dataset.py b/components/bigquery-components/tests/test_extract_bq_to_dataset.py index 8d83a62b..db00da21 100644 --- a/components/bigquery-components/tests/test_extract_bq_to_dataset.py +++ b/components/bigquery-components/tests/test_extract_bq_to_dataset.py @@ -37,7 +37,7 @@ def test_extract_bq_to_dataset(tmpdir): skip_if_exists=False, ) - # Check that Client.extract_table was called correctly + # Check that client.extract_table was called correctly mock_client.return_value.extract_table.assert_called_once_with( mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config" ) @@ -50,13 +50,13 @@ def test_extract_bq_to_dataset_skip_existing(tmpdir): "pathlib.Path.exists" ) as mock_path_exists: - # # Mock the Dataset path + # Mock the Dataset path mock_path = tmpdir # Mock that the destination already exists mock_path_exists.return_value = True - # Call the function with skip_if_exists set to True + # Call the function extract_bq_to_dataset( bq_client_project_id="my-project-id", source_project_id="source-project-id", diff --git a/components/vertex-components/tests/test_custom_training_job.py b/components/vertex-components/tests/test_custom_training_job.py index 2658f318..cb075c16 100644 --- a/components/vertex-components/tests/test_custom_training_job.py +++ b/components/vertex-components/tests/test_custom_training_job.py @@ -16,8 +16,10 @@ def test_custom_train_job(tmpdir): "builtins.open", mock.mock_open(read_data="{}") ) as mock_open: + # Mock that the training script exists mock_exists.return_value = True + # Mock Artifacts mock_train_data = Dataset(uri=tmpdir) mock_valid_data = Dataset(uri=tmpdir) mock_test_data = Dataset(uri=tmpdir) @@ -25,6 +27,7 @@ def test_custom_train_job(tmpdir): mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) mock_metrics = Metrics(uri=tmpdir) + # Call function custom_train_job( train_script_uri="gs://my-bucket/train_script.py", train_data=mock_train_data, @@ -41,6 +44,7 @@ def test_custom_train_job(tmpdir): job_name="my-job", ) + # Assert custom training job is called mock_job.assert_called_once_with( project="my-project-id", location="europe-west4", @@ -52,6 +56,7 @@ def test_custom_train_job(tmpdir): model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501 ) + # Assert metrics loading mock_open.assert_called_once_with(tmpdir, "r") @@ -62,7 +67,8 @@ def test_custom_train_script_not_found(tmpdir): "builtins.open", mock.mock_open(read_data="{}") ) as mock_open: - mock_exists.return_value = False # Simulate script path not found + # Mock that the training script is not found + mock_exists.return_value = False mock_train_data = Dataset(uri=tmpdir) mock_valid_data = Dataset(uri=tmpdir) @@ -87,5 +93,6 @@ def test_custom_train_script_not_found(tmpdir): job_name="my-job", ) + # Assert the custom training job is not executed mock_job.assert_not_called() mock_open.assert_not_called() diff --git a/components/vertex-components/tests/test_import_model_evaluation.py b/components/vertex-components/tests/test_import_model_evaluation.py index 494c78a2..fe1f85eb 100644 --- a/components/vertex-components/tests/test_import_model_evaluation.py +++ b/components/vertex-components/tests/test_import_model_evaluation.py @@ -19,11 +19,15 @@ def test_import_model_evaluation(tmpdir): "google.protobuf.json_format.ParseDict" ) as mock_parse_dict: + # Mock Artifacts mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) mock_metrics = Metrics(uri=tmpdir) mock_dataset = Dataset(uri=tmpdir) + # Create an instance of the mocked ModelServiceClient. mock_service_client_instance = mock_service_client.return_value + # When import_model_evaluation is called during the test, + # it will return a new ModelEvaluation with the specified name. mock_service_client_instance.import_model_evaluation.return_value = ( ModelEvaluation(name="model_evaluation_name") ) @@ -31,6 +35,7 @@ def test_import_model_evaluation(tmpdir): # Set the return value for ParseDict to be a mock ModelEvaluation mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation) + # Call the function model_evaluation_name = import_model_evaluation( model=mock_model, metrics=mock_metrics, @@ -40,6 +45,8 @@ def test_import_model_evaluation(tmpdir): evaluation_name="Imported evaluation", ) + # Assert that the import_model_evaluation method of + # the mocked ModelServiceClient was called mock_service_client_instance.import_model_evaluation.assert_called_once_with( parent=mock_model.metadata["resourceName"], model_evaluation=mock_parse_dict.return_value, @@ -48,4 +55,6 @@ def test_import_model_evaluation(tmpdir): # Check that open was called with the correct path mock_open.assert_called_once_with(mock_metrics.uri) + # Assert that the return value of the import_model_evaluation + # function is as expected. assert model_evaluation_name[0] == "model_evaluation_name" From 6d9c43fce26fb6c9d4bf0645d526ee7a4bd6aa22 Mon Sep 17 00:00:00 2001 From: ariadnafer Date: Tue, 23 May 2023 11:16:11 +0200 Subject: [PATCH 3/3] refactor: change patch formatting in all tests --- .../tests/test_extract_bq_to_dataset.py | 119 +++++++-------- .../tests/test_custom_training_job.py | 137 +++++++++--------- .../tests/test_import_model_evaluation.py | 97 ++++++------- .../tests/test_lookup_model.py | 126 +++++++--------- .../tests/test_model_batch_predict.py | 59 ++++---- .../tests/test_update_best_model.py | 43 +++--- 6 files changed, 275 insertions(+), 306 deletions(-) diff --git a/components/bigquery-components/tests/test_extract_bq_to_dataset.py b/components/bigquery-components/tests/test_extract_bq_to_dataset.py index db00da21..be3084e6 100644 --- a/components/bigquery-components/tests/test_extract_bq_to_dataset.py +++ b/components/bigquery-components/tests/test_extract_bq_to_dataset.py @@ -7,67 +7,58 @@ extract_bq_to_dataset = bigquery_components.extract_bq_to_dataset.python_func -def test_extract_bq_to_dataset(tmpdir): - with mock.patch("google.cloud.bigquery.client.Client") as mock_client, mock.patch( - "google.cloud.bigquery.job.ExtractJobConfig" - ) as mock_job_config, mock.patch("google.cloud.bigquery.table.Table") as mock_table: - - # Mock the Dataset path - mock_path = tmpdir - - # Set up the mock Client - mock_client.extract_table.return_value = "my-job" - - # Set up the mock Table - mock_table.return_value.table_ref = "my-table" - - # Set up the mock ExtractJob - mock_job_config.return_value = "mock-job-config" - - # Call the function - extract_bq_to_dataset( - bq_client_project_id="my-project-id", - source_project_id="source-project-id", - dataset_id="dataset-id", - table_name="table-name", - dataset=Dataset(uri=mock_path), - destination_gcs_uri="gs://mock_bucket", - dataset_location="EU", - extract_job_config=None, - skip_if_exists=False, - ) - - # Check that client.extract_table was called correctly - mock_client.return_value.extract_table.assert_called_once_with( - mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config" - ) - - -def test_extract_bq_to_dataset_skip_existing(tmpdir): - with mock.patch("google.cloud.bigquery.client.Client") as mock_client, mock.patch( - "google.cloud.bigquery.table.Table" - ), mock.patch("google.cloud.bigquery.job.ExtractJobConfig"), mock.patch( - "pathlib.Path.exists" - ) as mock_path_exists: - - # Mock the Dataset path - mock_path = tmpdir - - # Mock that the destination already exists - mock_path_exists.return_value = True - - # Call the function - extract_bq_to_dataset( - bq_client_project_id="my-project-id", - source_project_id="source-project-id", - dataset_id="dataset-id", - table_name="table-name", - dataset=Dataset(uri=mock_path), - destination_gcs_uri="gs://mock_bucket", - dataset_location="EU", - extract_job_config=None, - skip_if_exists=True, - ) - - # Ensure that Client.extract_table was not called - assert not mock_client.return_value.extract_table.called +@mock.patch("google.cloud.bigquery.client.Client") +@mock.patch("google.cloud.bigquery.table.Table") +@mock.patch("google.cloud.bigquery.job.ExtractJobConfig") +def test_extract_bq_to_dataset(mock_job_config, mock_table, mock_client, tmpdir): + """ + Checks that the extract_bq_to_dataset is called correctly + """ + mock_path = tmpdir + mock_client.extract_table.return_value = "my-job" + mock_table.return_value.table_ref = "my-table" + mock_job_config.return_value = "mock-job-config" + + extract_bq_to_dataset( + bq_client_project_id="my-project-id", + source_project_id="source-project-id", + dataset_id="dataset-id", + table_name="table-name", + dataset=Dataset(uri=mock_path), + destination_gcs_uri="gs://mock_bucket", + dataset_location="EU", + extract_job_config=None, + skip_if_exists=False, + ) + + mock_client.return_value.extract_table.assert_called_once_with( + mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config" + ) + + +@mock.patch("google.cloud.bigquery.client.Client") +@mock.patch("google.cloud.bigquery.table.Table") +@mock.patch("google.cloud.bigquery.job.ExtractJobConfig") +@mock.patch("pathlib.Path.exists") +def test_extract_bq_to_dataset_skip_existing( + mock_path_exists, mock_job_config, mock_table, mock_client, tmpdir +): + """ + Checks that when the dataset exists the method is not called + """ + mock_path = tmpdir + mock_path_exists.return_value = True + + extract_bq_to_dataset( + bq_client_project_id="my-project-id", + source_project_id="source-project-id", + dataset_id="dataset-id", + table_name="table-name", + dataset=Dataset(uri=mock_path), + destination_gcs_uri="gs://mock_bucket", + dataset_location="EU", + extract_job_config=None, + skip_if_exists=True, + ) + + assert not mock_client.return_value.extract_table.called diff --git a/components/vertex-components/tests/test_custom_training_job.py b/components/vertex-components/tests/test_custom_training_job.py index cb075c16..46ec46ca 100644 --- a/components/vertex-components/tests/test_custom_training_job.py +++ b/components/vertex-components/tests/test_custom_training_job.py @@ -9,74 +9,69 @@ custom_train_job = vertex_components.custom_train_job.python_func -def test_custom_train_job(tmpdir): - with mock.patch( - "google.cloud.aiplatform.CustomTrainingJob" - ) as mock_job, mock.patch("os.path.exists") as mock_exists, mock.patch( - "builtins.open", mock.mock_open(read_data="{}") - ) as mock_open: - - # Mock that the training script exists - mock_exists.return_value = True - - # Mock Artifacts - mock_train_data = Dataset(uri=tmpdir) - mock_valid_data = Dataset(uri=tmpdir) - mock_test_data = Dataset(uri=tmpdir) - - mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) - mock_metrics = Metrics(uri=tmpdir) - - # Call function - custom_train_job( - train_script_uri="gs://my-bucket/train_script.py", - train_data=mock_train_data, - valid_data=mock_valid_data, - test_data=mock_test_data, - project_id="my-project-id", - project_location="europe-west4", - model_display_name="my-model", - train_container_uri="gcr.io/my-project/my-image:latest", - serving_container_uri="gcr.io/my-project/my-serving-image:latest", - model=mock_model, - metrics=mock_metrics, - staging_bucket="gs://my-bucket", - job_name="my-job", - ) - - # Assert custom training job is called - mock_job.assert_called_once_with( - project="my-project-id", - location="europe-west4", - staging_bucket="gs://my-bucket", - display_name="my-job", - script_path="/gcs/my-bucket/train_script.py", - container_uri="gcr.io/my-project/my-image:latest", - requirements=None, - model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501 - ) - - # Assert metrics loading - mock_open.assert_called_once_with(tmpdir, "r") - - -def test_custom_train_script_not_found(tmpdir): - with pytest.raises(ValueError), mock.patch( - "google.cloud.aiplatform.CustomTrainingJob" - ) as mock_job, mock.patch("os.path.exists") as mock_exists, mock.patch( - "builtins.open", mock.mock_open(read_data="{}") - ) as mock_open: - - # Mock that the training script is not found - mock_exists.return_value = False - - mock_train_data = Dataset(uri=tmpdir) - mock_valid_data = Dataset(uri=tmpdir) - mock_test_data = Dataset(uri=tmpdir) - - mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) - mock_metrics = Metrics(uri=tmpdir) - +@mock.patch("google.cloud.aiplatform.CustomTrainingJob") +@mock.patch("os.path.exists") +@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}") +def test_custom_train_job(mock_open, mock_exists, mock_job, tmpdir): + """ + Checks that the custom job method is called + """ + mock_exists.return_value = True + + mock_train_data = Dataset(uri=tmpdir) + mock_valid_data = Dataset(uri=tmpdir) + mock_test_data = Dataset(uri=tmpdir) + mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + + custom_train_job( + train_script_uri="gs://my-bucket/train_script.py", + train_data=mock_train_data, + valid_data=mock_valid_data, + test_data=mock_test_data, + project_id="my-project-id", + project_location="europe-west4", + model_display_name="my-model", + train_container_uri="gcr.io/my-project/my-image:latest", + serving_container_uri="gcr.io/my-project/my-serving-image:latest", + model=mock_model, + metrics=mock_metrics, + staging_bucket="gs://my-bucket", + job_name="my-job", + ) + + mock_job.assert_called_once_with( + project="my-project-id", + location="europe-west4", + staging_bucket="gs://my-bucket", + display_name="my-job", + script_path="/gcs/my-bucket/train_script.py", + container_uri="gcr.io/my-project/my-image:latest", + requirements=None, + model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501 + ) + + # Assert metrics loading + mock_open.assert_called_once_with(tmpdir, "r") + + +@mock.patch("google.cloud.aiplatform.CustomTrainingJob") +@mock.patch("os.path.exists") +@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}") +def test_custom_train_script_not_found(mock_open, mock_exists, mock_job, tmpdir): + """ + Checks that when the training script is not found + the method fails + """ + mock_exists.return_value = False + + mock_train_data = Dataset(uri=tmpdir) + mock_valid_data = Dataset(uri=tmpdir) + mock_test_data = Dataset(uri=tmpdir) + mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + + with pytest.raises(ValueError): custom_train_job( train_script_uri="gs://my-bucket/train_script.py", train_data=mock_train_data, @@ -93,6 +88,6 @@ def test_custom_train_script_not_found(tmpdir): job_name="my-job", ) - # Assert the custom training job is not executed - mock_job.assert_not_called() - mock_open.assert_not_called() + # Assert the custom training job is not executed + mock_job.assert_not_called() + mock_open.assert_not_called() diff --git a/components/vertex-components/tests/test_import_model_evaluation.py b/components/vertex-components/tests/test_import_model_evaluation.py index fe1f85eb..56cdf8bd 100644 --- a/components/vertex-components/tests/test_import_model_evaluation.py +++ b/components/vertex-components/tests/test_import_model_evaluation.py @@ -8,53 +8,50 @@ import_model_evaluation = vertex_components.import_model_evaluation.python_func -def test_import_model_evaluation(tmpdir): - with mock.patch( - "google.cloud.aiplatform_v1.ModelServiceClient" - ) as mock_service_client, mock.patch( - "builtins.open", - mock.mock_open(read_data='{"accuracy": 0.85, "problemType": "classification"}'), - create=True, - ) as mock_open, mock.patch( - "google.protobuf.json_format.ParseDict" - ) as mock_parse_dict: - - # Mock Artifacts - mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) - mock_metrics = Metrics(uri=tmpdir) - mock_dataset = Dataset(uri=tmpdir) - - # Create an instance of the mocked ModelServiceClient. - mock_service_client_instance = mock_service_client.return_value - # When import_model_evaluation is called during the test, - # it will return a new ModelEvaluation with the specified name. - mock_service_client_instance.import_model_evaluation.return_value = ( - ModelEvaluation(name="model_evaluation_name") - ) - - # Set the return value for ParseDict to be a mock ModelEvaluation - mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation) - - # Call the function - model_evaluation_name = import_model_evaluation( - model=mock_model, - metrics=mock_metrics, - test_dataset=mock_dataset, - pipeline_job_id="1234", - project_location="my-location", - evaluation_name="Imported evaluation", - ) - - # Assert that the import_model_evaluation method of - # the mocked ModelServiceClient was called - mock_service_client_instance.import_model_evaluation.assert_called_once_with( - parent=mock_model.metadata["resourceName"], - model_evaluation=mock_parse_dict.return_value, - ) - - # Check that open was called with the correct path - mock_open.assert_called_once_with(mock_metrics.uri) - - # Assert that the return value of the import_model_evaluation - # function is as expected. - assert model_evaluation_name[0] == "model_evaluation_name" +@mock.patch("google.cloud.aiplatform_v1.ModelServiceClient") +@mock.patch( + "builtins.open", + new_callable=mock.mock_open, + read_data='{"accuracy": 0.85, "problemType": "classification"}', +) +@mock.patch("google.protobuf.json_format.ParseDict") +def test_import_model_evaluation( + mock_parse_dict, mock_open, mock_service_client, tmpdir +): + """ + Checks that when the model evaluation is running and it is writing the metrics + """ + mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) + mock_metrics = Metrics(uri=tmpdir) + mock_dataset = Dataset(uri=tmpdir) + + # Create an instance of the mocked ModelServiceClient. + service_client_instance = mock.MagicMock() + mock_service_client.return_value = service_client_instance + # When import_model_evaluation is called during the test, + # it will return a new ModelEvaluation with the specified name. + service_client_instance.import_model_evaluation.return_value = ModelEvaluation( + name="model_evaluation_name" + ) + + # Set the return value for ParseDict to be a mock ModelEvaluation + mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation) + + model_evaluation_name = import_model_evaluation( + model=mock_model, + metrics=mock_metrics, + test_dataset=mock_dataset, + pipeline_job_id="1234", + project_location="my-location", + evaluation_name="Imported evaluation", + ) + + service_client_instance.import_model_evaluation.assert_called_once_with( + parent=mock_model.metadata["resourceName"], + model_evaluation=mock_parse_dict.return_value, + ) + + # Check that open was called with the correct path + mock_open.assert_called_once_with(mock_metrics.uri) + + assert model_evaluation_name[0] == "model_evaluation_name" diff --git a/components/vertex-components/tests/test_lookup_model.py b/components/vertex-components/tests/test_lookup_model.py index 1ba31f81..6950dd1b 100644 --- a/components/vertex-components/tests/test_lookup_model.py +++ b/components/vertex-components/tests/test_lookup_model.py @@ -22,93 +22,75 @@ lookup_model = vertex_components.lookup_model.python_func -def test_lookup_model(tmpdir): +@mock.patch("google.cloud.aiplatform.Model") +def test_lookup_model(mock_model, tmpdir): """ Assert lookup_model produces expected resource name, and that list method is called with the correct arguemnts - - Args: - tmpdir: built-in pytest tmpdir fixture - - Returns: - None """ - with mock.patch("google.cloud.aiplatform.Model") as mock_model: - - # Mock attribute and method - - mock_path = tmpdir - mock_model.resource_name = "my-model-resource-name" - mock_model.uri = mock_path - mock_model.list.return_value = [mock_model] - - # Invoke the model look up - found_model_resource_name, _ = lookup_model( - model_name="my-model", - project_location="europe-west4", - project_id="my-project-id", - order_models_by="create_time desc", - fail_on_model_not_found=False, - model=Model(uri=mock_path), - ) - - assert found_model_resource_name == "my-model-resource-name" - - # Check the list method was called once with the correct arguments - mock_model.list.assert_called_once_with( - filter='display_name="my-model"', - order_by="create_time desc", - location="europe-west4", - project="my-project-id", - ) - -def test_lookup_model_when_no_models(tmpdir): + # Mock attribute and method + mock_path = tmpdir + mock_model.resource_name = "my-model-resource-name" + mock_model.uri = mock_path + mock_model.list.return_value = [mock_model] + + # Invoke the model look up + found_model_resource_name, _ = lookup_model( + model_name="my-model", + project_location="europe-west4", + project_id="my-project-id", + order_models_by="create_time desc", + fail_on_model_not_found=False, + model=Model(uri=mock_path), + ) + + assert found_model_resource_name == "my-model-resource-name" + + # Check the list method was called once with the correct arguments + mock_model.list.assert_called_once_with( + filter='display_name="my-model"', + order_by="create_time desc", + location="europe-west4", + project="my-project-id", + ) + + +@mock.patch("google.cloud.aiplatform.Model") +def test_lookup_model_when_no_models(mock_model, tmpdir): """ Checks that when there are no models and fail_on_model_found = False, lookup_model returns an empty string. - - Args: - tmpdir: built-in pytest tmpdir fixture - - Returns: - None """ - with mock.patch("google.cloud.aiplatform.Model") as mock_model: - mock_model.list.return_value = [] - exported_model_resource_name, _ = lookup_model( - model_name="my-model", - project_location="europe-west4", - project_id="my-project-id", - order_models_by="create_time desc", - fail_on_model_not_found=False, - model=Model(uri=str(tmpdir)), - ) + mock_model.list.return_value = [] + exported_model_resource_name, _ = lookup_model( + model_name="my-model", + project_location="europe-west4", + project_id="my-project-id", + order_models_by="create_time desc", + fail_on_model_not_found=False, + model=Model(uri=str(tmpdir)), + ) + print(exported_model_resource_name) assert exported_model_resource_name == "" -def test_lookup_model_when_no_models_fail(tmpdir): +@mock.patch("google.cloud.aiplatform.Model") +def test_lookup_model_when_no_models_fail(mock_model, tmpdir): """ Checks that when there are no models and fail_on_model_found = True, lookup_model raises a RuntimeError. - - Args: - tmpdir: built-in pytest tmpdir fixture - - Returns: - None """ - with mock.patch("google.cloud.aiplatform.Model") as mock_model: - mock_model.list.return_value = [] + mock_model.list.return_value = [] - # Verify that a ValueError is raised - with pytest.raises(RuntimeError): - lookup_model( - model_name="my-model", - project_location="europe-west4", - project_id="my-project-id", - order_models_by="create_time desc", - fail_on_model_not_found=True, - model=Model(uri=str(tmpdir)), - ) + # Verify that a ValueError is raised + with pytest.raises(RuntimeError): + lookup_model( + model_name="my-model", + project_location="europe-west4", + project_id="my-project-id", + order_models_by="create_time desc", + fail_on_model_not_found=True, + model=Model(uri=str(tmpdir)), + ) diff --git a/components/vertex-components/tests/test_model_batch_predict.py b/components/vertex-components/tests/test_model_batch_predict.py index 10d9e808..76bca67d 100644 --- a/components/vertex-components/tests/test_model_batch_predict.py +++ b/components/vertex-components/tests/test_model_batch_predict.py @@ -13,7 +13,7 @@ # limitations under the License. import json import pytest -from unittest.mock import Mock, patch +from unittest import mock from kfp.v2.dsl import Model from google.cloud.aiplatform_v1beta1.types.job_state import JobState @@ -30,7 +30,19 @@ "targetField": "col", } +mock_job1 = mock.Mock() +mock_job1.name = "mock-batch-job" +mock_job1.state = JobState.JOB_STATE_SUCCEEDED + +@mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.create_batch_prediction_job", # noqa : E501 + return_value=mock_job1, +) +@mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_batch_prediction_job", # noqa : E501 + return_value=mock_job1, +) @pytest.mark.parametrize( ( "source_format,destination_format,source_uri,monitoring_training_dataset," @@ -44,6 +56,8 @@ ], ) def test_model_batch_predict( + create_job, + get_job, tmpdir, source_format, destination_format, @@ -55,37 +69,22 @@ def test_model_batch_predict( """ Asserts model_batch_predict successfully creates requests given different arguments. """ - mock_resource_name = "mock-batch-job" - - mock_job1 = Mock() - mock_job1.name = mock_resource_name - mock_job1.state = JobState.JOB_STATE_SUCCEEDED - mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) - with patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.create_batch_prediction_job", # noqa: E501 - return_value=mock_job1, - ) as create_job, patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_batch_prediction_job", # noqa: E501 - return_value=mock_job1, - ) as get_job: - (gcp_resources,) = model_batch_predict( - model=mock_model, - job_display_name="", - project_location="", - project_id="", - source_uri=source_uri, - destination_uri=destination_format, - source_format=source_format, - destination_format=destination_format, - monitoring_training_dataset=monitoring_training_dataset, - monitoring_alert_email_addresses=monitoring_alert_email_addresses, - monitoring_skew_config=monitoring_skew_config, - ) + (gcp_resources,) = model_batch_predict( + model=mock_model, + job_display_name="", + project_location="", + project_id="", + source_uri=source_uri, + destination_uri=destination_format, + source_format=source_format, + destination_format=destination_format, + monitoring_training_dataset=monitoring_training_dataset, + monitoring_alert_email_addresses=monitoring_alert_email_addresses, + monitoring_skew_config=monitoring_skew_config, + ) create_job.assert_called_once() get_job.assert_called_once() - assert ( - json.loads(gcp_resources)["resources"][0]["resourceUri"] == mock_resource_name - ) + assert json.loads(gcp_resources)["resources"][0]["resourceUri"] == mock_job1.name diff --git a/components/vertex-components/tests/test_update_best_model.py b/components/vertex-components/tests/test_update_best_model.py index bce9a9df..4ef6ba21 100644 --- a/components/vertex-components/tests/test_update_best_model.py +++ b/components/vertex-components/tests/test_update_best_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest import mock from kfp.v2.dsl import Model @@ -21,26 +21,31 @@ update_best_model = vertex_components.update_best_model.python_func -def test_model_batch_predict(tmpdir): +@mock.patch("google.cloud.aiplatform.Model") +@mock.patch("google.cloud.aiplatform.model_evaluation.ModelEvaluation") +@mock.patch("google.cloud.aiplatform.models.ModelRegistry") +@mock.patch("google.protobuf.json_format.MessageToDict") +def test_model_batch_predict( + mock_message_to_dict, + mock_model_registry, + mock_model_evaluation, + mock_model_class, + tmpdir, +): """ Asserts model_batch_predict successfully creates requests given different arguments. """ mock_model = Model(uri=tmpdir, metadata={"resourceName": ""}) mock_message = {"metrics": {"rmse": 0.01}} - - with patch("google.cloud.aiplatform.Model",), patch( - "google.cloud.aiplatform.model_evaluation.ModelEvaluation", - ), patch("google.cloud.aiplatform.models.ModelRegistry",), patch( - "google.protobuf.json_format.MessageToDict", return_value=mock_message - ): - - (challenger_wins,) = update_best_model( - challenger=mock_model, - challenger_evaluation="", - parent_model="", - project_id="", - project_location="", - eval_metric="rmse", - eval_lower_is_better=True, - ) - assert not challenger_wins + mock_message_to_dict.return_value = mock_message + + (challenger_wins,) = update_best_model( + challenger=mock_model, + challenger_evaluation="", + parent_model="", + project_id="", + project_location="", + eval_metric="rmse", + eval_lower_is_better=True, + ) + assert not challenger_wins