-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
kernel_function_from_method.py
187 lines (169 loc) · 8.76 KB
/
kernel_function_from_method.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright (c) Microsoft. All rights reserved.
import inspect
import logging
from collections.abc import Callable
from inspect import isasyncgen, isasyncgenfunction, isawaitable, iscoroutinefunction, isgenerator, isgeneratorfunction
from typing import Any
from pydantic import ValidationError
from semantic_kernel.exceptions import FunctionExecutionException, FunctionInitializationError
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.functions.kernel_function import KernelFunction
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
logger: logging.Logger = logging.getLogger(__name__)
class KernelFunctionFromMethod(KernelFunction):
"""Semantic Kernel Function from a method."""
method: Callable[..., Any]
stream_method: Callable[..., Any] | None = None
def __init__(
self,
method: Callable[..., Any],
plugin_name: str | None = None,
stream_method: Callable[..., Any] | None = None,
parameters: list[KernelParameterMetadata] | None = None,
return_parameter: KernelParameterMetadata | None = None,
additional_metadata: dict[str, Any] | None = None,
) -> None:
"""Initializes a new instance of the KernelFunctionFromMethod class.
Args:
method (Callable[..., Any]): The method to be called
plugin_name (str | None): The name of the plugin
stream_method (Callable[..., Any] | None): The stream method for the function
parameters (list[KernelParameterMetadata] | None): The parameters of the function
return_parameter (KernelParameterMetadata | None): The return parameter of the function
additional_metadata (dict[str, Any] | None): Additional metadata for the function
"""
if method is None:
raise FunctionInitializationError("Method cannot be `None`")
if not hasattr(method, "__kernel_function__") or method.__kernel_function__ is None:
raise FunctionInitializationError("Method is not a Kernel function")
# all these fields are created when the kernel function decorator is used,
# so no need to check before using, will raise an exception if not set
function_name = method.__kernel_function_name__ # type: ignore
description = method.__kernel_function_description__ # type: ignore
if parameters is None:
parameters = [KernelParameterMetadata(**param) for param in method.__kernel_function_parameters__] # type: ignore
if return_parameter is None:
return_parameter = KernelParameterMetadata(
name="return",
description=method.__kernel_function_return_description__, # type: ignore
default_value=None,
type_=method.__kernel_function_return_type__, # type: ignore
type_object=method.__kernel_function_return_type_object__, # type: ignore
is_required=method.__kernel_function_return_required__, # type: ignore
)
try:
metadata = KernelFunctionMetadata(
name=function_name,
description=description,
parameters=parameters,
return_parameter=return_parameter,
is_prompt=False,
is_asynchronous=isasyncgenfunction(method) or iscoroutinefunction(method),
plugin_name=plugin_name,
additional_properties=additional_metadata if additional_metadata is not None else {},
)
except ValidationError as exc:
# reraise the exception to clarify it comes from KernelFunction init
raise FunctionInitializationError("Failed to create KernelFunctionMetadata") from exc
args: dict[str, Any] = {
"metadata": metadata,
"method": method,
"stream_method": (
stream_method
if stream_method is not None
else method
if isasyncgenfunction(method) or isgeneratorfunction(method)
else None
),
}
super().__init__(**args)
async def _invoke_internal(
self,
context: FunctionInvocationContext,
) -> None:
"""Invoke the function with the given arguments."""
function_arguments = self.gather_function_parameters(context)
result = self.method(**function_arguments)
if isasyncgen(result):
result = [x async for x in result]
elif isawaitable(result):
result = await result
elif isgenerator(result):
result = list(result)
if not isinstance(result, FunctionResult):
result = FunctionResult(
function=self.metadata,
value=result,
metadata={"arguments": context.arguments, "used_arguments": function_arguments},
)
context.result = result
async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None:
if self.stream_method is None:
raise NotImplementedError("Stream method not implemented")
function_arguments = self.gather_function_parameters(context)
context.result = FunctionResult(function=self.metadata, value=self.stream_method(**function_arguments))
def _parse_parameter(self, value: Any, param_type: Any) -> Any:
"""Parses the value into the specified param_type, including handling lists of types."""
if isinstance(param_type, type) and hasattr(param_type, "model_validate"):
try:
return param_type.model_validate(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter is expected to be parsed to {param_type} but is not."
) from exc
elif hasattr(param_type, "__origin__") and param_type.__origin__ is list:
if isinstance(value, list):
item_type = param_type.__args__[0]
return [self._parse_parameter(item, item_type) for item in value]
raise FunctionExecutionException(f"Expected a list for {param_type}, but got {type(value)}")
else:
try:
if isinstance(value, dict) and hasattr(param_type, "__init__"):
return param_type(**value)
return param_type(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter is expected to be parsed to {param_type} but is not."
) from exc
def gather_function_parameters(self, context: FunctionInvocationContext) -> dict[str, Any]:
"""Gathers the function parameters from the arguments."""
function_arguments: dict[str, Any] = {}
for param in self.parameters:
if param.name is None:
raise FunctionExecutionException("Parameter name cannot be None")
if param.name == "kernel":
function_arguments[param.name] = context.kernel
continue
if param.name == "service":
function_arguments[param.name] = context.kernel.select_ai_service(self, context.arguments)[0]
continue
if param.name == "execution_settings":
function_arguments[param.name] = context.kernel.select_ai_service(self, context.arguments)[1]
continue
if param.name == "arguments":
function_arguments[param.name] = context.arguments
continue
if param.name in context.arguments:
value: Any = context.arguments[param.name]
if (
param.type_
and "," not in param.type_
and param.type_object
and param.type_object is not inspect._empty
):
try:
value = self._parse_parameter(value, param.type_object)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter {param.name} is expected to be parsed to {param.type_object} but is not."
) from exc
function_arguments[param.name] = value
continue
if param.is_required:
raise FunctionExecutionException(
f"Parameter {param.name} is required but not provided in the arguments."
)
logger.debug(f"Parameter {param.name} is not provided, using default value {param.default_value}")
return function_arguments