Skip to content

Commit

Permalink
feat: add TextEmbedding model version support (#394)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕
  • Loading branch information
GarrettWu authored Feb 29, 2024
1 parent 1726588 commit e0f1ab0
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 5 deletions.
18 changes: 16 additions & 2 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ class PaLM2TextEmbeddingGenerator(base.Predictor):
The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
"textembedding-gecko-multilingual" returns model embeddings for text inputs which support over 100 languages
Default to "textembedding-gecko".
version (str or None):
Model version. Accepted values are "001", "002", "003", "latest" etc. Will use the default version if unset.
See https://cloud.google.com/vertex-ai/docs/generative-ai/learn/model-versioning for details.
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
Expand All @@ -279,10 +282,12 @@ def __init__(
model_name: Literal[
"textembedding-gecko", "textembedding-gecko-multilingual"
] = "textembedding-gecko",
version: Optional[str] = None,
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
):
self.model_name = model_name
self.version = version
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
Expand Down Expand Up @@ -321,8 +326,11 @@ def _create_bqml_model(self):
f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}."
)

endpoint = (
self.model_name + "@" + self.version if self.version else self.model_name
)
options = {
"endpoint": self.model_name,
"endpoint": endpoint,
}
return self._bqml_model_factory.create_remote_model(
session=self.session, connection_name=self.connection_name, options=options
Expand All @@ -342,8 +350,14 @@ def _from_bq(
model_connection = model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

model_name, version = utils.parse_model_endpoint(model_endpoint)

embedding_generator_model = cls(
session=session, model_name=model_endpoint, connection_name=model_connection
session=session,
# str to literals
model_name=model_name, # type: ignore
version=version,
connection_name=model_connection,
)
embedding_generator_model._bqml_model = core.BqmlModel(session, model)
return embedding_generator_model
Expand Down
7 changes: 5 additions & 2 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
linear_model,
llm,
pipeline,
utils,
)

_BQML_MODEL_TYPE_MAPPING = MappingProxyType(
Expand Down Expand Up @@ -106,8 +107,10 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
):
# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
endpoint_model = bqml_endpoint.split("/")[-1]
return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore
model_endpoint = bqml_endpoint.split("/")[-1]
model_name, _ = utils.parse_model_endpoint(model_endpoint)

return _BQML_ENDPOINT_TYPE_MAPPING[model_name]._from_bq( # type: ignore
session=session, model=bq_model
)

Expand Down
15 changes: 14 additions & 1 deletion bigframes/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import typing
from typing import Iterable, Union
from typing import Iterable, Optional, Union

import bigframes.constants as constants
from bigframes.core import blocks
Expand Down Expand Up @@ -56,3 +56,16 @@ def _convert_to_series(frame: ArrayType) -> bpd.Series:
raise ValueError(
f"Unsupported type {type(frame)} to convert to Series. {constants.FEEDBACK_LINK}"
)


def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
"""Parse model endpoint string to model_name and version."""
model_name = model_endpoint
version = None

at_idx = model_endpoint.find("@")
if at_idx != -1:
version = model_endpoint[at_idx + 1 :]
model_name = model_endpoint[:at_idx]

return model_name, version
9 changes: 9 additions & 0 deletions tests/system/small/ml/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def palm2_embedding_generator_model(
)


@pytest.fixture(scope="session")
def palm2_embedding_generator_model_002(
session, bq_connection
) -> llm.PaLM2TextEmbeddingGenerator:
return llm.PaLM2TextEmbeddingGenerator(
version="002", session=session, connection_name=bq_connection
)


@pytest.fixture(scope="session")
def palm2_embedding_generator_multilingual_model(
session, bq_connection
Expand Down
17 changes: 17 additions & 0 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def test_create_embedding_generator_model(
assert reloaded_model.connection_name == bq_connection


def test_create_embedding_generator_model_002(
palm2_embedding_generator_model_002, dataset_id, bq_connection
):
# Model creation doesn't return error
assert palm2_embedding_generator_model_002 is not None
assert palm2_embedding_generator_model_002._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = palm2_embedding_generator_model_002.to_gbq(
f"{dataset_id}.temp_embedding_model", replace=True
)
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.model_name == "textembedding-gecko"
assert reloaded_model.version == "002"
assert reloaded_model.connection_name == bq_connection


def test_create_embedding_generator_multilingual_model(
palm2_embedding_generator_multilingual_model,
dataset_id,
Expand Down

0 comments on commit e0f1ab0

Please sign in to comment.