Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add tests #15

Merged
merged 3 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ test-all-components: ## Run unit tests for all pipeline components
$(MAKE) test-components GROUP=$$(basename $$component_group) ; \
done

test-components-coverage: ## Run tests with coverage
@cd "components/${GROUP}" && \
pipenv run coverage run -m pytest && \
pipenv run coverage report -m

test-all-components-coverage: ## Run tests with coverage
@set -e && \
for component_group in components/*/ ; do \
echo "Test components under $$component_group" && \
$(MAKE) test-components-coverage GROUP=$$(basename $$component_group) ; \
done

sync-assets: ## Sync assets folder to GCS. Must specify pipeline=<training|prediction>
@if [ -d "./pipelines/src/pipelines/${PIPELINE_TEMPLATE}/$(pipeline)/assets/" ] ; then \
echo "Syncing assets to GCS" && \
Expand Down
64 changes: 64 additions & 0 deletions components/bigquery-components/tests/test_extract_bq_to_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import google.cloud.bigquery # noqa
from kfp.v2.dsl import Dataset
from unittest import mock

import bigquery_components

extract_bq_to_dataset = bigquery_components.extract_bq_to_dataset.python_func


@mock.patch("google.cloud.bigquery.client.Client")
@mock.patch("google.cloud.bigquery.table.Table")
@mock.patch("google.cloud.bigquery.job.ExtractJobConfig")
def test_extract_bq_to_dataset(mock_job_config, mock_table, mock_client, tmpdir):
"""
Checks that the extract_bq_to_dataset is called correctly
"""
mock_path = tmpdir
mock_client.extract_table.return_value = "my-job"
mock_table.return_value.table_ref = "my-table"
mock_job_config.return_value = "mock-job-config"

extract_bq_to_dataset(
bq_client_project_id="my-project-id",
source_project_id="source-project-id",
dataset_id="dataset-id",
table_name="table-name",
dataset=Dataset(uri=mock_path),
destination_gcs_uri="gs://mock_bucket",
dataset_location="EU",
extract_job_config=None,
skip_if_exists=False,
)

mock_client.return_value.extract_table.assert_called_once_with(
mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config"
)


@mock.patch("google.cloud.bigquery.client.Client")
@mock.patch("google.cloud.bigquery.table.Table")
@mock.patch("google.cloud.bigquery.job.ExtractJobConfig")
@mock.patch("pathlib.Path.exists")
def test_extract_bq_to_dataset_skip_existing(
mock_path_exists, mock_job_config, mock_table, mock_client, tmpdir
):
"""
Checks that when the dataset exists the method is not called
"""
mock_path = tmpdir
mock_path_exists.return_value = True

extract_bq_to_dataset(
bq_client_project_id="my-project-id",
source_project_id="source-project-id",
dataset_id="dataset-id",
table_name="table-name",
dataset=Dataset(uri=mock_path),
destination_gcs_uri="gs://mock_bucket",
dataset_location="EU",
extract_job_config=None,
skip_if_exists=True,
)

assert not mock_client.return_value.extract_table.called
93 changes: 93 additions & 0 deletions components/vertex-components/tests/test_custom_training_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import google.cloud.aiplatform as aip # noqa
from kfp.v2.dsl import Dataset, Metrics, Artifact
from unittest import mock
import pytest


import vertex_components

custom_train_job = vertex_components.custom_train_job.python_func


@mock.patch("google.cloud.aiplatform.CustomTrainingJob")
@mock.patch("os.path.exists")
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}")
def test_custom_train_job(mock_open, mock_exists, mock_job, tmpdir):
"""
Checks that the custom job method is called
"""
mock_exists.return_value = True

mock_train_data = Dataset(uri=tmpdir)
mock_valid_data = Dataset(uri=tmpdir)
mock_test_data = Dataset(uri=tmpdir)
mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)

custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
test_data=mock_test_data,
project_id="my-project-id",
project_location="europe-west4",
model_display_name="my-model",
train_container_uri="gcr.io/my-project/my-image:latest",
serving_container_uri="gcr.io/my-project/my-serving-image:latest",
model=mock_model,
metrics=mock_metrics,
staging_bucket="gs://my-bucket",
job_name="my-job",
)

mock_job.assert_called_once_with(
project="my-project-id",
location="europe-west4",
staging_bucket="gs://my-bucket",
display_name="my-job",
script_path="/gcs/my-bucket/train_script.py",
container_uri="gcr.io/my-project/my-image:latest",
requirements=None,
model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501
)

# Assert metrics loading
mock_open.assert_called_once_with(tmpdir, "r")


@mock.patch("google.cloud.aiplatform.CustomTrainingJob")
@mock.patch("os.path.exists")
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}")
def test_custom_train_script_not_found(mock_open, mock_exists, mock_job, tmpdir):
"""
Checks that when the training script is not found
the method fails
"""
mock_exists.return_value = False

mock_train_data = Dataset(uri=tmpdir)
mock_valid_data = Dataset(uri=tmpdir)
mock_test_data = Dataset(uri=tmpdir)
mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)

with pytest.raises(ValueError):
custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
test_data=mock_test_data,
project_id="my-project-id",
project_location="europe-west4",
model_display_name="my-model",
train_container_uri="gcr.io/my-project/my-image:latest",
serving_container_uri="gcr.io/my-project/my-serving-image:latest",
model=mock_model,
metrics=mock_metrics,
staging_bucket="gs://my-bucket",
job_name="my-job",
)

# Assert the custom training job is not executed
mock_job.assert_not_called()
mock_open.assert_not_called()
57 changes: 57 additions & 0 deletions components/vertex-components/tests/test_import_model_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest import mock
from kfp.v2.dsl import Model, Metrics, Dataset

import vertex_components
from google.cloud.aiplatform_v1 import ModelEvaluation


import_model_evaluation = vertex_components.import_model_evaluation.python_func


@mock.patch("google.cloud.aiplatform_v1.ModelServiceClient")
@mock.patch(
"builtins.open",
new_callable=mock.mock_open,
read_data='{"accuracy": 0.85, "problemType": "classification"}',
)
@mock.patch("google.protobuf.json_format.ParseDict")
def test_import_model_evaluation(
mock_parse_dict, mock_open, mock_service_client, tmpdir
):
"""
Checks that when the model evaluation is running and it is writing the metrics
"""
mock_model = Model(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)
mock_dataset = Dataset(uri=tmpdir)

# Create an instance of the mocked ModelServiceClient.
service_client_instance = mock.MagicMock()
mock_service_client.return_value = service_client_instance
# When import_model_evaluation is called during the test,
# it will return a new ModelEvaluation with the specified name.
service_client_instance.import_model_evaluation.return_value = ModelEvaluation(
name="model_evaluation_name"
)

# Set the return value for ParseDict to be a mock ModelEvaluation
mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation)

model_evaluation_name = import_model_evaluation(
model=mock_model,
metrics=mock_metrics,
test_dataset=mock_dataset,
pipeline_job_id="1234",
project_location="my-location",
evaluation_name="Imported evaluation",
)

service_client_instance.import_model_evaluation.assert_called_once_with(
parent=mock_model.metadata["resourceName"],
model_evaluation=mock_parse_dict.return_value,
)

# Check that open was called with the correct path
mock_open.assert_called_once_with(mock_metrics.uri)

assert model_evaluation_name[0] == "model_evaluation_name"
126 changes: 54 additions & 72 deletions components/vertex-components/tests/test_lookup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,93 +22,75 @@
lookup_model = vertex_components.lookup_model.python_func


def test_lookup_model(tmpdir):
@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model(mock_model, tmpdir):
"""
Assert lookup_model produces expected resource name, and that list method is
called with the correct arguemnts

Args:
tmpdir: built-in pytest tmpdir fixture

Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:

# Mock attribute and method

mock_path = tmpdir
mock_model.resource_name = "my-model-resource-name"
mock_model.uri = mock_path
mock_model.list.return_value = [mock_model]

# Invoke the model look up
found_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=mock_path),
)

assert found_model_resource_name == "my-model-resource-name"

# Check the list method was called once with the correct arguments
mock_model.list.assert_called_once_with(
filter='display_name="my-model"',
order_by="create_time desc",
location="europe-west4",
project="my-project-id",
)


def test_lookup_model_when_no_models(tmpdir):
# Mock attribute and method
mock_path = tmpdir
mock_model.resource_name = "my-model-resource-name"
mock_model.uri = mock_path
mock_model.list.return_value = [mock_model]

# Invoke the model look up
found_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=mock_path),
)

assert found_model_resource_name == "my-model-resource-name"

# Check the list method was called once with the correct arguments
mock_model.list.assert_called_once_with(
filter='display_name="my-model"',
order_by="create_time desc",
location="europe-west4",
project="my-project-id",
)


@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models(mock_model, tmpdir):
"""
Checks that when there are no models and fail_on_model_found = False,
lookup_model returns an empty string.

Args:
tmpdir: built-in pytest tmpdir fixture

Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:
mock_model.list.return_value = []
exported_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=str(tmpdir)),
)
mock_model.list.return_value = []
exported_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=str(tmpdir)),
)

print(exported_model_resource_name)
assert exported_model_resource_name == ""


def test_lookup_model_when_no_models_fail(tmpdir):
@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models_fail(mock_model, tmpdir):
"""
Checks that when there are no models and fail_on_model_found = True,
lookup_model raises a RuntimeError.

Args:
tmpdir: built-in pytest tmpdir fixture

Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:
mock_model.list.return_value = []
mock_model.list.return_value = []

# Verify that a ValueError is raised
with pytest.raises(RuntimeError):
lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=True,
model=Model(uri=str(tmpdir)),
)
# Verify that a ValueError is raised
with pytest.raises(RuntimeError):
lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=True,
model=Model(uri=str(tmpdir)),
)
Loading