Skip to content

Commit

Permalink
fix: GenAI - Tuning - Supervised - Fix adapter_size parameter handl…
Browse files Browse the repository at this point in the history
…ing to match enum values.

PiperOrigin-RevId: 636608417
  • Loading branch information
ilai-deutel authored and copybara-github committed May 23, 2024
1 parent bed3dec commit 1cc22c3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
1 change: 1 addition & 0 deletions tests/unit/vertexai/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self):
validation_dataset="gs://some-bucket/some_dataset.jsonl",
epochs=300,
learning_rate_multiplier=1.0,
adapter_size=8,
)
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not sft_tuning_job.has_ended
Expand Down
80 changes: 52 additions & 28 deletions vertexai/tuning/_supervised_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from typing import Dict, Literal, Optional, Union

from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
from google.cloud.aiplatform_v1.types import (
tuning_job as gca_tuning_job_types,
)
from vertexai import generative_models
from vertexai.tuning import _tuning

Expand All @@ -31,44 +33,66 @@ def train(
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
labels: Optional[Dict[str, str]] = None,
) -> "SupervisedTuningJob":
"""Tunes a model using supervised training.
"""Tunes a model using supervised training.
Args:
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
train_dataset: Cloud Storage path to file containing training dataset for
tuning. The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation
dataset for tuning. The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
labels: User-defined metadata to be associated with trained models
Args:
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
train_dataset: Cloud Storage path to file containing training dataset for
tuning. The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation
dataset for tuning. The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
labels: User-defined metadata to be associated with trained models
Returns:
A `TuningJob` object.
"""
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
Returns:
A `TuningJob` object.
"""
if adapter_size is None:
adapter_size_value = None
elif adapter_size == 1:
adapter_size_value = (
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_ONE
)
elif adapter_size == 4:
adapter_size_value = (
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_FOUR
)
elif adapter_size == 8:
adapter_size_value = (
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_EIGHT
)
elif adapter_size == 16:
adapter_size_value = (
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_SIXTEEN
)
else:
raise ValueError(
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]"
)
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
training_dataset_uri=train_dataset,
validation_dataset_uri=validation_dataset,
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
epoch_count=epochs,
learning_rate_multiplier=learning_rate_multiplier,
adapter_size=adapter_size,
adapter_size=adapter_size_value,
),
)

if isinstance(source_model, generative_models.GenerativeModel):
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
if isinstance(source_model, generative_models.GenerativeModel):
source_model = source_model._prediction_resource_name.rpartition("/")[-1]

return SupervisedTuningJob._create( # pylint: disable=protected-access
base_model=source_model,
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
labels=labels,
)
return SupervisedTuningJob._create( # pylint: disable=protected-access
base_model=source_model,
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
labels=labels,
)


class SupervisedTuningJob(_tuning.TuningJob):
Expand Down

0 comments on commit 1cc22c3

Please sign in to comment.