Skip to content

Commit

Permalink
feat: GenAI - Forced function calling feature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622223390
  • Loading branch information
matthew29tang authored and copybara-github committed Apr 5, 2024
1 parent 13493a4 commit 806ef9f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 6 deletions.
8 changes: 8 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,18 @@ def test_chat_function_calling(self, generative_models: generative_models):
function_declarations=[get_current_weather_func],
)

tool_config = generative_models.ToolConfig(
function_calling_config=generative_models.ToolConfig.FunctionCallingConfig(
mode=generative_models.ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather"],
)
)

model = generative_models.GenerativeModel(
"gemini-pro",
# Specifying the tools once to avoid specifying them in every request
tools=[weather_tool],
tool_config=tool_config,
)
chat = model.start_chat()

Expand Down
2 changes: 2 additions & 0 deletions vertexai/generative_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ResponseValidationError,
SafetySetting,
Tool,
ToolConfig,
)

__all__ = [
Expand All @@ -50,4 +51,5 @@
"ResponseValidationError",
"SafetySetting",
"Tool",
"ToolConfig",
]
121 changes: 115 additions & 6 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
system_instruction: Optional[PartsType] = None,
):
r"""Initializes GenerativeModel.
Expand All @@ -149,6 +150,7 @@ def __init__(
generation_config: Default generation config to use in generate_content.
safety_settings: Default safety settings to use in generate_content.
tools: Default tools to use in generate_content.
tool_config: Default tool config to use in generate_content.
system_instruction: Default system instruction to use in generate_content.
Note: Only text should be used in parts.
Content of each part will become a separate paragraph.
Expand All @@ -164,7 +166,9 @@ def __init__(
location = aiplatform_initializer.global_config.location

if model_name.startswith("publishers/"):
prediction_resource_name = f"projects/{project}/locations/{location}/{model_name}"
prediction_resource_name = (
f"projects/{project}/locations/{location}/{model_name}"
)
else:
prediction_resource_name = model_name

Expand All @@ -173,6 +177,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction

# Validating the parameters
Expand All @@ -181,6 +186,7 @@ def __init__(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
system_instruction=system_instruction,
)

Expand Down Expand Up @@ -217,6 +223,7 @@ def _prepare_request(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
system_instruction: Optional[PartsType] = None,
) -> gapic_prediction_service_types.GenerateContentRequest:
"""Prepares a GAPIC GenerateContentRequest."""
Expand All @@ -226,6 +233,7 @@ def _prepare_request(
generation_config = generation_config or self._generation_config
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools
tool_config = tool_config or self._tool_config
system_instruction = system_instruction or self._system_instruction

# contents can either be a list of Content objects (most generic case)
Expand Down Expand Up @@ -316,6 +324,13 @@ def _prepare_request(
else:
raise TypeError(f"Unexpected tool type: {tool}.")

gapic_tool_config = None
if tool_config:
if isinstance(tool_config, ToolConfig):
gapic_tool_config = tool_config._gapic_tool_config
else:
raise TypeError("tool_config must be a ToolConfig object.")

return gapic_prediction_service_types.GenerateContentRequest(
# The `model` parameter now needs to be set for the vision models.
# Always need to pass the resource via the `model` parameter.
Expand All @@ -325,6 +340,7 @@ def _prepare_request(
generation_config=gapic_generation_config,
safety_settings=gapic_safety_settings,
tools=gapic_tools,
tool_config=gapic_tool_config,
system_instruction=gapic_system_instruction,
)

Expand All @@ -341,6 +357,7 @@ def generate_content(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
) -> Union["GenerationResponse", Iterable["GenerationResponse"],]:
"""Generates content.
Expand All @@ -356,6 +373,7 @@ def generate_content(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
stream: Whether to stream the response.
Returns:
Expand All @@ -369,13 +387,15 @@ def generate_content(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
else:
return self._generate_content(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)

async def generate_content_async(
Expand All @@ -385,6 +405,7 @@ async def generate_content_async(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
) -> Union["GenerationResponse", AsyncIterable["GenerationResponse"],]:
"""Generates content asynchronously.
Expand All @@ -400,6 +421,7 @@ async def generate_content_async(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
stream: Whether to stream the response.
Returns:
Expand All @@ -412,13 +434,15 @@ async def generate_content_async(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
else:
return await self._generate_content_async(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)

def _generate_content(
Expand All @@ -428,6 +452,7 @@ def _generate_content(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
) -> "GenerationResponse":
"""Generates content.
Expand All @@ -442,6 +467,7 @@ def _generate_content(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
Returns:
A single GenerationResponse object
Expand All @@ -451,6 +477,7 @@ def _generate_content(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
gapic_response = self._prediction_client.generate_content(request=request)
return self._parse_response(gapic_response)
Expand All @@ -462,6 +489,7 @@ async def _generate_content_async(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
) -> "GenerationResponse":
"""Generates content asynchronously.
Expand All @@ -476,6 +504,7 @@ async def _generate_content_async(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
Returns:
An awaitable for a single GenerationResponse object
Expand All @@ -485,6 +514,7 @@ async def _generate_content_async(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
gapic_response = await self._prediction_async_client.generate_content(
request=request
Expand All @@ -498,6 +528,7 @@ def _generate_content_streaming(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
) -> Iterable["GenerationResponse"]:
"""Generates content.
Expand All @@ -512,6 +543,7 @@ def _generate_content_streaming(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
Yields:
A stream of GenerationResponse objects
Expand All @@ -521,6 +553,7 @@ def _generate_content_streaming(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
response_stream = self._prediction_client.stream_generate_content(
request=request
Expand All @@ -535,6 +568,7 @@ async def _generate_content_streaming_async(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
) -> AsyncIterable["GenerationResponse"]:
"""Generates content asynchronously.
Expand All @@ -549,6 +583,7 @@ async def _generate_content_streaming_async(
generation_config: Parameters for the generation.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
Returns:
An awaitable for a stream of GenerationResponse objects
Expand All @@ -558,6 +593,7 @@ async def _generate_content_streaming_async(
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
)
response_stream = await self._prediction_async_client.stream_generate_content(
request=request
Expand Down Expand Up @@ -1312,6 +1348,77 @@ def __repr__(self) -> str:
return self._raw_tool.__repr__()


class ToolConfig:
r"""Config shared for all tools provided in the request.
Usage:
Create ToolConfig
```
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
))
```
Use ToolConfig in `GenerativeModel.generate_content`:
```
model = GenerativeModel("gemini-pro")
print(model.generate_content(
"What is the weather like in Boston?",
# You can specify tools when creating a model to avoid having to send them with every request.
tools=[weather_tool],
tool_config=tool_config,
))
```
Use ToolConfig in chat:
```
model = GenerativeModel(
"gemini-pro",
# You can specify tools when creating a model to avoid having to send them with every request.
tools=[weather_tool],
tool_config=tool_config,
)
chat = model.start_chat()
print(chat.send_message("What is the weather like in Boston?"))
print(chat.send_message(
Part.from_function_response(
name="get_current_weather",
response={
"content": {"weather_there": "super nice"},
}
),
))
```
"""

class FunctionCallingConfig:
Mode = gapic_tool_types.FunctionCallingConfig.Mode

def __init__(
self,
mode: "ToolConfig.FunctionCallingConfig.Mode",
allowed_function_names: Optional[List[str]] = None,
):
"""Constructs FunctionCallingConfig.
Args:
mode: Enum describing the function calling mode
allowed_function_names: A list of allowed function names
(must match from Tool). Only set when the Mode is ANY.
"""
self._gapic_function_calling_config = (
gapic_tool_types.FunctionCallingConfig(
mode=mode,
allowed_function_names=allowed_function_names,
)
)

def __init__(self, function_calling_config: "ToolConfig.FunctionCallingConfig"):
self._gapic_tool_config = gapic_tool_types.ToolConfig(
function_calling_config=function_calling_config._gapic_function_calling_config
)


class FunctionDeclaration:
r"""A representation of a function declaration.
Expand Down Expand Up @@ -1454,9 +1561,13 @@ def from_func(cls, func: Callable[..., Any]) -> "CallableFunctionDeclaration":
Returns:
CallableFunctionDeclaration.
"""
from vertexai.generative_models import _function_calling_utils
from vertexai.generative_models import (
_function_calling_utils,
)

function_schema = _function_calling_utils.generate_json_schema_from_function(func)
function_schema = _function_calling_utils.generate_json_schema_from_function(
func
)
# Getting out the description first since it will be removed from the schema.
function_description = function_schema["description"]
function_schema = (
Expand Down Expand Up @@ -1804,9 +1915,7 @@ def _from_gapic(

@classmethod
def from_dict(cls, safety_setting_dict: Dict[str, Any]) -> "SafetySetting":
raw_safety_setting = gapic_content_types.SafetySetting(
safety_setting_dict
)
raw_safety_setting = gapic_content_types.SafetySetting(safety_setting_dict)
return cls._from_gapic(raw_safety_setting=raw_safety_setting)

def to_dict(self) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ResponseValidationError,
SafetySetting,
Tool,
ToolConfig,
)


Expand Down Expand Up @@ -67,5 +68,6 @@ class ChatSession(_PreviewChatSession):
"ResponseValidationError",
"SafetySetting",
"Tool",
"ToolConfig",
#
]

0 comments on commit 806ef9f

Please sign in to comment.