Skip to content

Commit

Permalink
feat: Add display experiment button for tuning in Ipython environments
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622285246
  • Loading branch information
matthew29tang authored and copybara-github committed Apr 5, 2024
1 parent 806ef9f commit 9bb687c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/unit/vertexai/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform.metadata import experiment_resources
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service
from google.cloud.aiplatform_v1.types import job_state
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job
Expand All @@ -35,6 +36,8 @@

import pytest

from unittest.mock import patch

from google.rpc import status_pb2


Expand Down Expand Up @@ -136,7 +139,14 @@ class MockTuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
)


@pytest.mark.usefixtures("google_auth_mock")
@pytest.fixture()
def experiment_init_mock():
with patch.object(experiment_resources.Experiment, "__init__") as experiment_mock:
experiment_mock.return_value = None
yield experiment_mock


@pytest.mark.usefixtures("google_auth_mock", "experiment_init_mock")
class TestgenerativeModelTuning:
"""Unit tests for generative model tuning."""

Expand Down
5 changes: 5 additions & 0 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform.utils import _ipython_utils
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service as gen_ai_tuning_service_v1
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
from google.cloud.aiplatform_v1 import types as gca_types
Expand Down Expand Up @@ -57,6 +58,7 @@ class TuningJob(aiplatform_base._VertexAiResourceNounPlus):
_parse_resource_name_method = "parse_tuning_job_path"
_format_resource_name_method = "tuning_job_path"
_job_type = "tuning/tuningJob"
_has_displayed_experiments_button = False

client_class = TuningJobClientWithOverride

Expand All @@ -74,6 +76,9 @@ def refresh(self) -> "TuningJob":
self._gca_resource: gca_tuning_job_types.TuningJob = (
self._get_gca_resource(resource_name=self.resource_name)
)
if self.experiment and not self._has_displayed_experiments_button:
self._has_displayed_experiments_button = True
_ipython_utils.display_experiment_button(self.experiment)
return self

@property
Expand Down

0 comments on commit 9bb687c

Please sign in to comment.