Skip to content

Commit

Permalink
Merge branch 'main' into fields_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron authored Oct 8, 2024
2 parents 5de9352 + d1b87e2 commit abea505
Show file tree
Hide file tree
Showing 12 changed files with 1,135 additions and 145 deletions.
2 changes: 1 addition & 1 deletion bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
):
self._credentials = credentials
self._project = project
self._location = location
self._location = _get_validated_location(location)
self._bq_connection = bq_connection
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
Expand Down
4 changes: 3 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ class GbqTable:
table_id: str = field()
physical_schema: Tuple[bq.SchemaField, ...] = field()
n_rows: int = field()
is_physical_table: bool = field()
cluster_cols: typing.Optional[Tuple[str, ...]]

@staticmethod
Expand All @@ -525,6 +526,7 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
table_id=table.table_id,
physical_schema=schema,
n_rows=table.num_rows,
is_physical_table=(table.table_type == "TABLE"),
cluster_cols=None
if table.clustering_fields is None
else tuple(table.clustering_fields),
Expand Down Expand Up @@ -605,7 +607,7 @@ def variables_introduced(self) -> int:

@property
def row_count(self) -> typing.Optional[int]:
if self.source.sql_predicate is None:
if self.source.sql_predicate is None and self.source.table.is_physical_table:
return self.source.table.n_rows
return None

Expand Down
42 changes: 32 additions & 10 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"

_MODEL_NOT_SUPPORTED_WARNING = (
"Model name '{model_name}' is not supported. "
"We are currently aware of the following models: {known_models}. "
"However, model names can change, and the supported models may be outdated. "
"You should use this model name only if you are sure that it is supported in BigQuery."
)


@typing_extensions.deprecated(
"PaLM2TextGenerator is going to be deprecated. Use GeminiTextGenerator(https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.GeminiTextGenerator) instead. ",
Expand Down Expand Up @@ -154,8 +161,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_TEXT_GENERATOR_ENDPOINTS),
)
)

options = {
Expand Down Expand Up @@ -484,8 +494,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS),
)
)

endpoint = (
Expand Down Expand Up @@ -644,8 +657,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_TEXT_EMBEDDING_ENDPOINTS),
)
)

options = {
Expand Down Expand Up @@ -801,8 +817,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _GEMINI_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_GEMINI_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_GEMINI_ENDPOINTS),
)
)

options = {"endpoint": self.model_name}
Expand Down Expand Up @@ -1118,8 +1137,11 @@ def _create_bqml_model(self):
)

if self.model_name not in _CLAUDE_3_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}."
warnings.warn(
_MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=", ".join(_CLAUDE_3_ENDPOINTS),
)
)

options = {
Expand Down
Loading

0 comments on commit abea505

Please sign in to comment.