Skip to content

Commit

Permalink
feat: Made display_name parameter optional for most calls (#882)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)

Fixes #853 🦕
  • Loading branch information
Ark-kun authored Apr 5, 2022
1 parent 647d31f commit 400b760
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 33 deletions.
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,13 @@ def to_dict(self) -> Dict[str, Any]:
"""Returns the resource proto as a dictionary."""
return json_format.MessageToDict(self.gca_resource._pb)

@classmethod
def _generate_display_name(cls, prefix: Optional[str] = None) -> str:
"""Returns a display name containing class name and time string."""
if not prefix:
prefix = cls.__name__
return prefix + " " + datetime.datetime.now().isoformat(sep=" ")


def optional_sync(
construct_object_on_arg: Optional[str] = None,
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _validate_metadata_schema_uri(self) -> None:
@classmethod
def create(
cls,
# TODO(b/223262536): Make the display_name parameter optional in the next major release
display_name: str,
metadata_schema_uri: str,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -211,7 +212,8 @@ def create(
dataset (Dataset):
Instantiated representation of the managed dataset resource.
"""

if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down
6 changes: 4 additions & 2 deletions google/cloud/aiplatform/datasets/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ImageDataset(datasets._Dataset):
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
Expand All @@ -54,7 +54,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Dataset.
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Expand Down Expand Up @@ -129,6 +129,8 @@ def create(
image_dataset (ImageDataset):
Instantiated representation of the managed image dataset resource.
"""
if not display_name:
display_name = cls._generate_display_name()

utils.validate_display_name(display_name)
if labels:
Expand Down
7 changes: 4 additions & 3 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TabularDataset(datasets._ColumnNamesDataset):
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
Expand All @@ -52,7 +52,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Dataset.
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Expand Down Expand Up @@ -110,7 +110,8 @@ def create(
tabular_dataset (TabularDataset):
Instantiated representation of the managed tabular dataset resource.
"""

if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down
7 changes: 4 additions & 3 deletions google/cloud/aiplatform/datasets/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TextDataset(datasets._Dataset):
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
Expand All @@ -61,7 +61,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Dataset.
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Expand Down Expand Up @@ -136,7 +136,8 @@ def create(
text_dataset (TextDataset):
Instantiated representation of the managed text dataset resource.
"""

if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down
7 changes: 4 additions & 3 deletions google/cloud/aiplatform/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TimeSeriesDataset(datasets._ColumnNamesDataset):
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
Expand All @@ -51,7 +51,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Dataset.
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Expand Down Expand Up @@ -108,7 +108,8 @@ def create(
Instantiated representation of the managed time series dataset resource.
"""

if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down
7 changes: 4 additions & 3 deletions google/cloud/aiplatform/datasets/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class VideoDataset(datasets._Dataset):
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
Expand All @@ -54,7 +54,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Dataset.
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Expand Down Expand Up @@ -129,7 +129,8 @@ def create(
video_dataset (VideoDataset):
Instantiated representation of the managed video dataset resource.
"""

if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
@classmethod
def create(
cls,
# TODO(b/223262536): Make the job_display_name parameter optional in the next major release
job_display_name: str,
model_name: Union[str, "aiplatform.Model"],
instances_format: str = "jsonl",
Expand Down Expand Up @@ -537,6 +538,8 @@ def create(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
if not job_display_name:
job_display_name = cls._generate_display_name()

utils.validate_display_name(job_display_name)

Expand Down Expand Up @@ -1032,6 +1035,7 @@ class CustomJob(_RunnableJob):

def __init__(
self,
# TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
worker_pool_specs: Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]],
base_output_dir: Optional[str] = None,
Expand Down Expand Up @@ -1136,6 +1140,9 @@ def __init__(
staging_bucket, "aiplatform-custom-job"
)

if not display_name:
display_name = self.__class__._generate_display_name()

self._gca_resource = gca_custom_job_compat.CustomJob(
display_name=display_name,
job_spec=gca_custom_job_compat.CustomJobSpec(
Expand Down Expand Up @@ -1193,6 +1200,7 @@ def _log_web_access_uris(self):
@classmethod
def from_local_script(
cls,
# TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
script_path: str,
container_uri: str,
Expand Down Expand Up @@ -1521,6 +1529,7 @@ class HyperparameterTuningJob(_RunnableJob):

def __init__(
self,
# TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
custom_job: CustomJob,
metric_spec: Dict[str, str],
Expand Down Expand Up @@ -1717,6 +1726,9 @@ def __init__(
],
)

if not display_name:
display_name = self.__class__._generate_display_name()

self._gca_resource = (
gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
display_name=display_name,
Expand Down
32 changes: 23 additions & 9 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def network(self) -> Optional[str]:
@classmethod
def create(
cls,
display_name: str,
display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
Expand All @@ -212,7 +212,7 @@ def create(
Args:
display_name (str):
Required. The user-defined name of the Endpoint.
Optional. The user-defined name of the Endpoint.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
project (str):
Expand Down Expand Up @@ -263,6 +263,9 @@ def create(

api_client = cls._instantiate_client(location=location, credentials=credentials)

if not display_name:
display_name = cls._generate_display_name()

utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down Expand Up @@ -1654,7 +1657,6 @@ def update(
@base.optional_sync()
def upload(
cls,
display_name: str,
serving_container_image_uri: str,
*,
artifact_uri: Optional[str] = None,
Expand All @@ -1670,6 +1672,7 @@ def upload(
prediction_schema_uri: Optional[str] = None,
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -1692,7 +1695,7 @@ def upload(
Args:
display_name (str):
Required. The display name of the Model. The name can be up to 128
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
serving_container_image_uri (str):
Required. The URI of the Model serving container.
Expand Down Expand Up @@ -1832,6 +1835,8 @@ def upload(
is specified.
Also if model directory does not contain a supported model file.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
Expand Down Expand Up @@ -2231,7 +2236,7 @@ def _deploy(

def batch_predict(
self,
job_display_name: str,
job_display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bigquery_source: Optional[str] = None,
instances_format: str = "jsonl",
Expand Down Expand Up @@ -2269,7 +2274,7 @@ def batch_predict(
Args:
job_display_name (str):
Required. The user-defined name of the BatchPredictionJob.
Optional. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source: Optional[Sequence[str]] = None
Expand Down Expand Up @@ -2636,7 +2641,7 @@ def upload_xgboost_model_file(
cls,
model_file_path: str,
xgboost_version: Optional[str] = None,
display_name: str = "XGBoost model",
display_name: Optional[str] = None,
description: Optional[str] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
Expand Down Expand Up @@ -2769,6 +2774,9 @@ def upload_xgboost_model_file(
is specified.
Also if model directory does not contain a supported model file.
"""
if not display_name:
display_name = cls.__class__.__generate_display_name("XGBoost model")

XGBOOST_SUPPORTED_MODEL_FILE_EXTENSIONS = [
".pkl",
".joblib",
Expand Down Expand Up @@ -2835,7 +2843,7 @@ def upload_scikit_learn_model_file(
cls,
model_file_path: str,
sklearn_version: Optional[str] = None,
display_name: str = "Scikit-learn model",
display_name: Optional[str] = None,
description: Optional[str] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
Expand Down Expand Up @@ -2969,6 +2977,9 @@ def upload_scikit_learn_model_file(
is specified.
Also if model directory does not contain a supported model file.
"""
if not display_name:
display_name = cls._generate_display_name("Scikit-Learn model")

SKLEARN_SUPPORTED_MODEL_FILE_EXTENSIONS = [
".pkl",
".joblib",
Expand Down Expand Up @@ -3034,7 +3045,7 @@ def upload_tensorflow_saved_model(
saved_model_dir: str,
tensorflow_version: Optional[str] = None,
use_gpu: bool = False,
display_name: str = "Tensorflow model",
display_name: Optional[str] = None,
description: Optional[str] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
Expand Down Expand Up @@ -3170,6 +3181,9 @@ def upload_tensorflow_saved_model(
is specified.
Also if model directory does not contain a supported model file.
"""
if not display_name:
display_name = cls._generate_display_name("Tensorflow model")

container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
region=location,
framework="tensorflow",
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class PipelineJob(base.VertexAiStatefulResource):

def __init__(
self,
# TODO(b/223262536): Make the display_name parameter optional in the next major release
display_name: str,
template_path: str,
job_id: Optional[str] = None,
Expand Down Expand Up @@ -160,6 +161,8 @@ def __init__(
Raises:
ValueError: If job_id or labels have incorrect format.
"""
if not display_name:
display_name = self.__class__._generate_display_name()
utils.validate_display_name(display_name)

if labels:
Expand Down
Loading

0 comments on commit 400b760

Please sign in to comment.