Skip to content

Commit

Permalink
feat: add Pandas DataFrame support to TabularDataset (#1185)
Browse files Browse the repository at this point in the history
* add create_from_dataframe method

* add tests for create_from_dataframe

* update docstrings and run linter

* update docstrings and make display_name optional

* updates from sashas feedback: added integration test, update validations

* remove some logging

* update error handling on bq_schema arg

* updates from sashas feedback

* update bq_schema docstring
  • Loading branch information
sararob authored May 5, 2022
1 parent 5a2e2de commit 4fe4558
Show file tree
Hide file tree
Showing 4 changed files with 439 additions and 21 deletions.
114 changes: 113 additions & 1 deletion google/cloud/aiplatform/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,12 +19,18 @@

from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils

_AUTOML_TRAINING_MIN_ROWS = 1000

_LOGGER = base.Logger(__name__)


class TabularDataset(datasets._ColumnNamesDataset):
"""Managed tabular dataset resource for Vertex AI."""
Expand Down Expand Up @@ -146,6 +152,112 @@ def create(
create_request_timeout=create_request_timeout,
)

@classmethod
def create_from_dataframe(
cls,
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
staging_path: str,
bq_schema: Optional[Union[str, bigquery.SchemaField]] = None,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "TabularDataset":
"""Creates a new tabular dataset from a Pandas DataFrame.
Args:
df_source (pd.DataFrame):
Required. Pandas DataFrame containing the source data for
ingestion as a TabularDataset. This method will use the data
types from the provided DataFrame when creating the dataset.
staging_path (str):
Required. The BigQuery table to stage the data
for Vertex. Because Vertex maintains a reference to this source
to create the Vertex Dataset, this BigQuery table should
not be deleted. Example: `bq://my-project.my-dataset.my-table`.
If the provided BigQuery table doesn't exist, this method will
create the table. If the provided BigQuery table already exists,
and the schemas of the BigQuery table and your DataFrame match,
this method will append the data in your local DataFrame to the table.
The location of the provided BigQuery table should conform to the location requirements
specified here: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations.
bq_schema (Optional[Union[str, bigquery.SchemaField]]):
Optional. If not set, BigQuery will autodetect the schema using your DataFrame's column types.
If set, BigQuery will use the schema you provide when creating the staging table. For more details,
see: https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 charact
project (str):
Optional. Project to upload this dataset to. Overrides project set in
aiplatform.init.
location (str):
Optional. Location to upload this dataset to. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to upload this dataset. Overrides
credentials set in aiplatform.init.
Returns:
tabular_dataset (TabularDataset):
Instantiated representation of the managed tabular dataset resource.
"""

if staging_path.startswith("bq://"):
bq_staging_path = staging_path[len("bq://") :]
else:
raise ValueError(
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
)

try:
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
except ImportError:
raise ImportError(
"Pyarrow is not installed, and is required to use the BigQuery client."
'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
)

if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
_LOGGER.info(
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
)

bigquery_client = bigquery.Client(
project=project or initializer.global_config.project,
credentials=credentials or initializer.global_config.credentials,
)

try:
parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True

job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
parquet_options=parquet_options,
)

if bq_schema:
job_config.schema = bq_schema

job = bigquery_client.load_table_from_dataframe(
dataframe=df_source, destination=bq_staging_path, job_config=job_config
)

job.result()

finally:
dataset_from_dataframe = cls.create(
display_name=display_name,
bq_source=staging_path,
project=project,
location=location,
credentials=credentials,
)

return dataset_from_dataframe

def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
pipelines_extra_requires = [
"pyyaml>=5.3,<6",
]
datasets_extra_require = [
"pyarrow >= 3.0.0, < 8.0dev",
]
full_extra_require = list(
set(
tensorboard_extra_require
Expand All @@ -63,6 +66,7 @@
+ lit_extra_require
+ featurestore_extra_require
+ pipelines_extra_requires
+ datasets_extra_require
)
)
testing_extra_require = (
Expand Down
165 changes: 145 additions & 20 deletions tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import pytest
import importlib

import pandas as pd

from google import auth as google_auth
from google.api_core import exceptions
from google.api_core import client_options

from google.cloud import bigquery

from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform import utils
Expand All @@ -33,6 +37,8 @@

from test_utils.vpcsc_config import vpcsc_config

from tests.system.aiplatform import e2e_base

# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported
_, _TEST_PROJECT = google_auth.default()
TEST_BUCKET = os.environ.get(
Expand All @@ -55,40 +61,91 @@
_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml"
_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml"

# create_from_dataframe
_TEST_BOOL_COL = "bool_col"
_TEST_BOOL_ARR_COL = "bool_array_col"
_TEST_DOUBLE_COL = "double_col"
_TEST_DOUBLE_ARR_COL = "double_array_col"
_TEST_INT_COL = "int64_col"
_TEST_INT_ARR_COL = "int64_array_col"
_TEST_STR_COL = "string_col"
_TEST_STR_ARR_COL = "string_array_col"
_TEST_BYTES_COL = "bytes_col"
_TEST_DF_COLUMN_NAMES = [
_TEST_BOOL_COL,
_TEST_BOOL_ARR_COL,
_TEST_DOUBLE_COL,
_TEST_DOUBLE_ARR_COL,
_TEST_INT_COL,
_TEST_INT_ARR_COL,
_TEST_STR_COL,
_TEST_STR_ARR_COL,
_TEST_BYTES_COL,
]
_TEST_DATAFRAME = pd.DataFrame(
data=[
[
False,
[True, False],
1.2,
[1.2, 3.4],
1,
[1, 2],
"test",
["test1", "test2"],
b"1",
],
[
True,
[True, True],
2.2,
[2.2, 4.4],
2,
[2, 3],
"test1",
["test2", "test3"],
b"0",
],
],
columns=_TEST_DF_COLUMN_NAMES,
)
_TEST_DATAFRAME_BQ_SCHEMA = [
bigquery.SchemaField(name="bool_col", field_type="BOOL"),
bigquery.SchemaField(name="bool_array_col", field_type="BOOL", mode="REPEATED"),
bigquery.SchemaField(name="double_col", field_type="FLOAT"),
bigquery.SchemaField(name="double_array_col", field_type="FLOAT", mode="REPEATED"),
bigquery.SchemaField(name="int64_col", field_type="INTEGER"),
bigquery.SchemaField(name="int64_array_col", field_type="INTEGER", mode="REPEATED"),
bigquery.SchemaField(name="string_col", field_type="STRING"),
bigquery.SchemaField(name="string_array_col", field_type="STRING", mode="REPEATED"),
bigquery.SchemaField(name="bytes_col", field_type="STRING"),
]


@pytest.mark.usefixtures(
"prepare_staging_bucket",
"delete_staging_bucket",
"prepare_bigquery_dataset",
"delete_bigquery_dataset",
"tear_down_resources",
)
class TestDataset(e2e_base.TestEndToEnd):

_temp_prefix = "temp-vertex-sdk-dataset-test"

class TestDataset:
def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)

@pytest.fixture()
def shared_state(self):
shared_state = {}
yield shared_state

@pytest.fixture()
def create_staging_bucket(self, shared_state):
new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}"

storage_client = storage.Client()
storage_client.create_bucket(new_staging_bucket)
shared_state["storage_client"] = storage_client
shared_state["staging_bucket"] = new_staging_bucket
yield

@pytest.fixture()
def delete_staging_bucket(self, shared_state):
yield
storage_client = shared_state["storage_client"]

# Delete temp staging bucket
bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"])
bucket_to_delete.delete(force=True)

# Close Storage Client
storage_client._http._auth_request.session.close()
storage_client._http.close()

@pytest.fixture()
def dataset_gapic_client(self):
gapic_client = dataset_service.DatasetServiceClient(
Expand Down Expand Up @@ -253,6 +310,74 @@ def test_create_tabular_dataset(self, dataset_gapic_client, shared_state):
== aiplatform.schema.dataset.metadata.tabular
)

@pytest.mark.usefixtures("delete_new_dataset")
def test_create_tabular_dataset_from_dataframe(
self, dataset_gapic_client, shared_state
):
"""Use the Dataset.create_from_dataframe() method to create a new tabular dataset.
Then confirm the dataset was successfully created and references the BQ source."""

assert shared_state["bigquery_dataset"]

shared_state["resources"] = []

bigquery_dataset_id = shared_state["bigquery_dataset_id"]
bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}"

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

tabular_dataset = aiplatform.TabularDataset.create_from_dataframe(
df_source=_TEST_DATAFRAME,
staging_path=bq_staging_table,
display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}",
)
shared_state["resources"].extend([tabular_dataset])
shared_state["dataset_name"] = tabular_dataset.resource_name

gapic_metadata = tabular_dataset.to_dict()["metadata"]
bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"]

assert bq_staging_table == bq_source
assert (
tabular_dataset.metadata_schema_uri
== aiplatform.schema.dataset.metadata.tabular
)

@pytest.mark.usefixtures("delete_new_dataset")
def test_create_tabular_dataset_from_dataframe_with_provided_schema(
self, dataset_gapic_client, shared_state
):
"""Use the Dataset.create_from_dataframe() method to create a new tabular dataset,
passing in the optional `bq_schema` argument. Then confirm the dataset was successfully
created and references the BQ source."""

assert shared_state["bigquery_dataset"]

shared_state["resources"] = []

bigquery_dataset_id = shared_state["bigquery_dataset_id"]
bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}"

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

tabular_dataset = aiplatform.TabularDataset.create_from_dataframe(
df_source=_TEST_DATAFRAME,
staging_path=bq_staging_table,
display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}",
bq_schema=_TEST_DATAFRAME_BQ_SCHEMA,
)
shared_state["resources"].extend([tabular_dataset])
shared_state["dataset_name"] = tabular_dataset.resource_name

gapic_metadata = tabular_dataset.to_dict()["metadata"]
bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"]

assert bq_staging_table == bq_source
assert (
tabular_dataset.metadata_schema_uri
== aiplatform.schema.dataset.metadata.tabular
)

# TODO(vinnys): Remove pytest skip once persistent resources are accessible
@pytest.mark.skip(reason="System tests cannot access persistent test resources")
@pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket")
Expand Down
Loading

0 comments on commit 4fe4558

Please sign in to comment.