Skip to content

Commit

Permalink
fix: Add DeprecationWarning to vertexai.preview predictive models SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630461931
  • Loading branch information
yinghsienwu authored and copybara-github committed May 3, 2024
1 parent e8fe28d commit 3c3727b
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 10 deletions.
6 changes: 5 additions & 1 deletion vertexai/preview/_workflow/driver/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import abc
import inspect
from typing import Any, Callable, Dict, Optional, Type

import warnings
from vertexai.preview._workflow import driver
from vertexai.preview._workflow.executor import (
training,
Expand All @@ -27,6 +27,7 @@
any_serializer,
)
from vertexai.preview._workflow.shared import (
constants,
supported_frameworks,
)
from vertexai.preview.developer import remote_specs
Expand All @@ -41,6 +42,9 @@ def remote_method_decorator(
return driver.VertexRemoteFunctor(method, remote_executor, remote_executor_kwargs)


warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)


def remote_class_decorator(cls: Type) -> Type:
"""Add Vertex attributes to a class object."""

Expand Down
18 changes: 18 additions & 0 deletions vertexai/preview/_workflow/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,21 @@

_START_EXECUTION_MSG = "Start remote execution on Vertex..."
_END_EXECUTION_MSG = "Remote execution is completed."

_V2_0_WARNING_MSG = """
After May 30, 2024, importing any code below will result in an error.
Please verify that you are explicitly pinning to a version of `google-cloud-aiplatform`
(e.g., google-cloud-aiplatform==[1.32.0, 1.49.0]) if you need to continue using this
library.
from vertexai.preview import (
init,
remote,
VertexModel,
register,
from_pretrained,
developer,
hyperparameter_tuning,
tabular_models,
)
"""
10 changes: 5 additions & 5 deletions vertexai/preview/_workflow/shared/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import re
from typing import Any, Dict, Optional, Union
import warnings

from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand All @@ -35,14 +36,12 @@
any_serializer,
serializers_base,
)
from vertexai.preview._workflow.shared import constants

# These need to be imported to be included in _ModelGardenModel.__init_subclass__
from vertexai.language_models import (
_language_models,
) # pylint:disable=unused-import
from vertexai.vision_models import (
_vision_models,
) # pylint:disable=unused-import
from vertexai._model_garden import _model_garden_models
from google.cloud.aiplatform import _publisher_models
from vertexai.preview._workflow.executor import training
Expand All @@ -60,9 +59,10 @@
_OUTPUT_ESTIMATOR_DIR = "output_estimator"
_OUTPUT_PREDICTIONS_DIR = "output_predictions"


_LOGGER = base.Logger("vertexai.remote_execution")

warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)


def _get_model_file_from_image_uri(container_image_uri: str) -> str:
"""Gets the model file from the container image URI.
Expand Down Expand Up @@ -121,7 +121,7 @@ def _generate_remote_job_output_path(base_gcs_dir: str) -> str:

def _get_model_from_successful_custom_job(
job_dir: str,
) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]:
) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: # noqa: F821

serializer = any_serializer.AnySerializer()

Expand Down
8 changes: 7 additions & 1 deletion vertexai/preview/developer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
# limitations under the License.
#

import warnings
from vertexai.preview._workflow.serialization_engine import (
any_serializer,
)
from vertexai.preview._workflow.serialization_engine import (
serializers_base,
)
from vertexai.preview._workflow.shared import configs
from vertexai.preview._workflow.shared import (
configs,
constants,
)
from vertexai.preview.developer import mark
from vertexai.preview.developer import remote_specs


warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)

PersistentResourceConfig = configs.PersistentResourceConfig
Serializer = serializers_base.Serializer
SerializationMetadata = serializers_base.SerializationMetadata
Expand Down
5 changes: 4 additions & 1 deletion vertexai/preview/hyperparameter_tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# limitations under the License.
#


import warnings
from vertexai.preview.hyperparameter_tuning import (
vizier_hyperparameter_tuner,
)
from vertexai.preview._workflow.shared import constants


warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)

VizierHyperparameterTuner = vizier_hyperparameter_tuner.VizierHyperparameterTuner

Expand Down
8 changes: 6 additions & 2 deletions vertexai/preview/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
#

from typing import Optional

import warnings
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from vertexai.preview._workflow.executor import (
persistent_resource_util,
)
from vertexai.preview._workflow.shared import configs
from vertexai.preview._workflow.shared import (
configs,
constants,
)


_LOGGER = base.Logger(__name__)
Expand All @@ -30,6 +33,7 @@ class _Config:
"""Store common configurations and current workflow for remote execution."""

def __init__(self):
warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)
self._remote = False
self._cluster = None

Expand Down
4 changes: 4 additions & 0 deletions vertexai/preview/tabular_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# limitations under the License.
#

import warnings

from vertexai.preview._workflow.shared import constants
from vertexai.preview.tabular_models import tabnet_trainer


warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1)

TabNetTrainer = tabnet_trainer.TabNetTrainer


Expand Down

0 comments on commit 3c3727b

Please sign in to comment.