-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add Gemini-pro-1.5 to GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 #1208
feat: add Gemini-pro-1.5 to GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 #1208
Changes from 11 commits
97d9259
c9318d0
4f9370c
8d5e0ed
e5413a1
7b65227
9d6376b
f315f54
e9f28f4
5d2a807
dc765ec
7a40315
8be16d3
8a39a02
361a734
b213753
9de2c0e
ed001b8
9928f10
ba80d10
241ae73
205e173
6a44e7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -874,7 +874,7 @@ def fit( | |
X: utils.ArrayType, | ||
y: utils.ArrayType, | ||
) -> GeminiTextGenerator: | ||
"""Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now. | ||
"""Fine tune GeminiTextGenerator model. Only support "gemini-pro" and "gemini-1.5" models for now. | ||
|
||
.. note:: | ||
|
||
|
@@ -892,8 +892,12 @@ def fit( | |
Returns: | ||
GeminiTextGenerator: Fitted estimator. | ||
""" | ||
if self._bqml_model.model_name.startswith("gemini-1.5"): | ||
raise NotImplementedError("Fit is not supported for gemini-1.5 model.") | ||
# Support gemini-1.5 and gemini-pro | ||
supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] | ||
if self.model_name not in supported_models: | ||
raise NotImplementedError( | ||
"Score is not supported models other than gemini-pro or gemini-1.5 model." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. be more explicit to "gemini-1.5-pro-002" and "gemini-1.5-flash-002". Since other gemini 1.5 endpoints aren't supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is in fit() in stead of score(). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am working on two bugs, b/381936588 and b/344891364, both fit() and score() needed to be updated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean the message is incorrect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't leave a comment at unchanged lines. Need to update line 905 to each endpoint respectively. (still use gemini-1.0-pro-002 for gemini-pro, but issuing a warning for that case maybe more appropriate) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
X, y = utils.batch_convert_to_dataframe(X, y) | ||
|
||
|
@@ -1009,7 +1013,7 @@ def score( | |
"text_generation", "classification", "summarization", "question_answering" | ||
] = "text_generation", | ||
) -> bpd.DataFrame: | ||
"""Calculate evaluation metrics of the model. Only "gemini-pro" model is supported for now. | ||
"""Calculate evaluation metrics of the model. Only "gemini-pro" and "gemini-1.5" models are supported for now. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, more explicit to "gemini-1.5-pro-002" and "gemini-1.5-flash-002" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
.. note:: | ||
|
||
|
@@ -1041,9 +1045,12 @@ def score( | |
if not self._bqml_model: | ||
raise RuntimeError("A model must be fitted before score") | ||
|
||
# TODO(ashleyxu): Support gemini-1.5 when the rollout is ready. b/344891364. | ||
if self._bqml_model.model_name.startswith("gemini-1.5"): | ||
raise NotImplementedError("Score is not supported for gemini-1.5 model.") | ||
# Support gemini-1.5 and gemini-pro | ||
supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move and share consts to top. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not seeing the change? I mean add const variables to the top of the file like https://github.com/googleapis/python-bigquery-dataframes/blob/main/bigframes/ml/llm.py#L67 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if self.model_name not in supported_models: | ||
raise NotImplementedError( | ||
"Score is not supported models other than gemini-pro or gemini-1.5 model." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, more explicit. Also the sentence seems awkward. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The message is ill-formed. Maybe just "score() method only supports xxx models". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
) | ||
|
||
X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -413,9 +413,17 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): | |
) | ||
|
||
|
||
@pytest.mark.flaky(retries=2) | ||
def test_llm_gemini_pro_score(llm_fine_tune_df_default_index): | ||
model = llm.GeminiTextGenerator(model_name="gemini-pro") | ||
# test score() function for "gemini-pro" and "gemini-1.5" model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: test name already indicates it. The comment seems redundant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
@pytest.mark.parametrize( | ||
"model_name", | ||
( | ||
"gemini-pro", | ||
"gemini-1.5-pro-002", | ||
"gemini-1.5-flash-002", | ||
), | ||
) | ||
def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add for next test "test_llm_gemini_pro_score_params" as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
model = llm.GeminiTextGenerator(model_name=model_name) | ||
|
||
# Check score to ensure the model was fitted | ||
score_result = model.score( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the consts to the top of the file for better organization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done