From 1cc22c3c3561f7c6374d32fafd45839256064958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ila=C3=AF=20Deutel?= Date: Thu, 23 May 2024 11:02:47 -0700 Subject: [PATCH] fix: GenAI - Tuning - Supervised - Fix `adapter_size` parameter handling to match enum values. PiperOrigin-RevId: 636608417 --- tests/unit/vertexai/test_tuning.py | 1 + vertexai/tuning/_supervised_tuning.py | 80 +++++++++++++++++---------- 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/tests/unit/vertexai/test_tuning.py b/tests/unit/vertexai/test_tuning.py index 73506ad417..1d64b9a489 100644 --- a/tests/unit/vertexai/test_tuning.py +++ b/tests/unit/vertexai/test_tuning.py @@ -172,6 +172,7 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self): validation_dataset="gs://some-bucket/some_dataset.jsonl", epochs=300, learning_rate_multiplier=1.0, + adapter_size=8, ) assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING assert not sft_tuning_job.has_ended diff --git a/vertexai/tuning/_supervised_tuning.py b/vertexai/tuning/_supervised_tuning.py index db542c6bf2..e7018879b0 100644 --- a/vertexai/tuning/_supervised_tuning.py +++ b/vertexai/tuning/_supervised_tuning.py @@ -15,7 +15,9 @@ from typing import Dict, Literal, Optional, Union -from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types +from google.cloud.aiplatform_v1.types import ( + tuning_job as gca_tuning_job_types, +) from vertexai import generative_models from vertexai.tuning import _tuning @@ -31,44 +33,66 @@ def train( adapter_size: Optional[Literal[1, 4, 8, 16]] = None, labels: Optional[Dict[str, str]] = None, ) -> "SupervisedTuningJob": - """Tunes a model using supervised training. + """Tunes a model using supervised training. - Args: - source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002". - train_dataset: Cloud Storage path to file containing training dataset for - tuning. The dataset should be in JSONL format. - validation_dataset: Cloud Storage path to file containing validation - dataset for tuning. The dataset should be in JSONL format. - tuned_model_display_name: The display name of the - [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to - 128 characters long and can consist of any UTF-8 characters. - epochs: Number of training epoches for this tuning job. - learning_rate_multiplier: Learning rate multiplier for tuning. - adapter_size: Adapter size for tuning. - labels: User-defined metadata to be associated with trained models + Args: + source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002". + train_dataset: Cloud Storage path to file containing training dataset for + tuning. The dataset should be in JSONL format. + validation_dataset: Cloud Storage path to file containing validation + dataset for tuning. The dataset should be in JSONL format. + tuned_model_display_name: The display name of the + [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to + 128 characters long and can consist of any UTF-8 characters. + epochs: Number of training epoches for this tuning job. + learning_rate_multiplier: Learning rate multiplier for tuning. + adapter_size: Adapter size for tuning. + labels: User-defined metadata to be associated with trained models - Returns: - A `TuningJob` object. - """ - supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec( + Returns: + A `TuningJob` object. + """ + if adapter_size is None: + adapter_size_value = None + elif adapter_size == 1: + adapter_size_value = ( + gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_ONE + ) + elif adapter_size == 4: + adapter_size_value = ( + gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_FOUR + ) + elif adapter_size == 8: + adapter_size_value = ( + gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_EIGHT + ) + elif adapter_size == 16: + adapter_size_value = ( + gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_SIXTEEN + ) + else: + raise ValueError( + f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]" + ) + supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec( training_dataset_uri=train_dataset, validation_dataset_uri=validation_dataset, hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters( epoch_count=epochs, learning_rate_multiplier=learning_rate_multiplier, - adapter_size=adapter_size, + adapter_size=adapter_size_value, ), ) - if isinstance(source_model, generative_models.GenerativeModel): - source_model = source_model._prediction_resource_name.rpartition('/')[-1] + if isinstance(source_model, generative_models.GenerativeModel): + source_model = source_model._prediction_resource_name.rpartition("/")[-1] - return SupervisedTuningJob._create( # pylint: disable=protected-access - base_model=source_model, - tuning_spec=supervised_tuning_spec, - tuned_model_display_name=tuned_model_display_name, - labels=labels, - ) + return SupervisedTuningJob._create( # pylint: disable=protected-access + base_model=source_model, + tuning_spec=supervised_tuning_spec, + tuned_model_display_name=tuned_model_display_name, + labels=labels, + ) class SupervisedTuningJob(_tuning.TuningJob):