diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index fac9055968..191d4e4640 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -698,6 +698,34 @@ def test_generate_content(self, generative_models: generative_models): ) assert response3.text + model4 = generative_models.GenerativeModel("gemini-1.5-pro-preview-0409") + response4 = model4.generate_content( + "Why is sky blue? Respond in JSON.", + generation_config=generative_models.GenerationConfig( + temperature=0.2, + top_p=0.9, + top_k=20, + candidate_count=1, + max_output_tokens=200, + stop_sequences=["\n\n\n"], + response_mime_type="application/json", + ), + safety_settings=[ + generative_models.SafetySetting( + category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY, + ), + generative_models.SafetySetting( + category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + method=generative_models.SafetySetting.HarmBlockMethod.PROBABILITY, + ), + ], + labels={"label1": "value1", "label2": "value2"}, + ) + assert response4.text + @mock.patch.object( target=prediction_service.PredictionServiceClient, attribute="generate_content", diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index cf68c01a9b..805123bc1d 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -160,6 +160,7 @@ def _validate_generate_content_parameters( tool_config: Optional["ToolConfig"] = None, system_instruction: Optional[PartsType] = None, cached_content: Optional["caching.CachedContent"] = None, + labels: Optional[Dict[str, str]] = None, ) -> None: """Validates the parameters for a generate_content call.""" if not contents: @@ -190,6 +191,10 @@ def _validate_generate_content_parameters( if tool_config: _validate_tool_config_type(tool_config) + if labels: + if not isinstance(labels, Dict): + raise TypeError("labels must be a dictionary.") + def _validate_contents_type_as_valid_sequence(contents: ContentsType) -> None: """Makes sure that individual elements of contents are of valid type.""" @@ -323,6 +328,7 @@ def __init__( tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, system_instruction: Optional[PartsType] = None, + labels: Optional[Dict[str, str]] = None, ): r"""Initializes GenerativeModel. @@ -342,6 +348,7 @@ def __init__( system_instruction: Default system instruction to use in generate_content. Note: Only text should be used in parts. Content of each part will become a separate paragraph. + labels: labels that will be passed to billing for cost tracking. """ project = aiplatform_initializer.global_config.project location = aiplatform_initializer.global_config.location @@ -364,6 +371,7 @@ def __init__( self._tool_config = tool_config self._system_instruction = system_instruction self._cached_content: Optional["caching.CachedContent"] = None + self._labels = labels # Validating the parameters _validate_generate_content_parameters( @@ -373,6 +381,7 @@ def __init__( tools=tools, tool_config=tool_config, system_instruction=system_instruction, + labels=labels, ) @property @@ -440,6 +449,7 @@ def _prepare_request( tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, system_instruction: Optional[PartsType] = None, + labels: Optional[Dict[str, str]] = None, ) -> gapic_prediction_service_types.GenerateContentRequest: """Prepares a GAPIC GenerateContentRequest.""" if not contents: @@ -451,6 +461,7 @@ def _prepare_request( tool_config = tool_config or self._tool_config system_instruction = system_instruction or self._system_instruction cached_content = self._cached_content + labels = labels or self._labels _validate_generate_content_parameters( contents=contents, @@ -460,6 +471,7 @@ def _prepare_request( tool_config=tool_config, system_instruction=system_instruction, cached_content=cached_content, + labels=labels, ) contents = _content_types_to_gapic_contents(contents) @@ -519,6 +531,7 @@ def _prepare_request( tool_config=gapic_tool_config, system_instruction=gapic_system_instruction, cached_content=cached_content.resource_name if cached_content else None, + labels=labels, ) def _parse_response( @@ -536,6 +549,7 @@ def generate_content( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[False] = False, ) -> "GenerationResponse": ... @@ -549,6 +563,7 @@ def generate_content( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[True], ) -> Iterable["GenerationResponse"]: ... @@ -561,6 +576,7 @@ def generate_content( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: bool = False, ) -> Union["GenerationResponse", Iterable["GenerationResponse"]]: """Generates content. @@ -577,6 +593,7 @@ def generate_content( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. stream: Whether to stream the response. Returns: @@ -591,6 +608,7 @@ def generate_content( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) else: return self._generate_content( @@ -599,6 +617,7 @@ def generate_content( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) @overload @@ -610,6 +629,7 @@ async def generate_content_async( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[False] = False, ) -> "GenerationResponse": ... @@ -623,6 +643,7 @@ async def generate_content_async( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[True] = True, ) -> AsyncIterable["GenerationResponse"]: ... @@ -635,6 +656,7 @@ async def generate_content_async( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, stream: bool = False, ) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"]]: """Generates content asynchronously. @@ -651,6 +673,7 @@ async def generate_content_async( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. stream: Whether to stream the response. Returns: @@ -664,6 +687,7 @@ async def generate_content_async( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) else: return await self._generate_content_async( @@ -672,6 +696,7 @@ async def generate_content_async( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) def _generate_content( @@ -682,6 +707,7 @@ def _generate_content( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, ) -> "GenerationResponse": """Generates content. @@ -697,6 +723,7 @@ def _generate_content( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. Returns: A single GenerationResponse object @@ -707,6 +734,7 @@ def _generate_content( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) gapic_response = self._prediction_client.generate_content(request=request) return self._parse_response(gapic_response) @@ -719,6 +747,7 @@ async def _generate_content_async( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, ) -> "GenerationResponse": """Generates content asynchronously. @@ -734,6 +763,7 @@ async def _generate_content_async( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. Returns: An awaitable for a single GenerationResponse object @@ -744,6 +774,7 @@ async def _generate_content_async( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) gapic_response = await self._prediction_async_client.generate_content( request=request @@ -758,6 +789,7 @@ def _generate_content_streaming( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, ) -> Iterable["GenerationResponse"]: """Generates content. @@ -773,6 +805,7 @@ def _generate_content_streaming( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. Yields: A stream of GenerationResponse objects @@ -783,6 +816,7 @@ def _generate_content_streaming( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) response_stream = self._prediction_client.stream_generate_content( request=request @@ -798,6 +832,7 @@ async def _generate_content_streaming_async( safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, + labels: Optional[Dict[str, str]] = None, ) -> AsyncIterable["GenerationResponse"]: """Generates content asynchronously. @@ -813,6 +848,7 @@ async def _generate_content_streaming_async( safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. tool_config: Config shared for all tools provided in the request. + labels: Labels that will be passed to billing for cost tracking. Returns: An awaitable for a stream of GenerationResponse objects @@ -823,6 +859,7 @@ async def _generate_content_streaming_async( safety_settings=safety_settings, tools=tools, tool_config=tool_config, + labels=labels, ) response_stream = await self._prediction_async_client.stream_generate_content( request=request @@ -1115,6 +1152,7 @@ def send_message( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[False] = False, ) -> "GenerationResponse": ... @@ -1127,6 +1165,7 @@ def send_message( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[True] = True, ) -> Iterable["GenerationResponse"]: ... @@ -1138,6 +1177,7 @@ def send_message( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: bool = False, ) -> Union["GenerationResponse", Iterable["GenerationResponse"]]: """Generates content. @@ -1151,6 +1191,7 @@ def send_message( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. stream: Whether to stream the response. Returns: @@ -1166,6 +1207,7 @@ def send_message( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) else: return self._send_message( @@ -1173,6 +1215,7 @@ def send_message( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) @overload @@ -1183,6 +1226,7 @@ def send_message_async( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[False] = False, ) -> Awaitable["GenerationResponse"]: ... @@ -1195,6 +1239,7 @@ def send_message_async( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: Literal[True] = True, ) -> Awaitable[AsyncIterable["GenerationResponse"]]: ... @@ -1206,6 +1251,7 @@ def send_message_async( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, stream: bool = False, ) -> Union[ Awaitable["GenerationResponse"], @@ -1222,6 +1268,7 @@ def send_message_async( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. stream: Whether to stream the response. Returns: @@ -1237,6 +1284,7 @@ def send_message_async( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) else: return self._send_message_async( @@ -1244,6 +1292,7 @@ def send_message_async( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) def _send_message( @@ -1253,6 +1302,7 @@ def _send_message( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, ) -> "GenerationResponse": """Generates content. @@ -1265,6 +1315,7 @@ def _send_message( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. Returns: A single GenerationResponse object @@ -1293,6 +1344,7 @@ def _send_message( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) # By default we're not adding incomplete interactions to history. if self._response_validator is not None: @@ -1329,6 +1381,7 @@ async def _send_message_async( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, ) -> "GenerationResponse": """Generates content asynchronously. @@ -1341,6 +1394,7 @@ async def _send_message_async( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. Returns: An awaitable for a single GenerationResponse object @@ -1361,6 +1415,7 @@ async def _send_message_async( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) # By default we're not adding incomplete interactions to history. if self._response_validator is not None: @@ -1385,6 +1440,7 @@ def _send_message_streaming( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, ) -> Iterable["GenerationResponse"]: """Generates content. @@ -1397,6 +1453,7 @@ def _send_message_streaming( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. Yields: A stream of GenerationResponse objects @@ -1417,6 +1474,7 @@ def _send_message_streaming( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) chunks = [] full_response = None @@ -1451,6 +1509,7 @@ async def _send_message_streaming_async( generation_config: Optional[GenerationConfigType] = None, safety_settings: Optional[SafetySettingsType] = None, tools: Optional[List["Tool"]] = None, + labels: Optional[Dict[str, str]] = None, ) -> AsyncIterable["GenerationResponse"]: """Generates content asynchronously. @@ -1463,6 +1522,7 @@ async def _send_message_streaming_async( generation_config: Parameters for the generation. safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold. tools: A list of tools (functions) that the model can try calling. + labels: Labels that will be passed to billing for cost tracking. Returns: An awaitable for a stream of GenerationResponse objects @@ -1482,6 +1542,7 @@ async def _send_message_streaming_async( generation_config=generation_config, safety_settings=safety_settings, tools=tools, + labels=labels, ) async def async_generator(): @@ -3081,6 +3142,7 @@ def _prepare_request( tools: Optional[List["Tool"]] = None, tool_config: Optional["ToolConfig"] = None, system_instruction: Optional[PartsType] = None, + labels: Optional[Dict[str, str]] = None, ) -> types_v1.GenerateContentRequest: """Prepares a GAPIC GenerateContentRequest.""" request_v1beta1 = super()._prepare_request( @@ -3090,6 +3152,7 @@ def _prepare_request( tools=tools, tool_config=tool_config, system_instruction=system_instruction, + labels=labels, ) serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1) try: