Skip to content

Commit

Permalink
Merge pull request #15 from teamdatatonic/feature/add-tests
Browse files Browse the repository at this point in the history
feature: add tests
  • Loading branch information
ariadnafer authored May 26, 2023
2 parents fe3d9d9 + 6d9c43f commit 3238801
Show file tree
Hide file tree
Showing 10 changed files with 458 additions and 187 deletions.
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ test-all-components: ## Run unit tests for all pipeline components
$(MAKE) test-components GROUP=$$(basename $$component_group) ; \
done

test-components-coverage: ## Run tests with coverage
@cd "components/${GROUP}" && \
pipenv run coverage run -m pytest && \
pipenv run coverage report -m

test-all-components-coverage: ## Run tests with coverage
@set -e && \
for component_group in components/*/ ; do \
echo "Test components under $$component_group" && \
$(MAKE) test-components-coverage GROUP=$$(basename $$component_group) ; \
done

sync-assets: ## Sync assets folder to GCS. Must specify pipeline=<training|prediction>
@if [ -d "./pipelines/src/pipelines/${PIPELINE_TEMPLATE}/$(pipeline)/assets/" ] ; then \
echo "Syncing assets to GCS" && \
Expand Down
64 changes: 64 additions & 0 deletions components/bigquery-components/tests/test_extract_bq_to_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import google.cloud.bigquery # noqa
from kfp.v2.dsl import Dataset
from unittest import mock

import bigquery_components

extract_bq_to_dataset = bigquery_components.extract_bq_to_dataset.python_func


@mock.patch("google.cloud.bigquery.client.Client")
@mock.patch("google.cloud.bigquery.table.Table")
@mock.patch("google.cloud.bigquery.job.ExtractJobConfig")
def test_extract_bq_to_dataset(mock_job_config, mock_table, mock_client, tmpdir):
"""
Checks that the extract_bq_to_dataset is called correctly
"""
mock_path = tmpdir
mock_client.extract_table.return_value = "my-job"
mock_table.return_value.table_ref = "my-table"
mock_job_config.return_value = "mock-job-config"

extract_bq_to_dataset(
bq_client_project_id="my-project-id",
source_project_id="source-project-id",
dataset_id="dataset-id",
table_name="table-name",
dataset=Dataset(uri=mock_path),
destination_gcs_uri="gs://mock_bucket",
dataset_location="EU",
extract_job_config=None,
skip_if_exists=False,
)

mock_client.return_value.extract_table.assert_called_once_with(
mock_table.return_value, "gs://mock_bucket", job_config="mock-job-config"
)


@mock.patch("google.cloud.bigquery.client.Client")
@mock.patch("google.cloud.bigquery.table.Table")
@mock.patch("google.cloud.bigquery.job.ExtractJobConfig")
@mock.patch("pathlib.Path.exists")
def test_extract_bq_to_dataset_skip_existing(
mock_path_exists, mock_job_config, mock_table, mock_client, tmpdir
):
"""
Checks that when the dataset exists the method is not called
"""
mock_path = tmpdir
mock_path_exists.return_value = True

extract_bq_to_dataset(
bq_client_project_id="my-project-id",
source_project_id="source-project-id",
dataset_id="dataset-id",
table_name="table-name",
dataset=Dataset(uri=mock_path),
destination_gcs_uri="gs://mock_bucket",
dataset_location="EU",
extract_job_config=None,
skip_if_exists=True,
)

assert not mock_client.return_value.extract_table.called
93 changes: 93 additions & 0 deletions components/vertex-components/tests/test_custom_training_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import google.cloud.aiplatform as aip # noqa
from kfp.v2.dsl import Dataset, Metrics, Artifact
from unittest import mock
import pytest


import vertex_components

custom_train_job = vertex_components.custom_train_job.python_func


@mock.patch("google.cloud.aiplatform.CustomTrainingJob")
@mock.patch("os.path.exists")
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}")
def test_custom_train_job(mock_open, mock_exists, mock_job, tmpdir):
"""
Checks that the custom job method is called
"""
mock_exists.return_value = True

mock_train_data = Dataset(uri=tmpdir)
mock_valid_data = Dataset(uri=tmpdir)
mock_test_data = Dataset(uri=tmpdir)
mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)

custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
test_data=mock_test_data,
project_id="my-project-id",
project_location="europe-west4",
model_display_name="my-model",
train_container_uri="gcr.io/my-project/my-image:latest",
serving_container_uri="gcr.io/my-project/my-serving-image:latest",
model=mock_model,
metrics=mock_metrics,
staging_bucket="gs://my-bucket",
job_name="my-job",
)

mock_job.assert_called_once_with(
project="my-project-id",
location="europe-west4",
staging_bucket="gs://my-bucket",
display_name="my-job",
script_path="/gcs/my-bucket/train_script.py",
container_uri="gcr.io/my-project/my-image:latest",
requirements=None,
model_serving_container_image_uri="gcr.io/my-project/my-serving-image:latest", # noqa: E501
)

# Assert metrics loading
mock_open.assert_called_once_with(tmpdir, "r")


@mock.patch("google.cloud.aiplatform.CustomTrainingJob")
@mock.patch("os.path.exists")
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="{}")
def test_custom_train_script_not_found(mock_open, mock_exists, mock_job, tmpdir):
"""
Checks that when the training script is not found
the method fails
"""
mock_exists.return_value = False

mock_train_data = Dataset(uri=tmpdir)
mock_valid_data = Dataset(uri=tmpdir)
mock_test_data = Dataset(uri=tmpdir)
mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)

with pytest.raises(ValueError):
custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
test_data=mock_test_data,
project_id="my-project-id",
project_location="europe-west4",
model_display_name="my-model",
train_container_uri="gcr.io/my-project/my-image:latest",
serving_container_uri="gcr.io/my-project/my-serving-image:latest",
model=mock_model,
metrics=mock_metrics,
staging_bucket="gs://my-bucket",
job_name="my-job",
)

# Assert the custom training job is not executed
mock_job.assert_not_called()
mock_open.assert_not_called()
57 changes: 57 additions & 0 deletions components/vertex-components/tests/test_import_model_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest import mock
from kfp.v2.dsl import Model, Metrics, Dataset

import vertex_components
from google.cloud.aiplatform_v1 import ModelEvaluation


import_model_evaluation = vertex_components.import_model_evaluation.python_func


@mock.patch("google.cloud.aiplatform_v1.ModelServiceClient")
@mock.patch(
"builtins.open",
new_callable=mock.mock_open,
read_data='{"accuracy": 0.85, "problemType": "classification"}',
)
@mock.patch("google.protobuf.json_format.ParseDict")
def test_import_model_evaluation(
mock_parse_dict, mock_open, mock_service_client, tmpdir
):
"""
Checks that when the model evaluation is running and it is writing the metrics
"""
mock_model = Model(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)
mock_dataset = Dataset(uri=tmpdir)

# Create an instance of the mocked ModelServiceClient.
service_client_instance = mock.MagicMock()
mock_service_client.return_value = service_client_instance
# When import_model_evaluation is called during the test,
# it will return a new ModelEvaluation with the specified name.
service_client_instance.import_model_evaluation.return_value = ModelEvaluation(
name="model_evaluation_name"
)

# Set the return value for ParseDict to be a mock ModelEvaluation
mock_parse_dict.return_value = mock.MagicMock(spec=ModelEvaluation)

model_evaluation_name = import_model_evaluation(
model=mock_model,
metrics=mock_metrics,
test_dataset=mock_dataset,
pipeline_job_id="1234",
project_location="my-location",
evaluation_name="Imported evaluation",
)

service_client_instance.import_model_evaluation.assert_called_once_with(
parent=mock_model.metadata["resourceName"],
model_evaluation=mock_parse_dict.return_value,
)

# Check that open was called with the correct path
mock_open.assert_called_once_with(mock_metrics.uri)

assert model_evaluation_name[0] == "model_evaluation_name"
126 changes: 54 additions & 72 deletions components/vertex-components/tests/test_lookup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,93 +22,75 @@
lookup_model = vertex_components.lookup_model.python_func


def test_lookup_model(tmpdir):
@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model(mock_model, tmpdir):
"""
Assert lookup_model produces expected resource name, and that list method is
called with the correct arguemnts
Args:
tmpdir: built-in pytest tmpdir fixture
Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:

# Mock attribute and method

mock_path = tmpdir
mock_model.resource_name = "my-model-resource-name"
mock_model.uri = mock_path
mock_model.list.return_value = [mock_model]

# Invoke the model look up
found_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=mock_path),
)

assert found_model_resource_name == "my-model-resource-name"

# Check the list method was called once with the correct arguments
mock_model.list.assert_called_once_with(
filter='display_name="my-model"',
order_by="create_time desc",
location="europe-west4",
project="my-project-id",
)


def test_lookup_model_when_no_models(tmpdir):
# Mock attribute and method
mock_path = tmpdir
mock_model.resource_name = "my-model-resource-name"
mock_model.uri = mock_path
mock_model.list.return_value = [mock_model]

# Invoke the model look up
found_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=mock_path),
)

assert found_model_resource_name == "my-model-resource-name"

# Check the list method was called once with the correct arguments
mock_model.list.assert_called_once_with(
filter='display_name="my-model"',
order_by="create_time desc",
location="europe-west4",
project="my-project-id",
)


@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models(mock_model, tmpdir):
"""
Checks that when there are no models and fail_on_model_found = False,
lookup_model returns an empty string.
Args:
tmpdir: built-in pytest tmpdir fixture
Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:
mock_model.list.return_value = []
exported_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=str(tmpdir)),
)
mock_model.list.return_value = []
exported_model_resource_name, _ = lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=False,
model=Model(uri=str(tmpdir)),
)

print(exported_model_resource_name)
assert exported_model_resource_name == ""


def test_lookup_model_when_no_models_fail(tmpdir):
@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models_fail(mock_model, tmpdir):
"""
Checks that when there are no models and fail_on_model_found = True,
lookup_model raises a RuntimeError.
Args:
tmpdir: built-in pytest tmpdir fixture
Returns:
None
"""
with mock.patch("google.cloud.aiplatform.Model") as mock_model:
mock_model.list.return_value = []
mock_model.list.return_value = []

# Verify that a ValueError is raised
with pytest.raises(RuntimeError):
lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=True,
model=Model(uri=str(tmpdir)),
)
# Verify that a ValueError is raised
with pytest.raises(RuntimeError):
lookup_model(
model_name="my-model",
project_location="europe-west4",
project_id="my-project-id",
order_models_by="create_time desc",
fail_on_model_not_found=True,
model=Model(uri=str(tmpdir)),
)
Loading

0 comments on commit 3238801

Please sign in to comment.