diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index f638dffc7ef..74cb9fb7c27 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -379,9 +379,29 @@ def test_chat_function_calling(self, generative_models: generative_models): [generative_models, preview_generative_models], ) def test_conversion_methods(self, generative_models: generative_models): - """Tests the .to_dict, .from_dict and __repr__ methods""" - model = generative_models.GenerativeModel("gemini-pro") - response = model.generate_content("Why is sky blue?") + """Tests the .to_dict, .from_dict and __repr__ methods.""" + # Testing on a full chat conversation which includes function calling + get_current_weather_func = generative_models.FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT, + ) + weather_tool = generative_models.Tool( + function_declarations=[get_current_weather_func], + ) + + model = generative_models.GenerativeModel("gemini-pro", tools=[weather_tool]) + chat = model.start_chat() + response = chat.send_message("What is the weather like in Boston?") + chat.send_message( + generative_models.Part.from_function_response( + name="get_current_weather", + response={ + "location": "Boston", + "weather": "super nice", + }, + ), + ) response_new = generative_models.GenerationResponse.from_dict( response.to_dict() @@ -400,6 +420,12 @@ def test_conversion_methods(self, generative_models: generative_models): part_new = generative_models.Part.from_dict(part.to_dict()) assert repr(part_new) == repr(part) + # Checking the history which contains different Part types + for content in chat.history: + for part in content.parts: + part_new = generative_models.Part.from_dict(part.to_dict()) + assert repr(part_new) == repr(part) + @mock.patch.object( target=prediction_service.PredictionServiceClient, attribute="generate_content", diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 215db7c9f07..15daca062a3 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -45,6 +45,7 @@ from vertexai.language_models import ( _language_models as tunable_models, ) +from google.protobuf import json_format import warnings try: @@ -1377,9 +1378,8 @@ def _from_gapic( @classmethod def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse": - raw_response = gapic_prediction_service_types.GenerateContentResponse( - response_dict - ) + raw_response = gapic_prediction_service_types.GenerateContentResponse() + json_format.ParseDict(response_dict, raw_response._pb) return cls._from_gapic(raw_response=raw_response) def to_dict(self) -> Dict[str, Any]: @@ -1418,7 +1418,8 @@ def _from_gapic(cls, raw_candidate: gapic_content_types.Candidate) -> "Candidate @classmethod def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate": - raw_candidate = gapic_content_types.Candidate(candidate_dict) + raw_candidate = gapic_content_types.Candidate() + json_format.ParseDict(candidate_dict, raw_candidate._pb) return cls._from_gapic(raw_candidate=raw_candidate) def to_dict(self) -> Dict[str, Any]: @@ -1497,7 +1498,8 @@ def _from_gapic(cls, raw_content: gapic_content_types.Content) -> "Content": @classmethod def from_dict(cls, content_dict: Dict[str, Any]) -> "Content": - raw_content = gapic_content_types.Content(content_dict) + raw_content = gapic_content_types.Content() + json_format.ParseDict(content_dict, raw_content._pb) return cls._from_gapic(raw_content=raw_content) def to_dict(self) -> Dict[str, Any]: @@ -1563,7 +1565,8 @@ def _from_gapic(cls, raw_part: gapic_content_types.Part) -> "Part": @classmethod def from_dict(cls, part_dict: Dict[str, Any]) -> "Part": - raw_part = gapic_content_types.Part(part_dict) + raw_part = gapic_content_types.Part() + json_format.ParseDict(part_dict, raw_part._pb) return cls._from_gapic(raw_part=raw_part) def __repr__(self):