Skip to content

Commit

Permalink
Merge branch 'main' into sycai_sem_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
sycai authored Dec 27, 2024
2 parents 315f75c + 8193abe commit 41e2cb3
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 4 deletions.
39 changes: 36 additions & 3 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import bigframes
from bigframes import clients, exceptions
from bigframes.core import blocks, log_adapter
import bigframes.dataframe
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd

Expand Down Expand Up @@ -945,6 +946,7 @@ def predict(
top_k: int = 40,
top_p: float = 1.0,
ground_with_google_search: bool = False,
max_retries: int = 0,
) -> bpd.DataFrame:
"""Predict the result from input DataFrame.
Expand Down Expand Up @@ -983,6 +985,10 @@ def predict(
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
The default is `False`.
max_retries (int, default 0):
Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round.
Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
Expand All @@ -1002,6 +1008,11 @@ def predict(
if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")

if max_retries < 0:
raise ValueError(
f"max_retries must be larger than or equal to 0, but is {max_retries}."
)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

if len(X.columns) == 1:
Expand All @@ -1018,15 +1029,37 @@ def predict(
"ground_with_google_search": ground_with_google_search,
}

df = self._bqml_model.generate_text(X, options)
df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
df_fail = X
for _ in range(max_retries + 1):
df = self._bqml_model.generate_text(df_fail, options)

if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0]
df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0]

if df_succ.empty:
warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning)
break

df_result = (
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
)

if df_fail.empty:
break

if not df_fail.empty:
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

return df
df_result = cast(
bpd.DataFrame,
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
)

return df_result

def score(
self,
Expand Down
221 changes: 220 additions & 1 deletion tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pandas as pd
import pytest

from bigframes import exceptions
from bigframes.ml import llm
from bigframes.ml import core, llm
import bigframes.pandas as bpd
from tests.system import utils

Expand Down Expand Up @@ -372,6 +375,222 @@ def test_gemini_text_generator_multi_cols_predict_success(
)


# Overrides __eq__ function for comparing as mock.call parameter
class EqCmpAllDataFrame(bpd.DataFrame):
def __eq__(self, other):
return self.equals(other)


def test_gemini_text_generator_retry_success(session, bq_connection):
# Requests.
df0 = EqCmpAllDataFrame(
{
"prompt": [
"What is BigQuery?",
"What is BQML?",
"What is BigQuery DataFrame?",
]
},
index=[0, 1, 2],
session=session,
)
df1 = EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", "error"],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[1, 2],
session=session,
)
df2 = EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error"],
"prompt": [
"What is BQML?",
],
},
index=[1],
session=session,
)

mock_bqml_model = mock.create_autospec(spec=core.BqmlModel)
type(mock_bqml_model).session = mock.PropertyMock(return_value=session)

# Responses. Retry twice then all succeeded.
mock_bqml_model.generate_text.side_effect = [
EqCmpAllDataFrame(
{
"ml_generate_text_status": ["", "error", "error"],
"prompt": [
"What is BigQuery?",
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[0, 1, 2],
session=session,
),
EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", ""],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[1, 2],
session=session,
),
EqCmpAllDataFrame(
{
"ml_generate_text_status": [""],
"prompt": [
"What is BQML?",
],
},
index=[1],
session=session,
),
]
options = {
"temperature": 0.9,
"max_output_tokens": 8192,
"top_k": 40,
"top_p": 1.0,
"flatten_json_output": True,
"ground_with_google_search": False,
}

gemini_text_generator_model = llm.GeminiTextGenerator(
connection_name=bq_connection, session=session
)
gemini_text_generator_model._bqml_model = mock_bqml_model

# 3rd retry isn't triggered
result = gemini_text_generator_model.predict(df0, max_retries=3)

mock_bqml_model.generate_text.assert_has_calls(
[
mock.call(df0, options),
mock.call(df1, options),
mock.call(df2, options),
]
)
pd.testing.assert_frame_equal(
result.to_pandas(),
pd.DataFrame(
{
"ml_generate_text_status": ["", "", ""],
"prompt": [
"What is BigQuery?",
"What is BigQuery DataFrame?",
"What is BQML?",
],
},
index=[0, 2, 1],
),
check_dtype=False,
check_index_type=False,
)


def test_gemini_text_generator_retry_no_progress(session, bq_connection):
# Requests.
df0 = EqCmpAllDataFrame(
{
"prompt": [
"What is BigQuery?",
"What is BQML?",
"What is BigQuery DataFrame?",
]
},
index=[0, 1, 2],
session=session,
)
df1 = EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", "error"],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[1, 2],
session=session,
)

mock_bqml_model = mock.create_autospec(spec=core.BqmlModel)
type(mock_bqml_model).session = mock.PropertyMock(return_value=session)
# Responses. Retry once, no progress, just stop.
mock_bqml_model.generate_text.side_effect = [
EqCmpAllDataFrame(
{
"ml_generate_text_status": ["", "error", "error"],
"prompt": [
"What is BigQuery?",
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[0, 1, 2],
session=session,
),
EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", "error"],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[1, 2],
session=session,
),
]
options = {
"temperature": 0.9,
"max_output_tokens": 8192,
"top_k": 40,
"top_p": 1.0,
"flatten_json_output": True,
"ground_with_google_search": False,
}

gemini_text_generator_model = llm.GeminiTextGenerator(
connection_name=bq_connection, session=session
)
gemini_text_generator_model._bqml_model = mock_bqml_model

# No progress, only conduct retry once
result = gemini_text_generator_model.predict(df0, max_retries=3)

mock_bqml_model.generate_text.assert_has_calls(
[
mock.call(df0, options),
mock.call(df1, options),
]
)
pd.testing.assert_frame_equal(
result.to_pandas(),
pd.DataFrame(
{
"ml_generate_text_status": ["", "error", "error"],
"prompt": [
"What is BigQuery?",
"What is BQML?",
"What is BigQuery DataFrame?",
],
},
index=[0, 1, 2],
),
check_dtype=False,
check_index_type=False,
)


@pytest.mark.flaky(retries=2)
def test_llm_palm_score(llm_fine_tune_df_default_index):
model = llm.PaLM2TextGenerator(model_name="text-bison")
Expand Down

0 comments on commit 41e2cb3

Please sign in to comment.