Skip to content

Commit

Permalink
feat: Added forecasting snippets and fixed bugs with existing snippets (
Browse files Browse the repository at this point in the history
#1210)

* Added dataset snippets

* Fixed typehint and missing parameter bugs as well as added new samples

* Fixed lint issues

* Added bq batch_prediction bq snippets

* Removed unneeded fixture

* Renamed bq_source to bigquery_source

* Added back explain_tabular_sample.py for now

* Fixed tests

* Fixed lint issues
  • Loading branch information
ivanmkc authored May 31, 2022
1 parent 0036ab0 commit 4e4bff5
Show file tree
Hide file tree
Showing 24 changed files with 727 additions and 31 deletions.
41 changes: 41 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def mock_tabular_dataset():
yield mock


@pytest.fixture
def mock_time_series_dataset():
mock = MagicMock(aiplatform.datasets.TimeSeriesDataset)
yield mock


@pytest.fixture
def mock_text_dataset():
mock = MagicMock(aiplatform.datasets.TextDataset)
Expand Down Expand Up @@ -74,6 +80,13 @@ def mock_get_tabular_dataset(mock_tabular_dataset):
yield mock_get_tabular_dataset


@pytest.fixture
def mock_get_time_series_dataset(mock_time_series_dataset):
with patch.object(aiplatform, "TimeSeriesDataset") as mock_get_time_series_dataset:
mock_get_time_series_dataset.return_value = mock_time_series_dataset
yield mock_get_time_series_dataset


@pytest.fixture
def mock_get_text_dataset(mock_text_dataset):
with patch.object(aiplatform, "TextDataset") as mock_get_text_dataset:
Expand Down Expand Up @@ -107,6 +120,15 @@ def mock_create_tabular_dataset(mock_tabular_dataset):
yield mock_create_tabular_dataset


@pytest.fixture
def mock_create_time_series_dataset(mock_time_series_dataset):
with patch.object(
aiplatform.TimeSeriesDataset, "create"
) as mock_create_time_series_dataset:
mock_create_time_series_dataset.return_value = mock_time_series_dataset
yield mock_create_time_series_dataset


@pytest.fixture
def mock_create_text_dataset(mock_text_dataset):
with patch.object(aiplatform.TextDataset, "create") as mock_create_text_dataset:
Expand Down Expand Up @@ -183,6 +205,12 @@ def mock_tabular_training_job():
yield mock


@pytest.fixture
def mock_forecasting_training_job():
mock = MagicMock(aiplatform.training_jobs.AutoMLForecastingTrainingJob)
yield mock


@pytest.fixture
def mock_text_training_job():
mock = MagicMock(aiplatform.training_jobs.AutoMLTextTrainingJob)
Expand All @@ -208,6 +236,19 @@ def mock_run_automl_tabular_training_job(mock_tabular_training_job):
yield mock


@pytest.fixture
def mock_get_automl_forecasting_training_job(mock_forecasting_training_job):
with patch.object(aiplatform, "AutoMLForecastingTrainingJob") as mock:
mock.return_value = mock_forecasting_training_job
yield mock


@pytest.fixture
def mock_run_automl_forecasting_training_job(mock_forecasting_training_job):
with patch.object(mock_forecasting_training_job, "run") as mock:
yield mock


@pytest.fixture
def mock_get_automl_image_training_job(mock_image_training_job):
with patch.object(aiplatform, "AutoMLImageTrainingJob") as mock:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@

# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
def create_and_import_dataset_tabular_bigquery_sample(
display_name: str, project: str, location: str, bq_source: str,
display_name: str,
project: str,
location: str,
bigquery_source: str,
):

aiplatform.init(project=project, location=location)

dataset = aiplatform.TabularDataset.create(
display_name=display_name, bq_source=bq_source,
display_name=display_name,
bigquery_source=bigquery_source,
)

dataset.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ def test_create_and_import_dataset_tabular_bigquery_sample(
create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample(
project=constants.PROJECT,
location=constants.LOCATION,
bq_source=constants.BIGQUERY_SOURCE,
bigquery_source=constants.BIGQUERY_SOURCE,
display_name=constants.DISPLAY_NAME,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
mock_create_tabular_dataset.assert_called_once_with(
display_name=constants.DISPLAY_NAME, bq_source=constants.BIGQUERY_SOURCE,
display_name=constants.DISPLAY_NAME,
bigquery_source=constants.BIGQUERY_SOURCE,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform


# [START aiplatform_sdk_create_and_import_dataset_time_series_bigquery_sample]
def create_and_import_dataset_time_series_bigquery_sample(
display_name: str,
project: str,
location: str,
bigquery_source: str,
):

aiplatform.init(project=project, location=location)

dataset = aiplatform.TimeSeriesDataset.create(
display_name=display_name,
bigquery_source=bigquery_source,
)

dataset.wait()

print(f'\tDataset: "{dataset.display_name}"')
print(f'\tname: "{dataset.resource_name}"')


# [END aiplatform_sdk_create_and_import_dataset_time_series_bigquery_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import create_and_import_dataset_time_series_bigquery_sample
import test_constants as constants


def test_create_and_import_dataset_time_series_bigquery_sample(
mock_sdk_init, mock_create_time_series_dataset
):

create_and_import_dataset_time_series_bigquery_sample.create_and_import_dataset_time_series_bigquery_sample(
project=constants.PROJECT,
location=constants.LOCATION,
bigquery_source=constants.BIGQUERY_SOURCE,
display_name=constants.DISPLAY_NAME,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
mock_create_time_series_dataset.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
bigquery_source=constants.BIGQUERY_SOURCE,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

from google.cloud import aiplatform


# [START aiplatform_sdk_create_and_import_dataset_time_series_gcs_sample]
def create_and_import_dataset_time_series_gcs_sample(
display_name: str,
project: str,
location: str,
gcs_source: Union[str, List[str]],
):

aiplatform.init(project=project, location=location)

dataset = aiplatform.TimeSeriesDataset.create(
display_name=display_name,
gcs_source=gcs_source,
)

dataset.wait()

print(f'\tDataset: "{dataset.display_name}"')
print(f'\tname: "{dataset.resource_name}"')


# [END aiplatform_sdk_create_and_import_dataset_time_series_gcs_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import create_and_import_dataset_time_series_gcs_sample
import test_constants as constants


def test_create_and_import_dataset_time_series_gcs_sample(
mock_sdk_init, mock_create_time_series_dataset
):

create_and_import_dataset_time_series_gcs_sample.create_and_import_dataset_time_series_gcs_sample(
project=constants.PROJECT,
location=constants.LOCATION,
gcs_source=constants.GCS_SOURCES,
display_name=constants.DISPLAY_NAME,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
mock_create_time_series_dataset.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
gcs_source=constants.GCS_SOURCES,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_create_batch_prediction_job_bigquery_sample]
def create_batch_prediction_job_bigquery_sample(
project: str,
location: str,
model_resource_name: str,
job_display_name: str,
bigquery_source: str,
bigquery_destination_prefix: str,
sync: bool = True,
):
aiplatform.init(project=project, location=location)

my_model = aiplatform.Model(model_resource_name)

batch_prediction_job = my_model.batch_predict(
job_display_name=job_display_name,
bigquery_source=bigquery_source,
bigquery_destination_prefix=bigquery_destination_prefix,
sync=sync,
)

batch_prediction_job.wait()

print(batch_prediction_job.display_name)
print(batch_prediction_job.resource_name)
print(batch_prediction_job.state)
return batch_prediction_job


# [END aiplatform_sdk_create_batch_prediction_job_bigquery_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import create_batch_prediction_job_bigquery_sample
import test_constants as constants


def test_create_batch_prediction_job_bigquery_sample(
mock_sdk_init, mock_model, mock_init_model, mock_batch_predict_model
):

create_batch_prediction_job_bigquery_sample.create_batch_prediction_job_bigquery_sample(
project=constants.PROJECT,
location=constants.LOCATION,
model_resource_name=constants.MODEL_NAME,
job_display_name=constants.DISPLAY_NAME,
bigquery_source=constants.BIGQUERY_SOURCE,
bigquery_destination_prefix=constants.BIGQUERY_DESTINATION_PREFIX,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
mock_init_model.assert_called_once_with(constants.MODEL_NAME)
mock_batch_predict_model.assert_called_once_with(
job_display_name=constants.DISPLAY_NAME,
bigquery_source=constants.BIGQUERY_SOURCE,
bigquery_destination_prefix=constants.BIGQUERY_DESTINATION_PREFIX,
sync=True,
)
Loading

0 comments on commit 4e4bff5

Please sign in to comment.