From 8193abe395c5648db8169818eca29aee76c46478 Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:43:20 -0800 Subject: [PATCH] feat: add client side retry to GeminiTextGenerator (#1242) * feat: add client side retry to GeminiTextGenerator * test * test * test * test * fix * max_retries * fix * fix --- bigframes/ml/llm.py | 39 +++++- tests/system/small/ml/test_llm.py | 221 +++++++++++++++++++++++++++++- 2 files changed, 256 insertions(+), 4 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 9b7228fe83..2427009cf1 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -26,6 +26,7 @@ import bigframes from bigframes import clients, exceptions from bigframes.core import blocks, log_adapter +import bigframes.dataframe from bigframes.ml import base, core, globals, utils import bigframes.pandas as bpd @@ -945,6 +946,7 @@ def predict( top_k: int = 40, top_p: float = 1.0, ground_with_google_search: bool = False, + max_retries: int = 0, ) -> bpd.DataFrame: """Predict the result from input DataFrame. @@ -983,6 +985,10 @@ def predict( page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models The default is `False`. + max_retries (int, default 0): + Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round. + Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result. + Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. """ @@ -1002,6 +1008,11 @@ def predict( if top_p < 0.0 or top_p > 1.0: raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.") + if max_retries < 0: + raise ValueError( + f"max_retries must be larger than or equal to 0, but is {max_retries}." + ) + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) if len(X.columns) == 1: @@ -1018,15 +1029,37 @@ def predict( "ground_with_google_search": ground_with_google_search, } - df = self._bqml_model.generate_text(X, options) + df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder + df_fail = X + for _ in range(max_retries + 1): + df = self._bqml_model.generate_text(df_fail, options) - if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): + df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0] + df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0] + + if df_succ.empty: + warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning) + break + + df_result = ( + bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ + ) + + if df_fail.empty: + break + + if not df_fail.empty: warnings.warn( f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", RuntimeWarning, ) - return df + df_result = cast( + bpd.DataFrame, + bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, + ) + + return df_result def score( self, diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index dca3c35d84..304204cc7b 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + +import pandas as pd import pytest from bigframes import exceptions -from bigframes.ml import llm +from bigframes.ml import core, llm import bigframes.pandas as bpd from tests.system import utils @@ -372,6 +375,222 @@ def test_gemini_text_generator_multi_cols_predict_success( ) +# Overrides __eq__ function for comparing as mock.call parameter +class EqCmpAllDataFrame(bpd.DataFrame): + def __eq__(self, other): + return self.equals(other) + + +def test_gemini_text_generator_retry_success(session, bq_connection): + # Requests. + df0 = EqCmpAllDataFrame( + { + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ] + }, + index=[0, 1, 2], + session=session, + ) + df1 = EqCmpAllDataFrame( + { + "ml_generate_text_status": ["error", "error"], + "prompt": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ) + df2 = EqCmpAllDataFrame( + { + "ml_generate_text_status": ["error"], + "prompt": [ + "What is BQML?", + ], + }, + index=[1], + session=session, + ) + + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) + + # Responses. Retry twice then all succeeded. + mock_bqml_model.generate_text.side_effect = [ + EqCmpAllDataFrame( + { + "ml_generate_text_status": ["", "error", "error"], + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_text_status": ["error", ""], + "prompt": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_text_status": [""], + "prompt": [ + "What is BQML?", + ], + }, + index=[1], + session=session, + ), + ] + options = { + "temperature": 0.9, + "max_output_tokens": 8192, + "top_k": 40, + "top_p": 1.0, + "flatten_json_output": True, + "ground_with_google_search": False, + } + + gemini_text_generator_model = llm.GeminiTextGenerator( + connection_name=bq_connection, session=session + ) + gemini_text_generator_model._bqml_model = mock_bqml_model + + # 3rd retry isn't triggered + result = gemini_text_generator_model.predict(df0, max_retries=3) + + mock_bqml_model.generate_text.assert_has_calls( + [ + mock.call(df0, options), + mock.call(df1, options), + mock.call(df2, options), + ] + ) + pd.testing.assert_frame_equal( + result.to_pandas(), + pd.DataFrame( + { + "ml_generate_text_status": ["", "", ""], + "prompt": [ + "What is BigQuery?", + "What is BigQuery DataFrame?", + "What is BQML?", + ], + }, + index=[0, 2, 1], + ), + check_dtype=False, + check_index_type=False, + ) + + +def test_gemini_text_generator_retry_no_progress(session, bq_connection): + # Requests. + df0 = EqCmpAllDataFrame( + { + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ] + }, + index=[0, 1, 2], + session=session, + ) + df1 = EqCmpAllDataFrame( + { + "ml_generate_text_status": ["error", "error"], + "prompt": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ) + + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) + # Responses. Retry once, no progress, just stop. + mock_bqml_model.generate_text.side_effect = [ + EqCmpAllDataFrame( + { + "ml_generate_text_status": ["", "error", "error"], + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_text_status": ["error", "error"], + "prompt": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ), + ] + options = { + "temperature": 0.9, + "max_output_tokens": 8192, + "top_k": 40, + "top_p": 1.0, + "flatten_json_output": True, + "ground_with_google_search": False, + } + + gemini_text_generator_model = llm.GeminiTextGenerator( + connection_name=bq_connection, session=session + ) + gemini_text_generator_model._bqml_model = mock_bqml_model + + # No progress, only conduct retry once + result = gemini_text_generator_model.predict(df0, max_retries=3) + + mock_bqml_model.generate_text.assert_has_calls( + [ + mock.call(df0, options), + mock.call(df1, options), + ] + ) + pd.testing.assert_frame_equal( + result.to_pandas(), + pd.DataFrame( + { + "ml_generate_text_status": ["", "error", "error"], + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + ), + check_dtype=False, + check_index_type=False, + ) + + @pytest.mark.flaky(retries=2) def test_llm_palm_score(llm_fine_tune_df_default_index): model = llm.PaLM2TextGenerator(model_name="text-bison")