Skip to content

Commit

Permalink
chore: fix wordings of Gemini max_retries
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettWu committed Dec 31, 2024
1 parent 3068e19 commit f003deb
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,9 +986,8 @@ def predict(
The default is `False`.
max_retries (int, default 0):
Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round.
Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result.
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
Expand Down Expand Up @@ -1034,11 +1033,15 @@ def predict(
for _ in range(max_retries + 1):
df = self._bqml_model.generate_text(df_fail, options)

df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0]
df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0]
success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0
df_succ = df[success]
df_fail = df[~success]

if df_succ.empty:
warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning)
if max_retries > 0:
warnings.warn(
"Can't make any progress, stop retrying.", RuntimeWarning
)
break

df_result = (
Expand Down

0 comments on commit f003deb

Please sign in to comment.