Skip to content

Commit

Permalink
feat: update LLM generators to warn user about model name instead of …
Browse files Browse the repository at this point in the history
…raising error. (#1048)

* feat: update LLM generators to warn user about model name instead of raise error.

* update message and format

* update message and format
  • Loading branch information
Genesis929 authored Oct 8, 2024
1 parent 02c2da7 commit 650d80d
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"

_MODEL_NOT_SUPPORTED_WARNING = (
"Model name '{model_name}' is not supported. "
"We are currently aware of the following models: {known_models}. "
"However, model names can change, and the supported models may be outdated. "
"You should use this model name only if you are sure that it is supported in BigQuery."
)


@typing_extensions.deprecated(
"PaLM2TextGenerator is going to be deprecated. Use GeminiTextGenerator(https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.GeminiTextGenerator) instead. ",
Expand Down Expand Up @@ -154,8 +161,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_TEXT_GENERATOR_ENDPOINTS),
)
)

options = {
Expand Down Expand Up @@ -484,8 +494,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS),
)
)

endpoint = (
Expand Down Expand Up @@ -644,8 +657,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_TEXT_EMBEDDING_ENDPOINTS),
)
)

options = {
Expand Down Expand Up @@ -801,8 +817,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _GEMINI_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_GEMINI_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_GEMINI_ENDPOINTS),
)
)

options = {"endpoint": self.model_name}
Expand Down Expand Up @@ -1118,8 +1137,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _CLAUDE_3_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_CLAUDE_3_ENDPOINTS),
)
)

options = {
Expand Down

0 comments on commit 650d80d

Please sign in to comment.