From 95b2cfb03e3688ccf70cd66bee687d04c667933c Mon Sep 17 00:00:00 2001 From: Nate Bosch Date: Tue, 16 Apr 2024 17:25:25 -0700 Subject: [PATCH] Implement constrained function calling (#121) Function calling configuration allows specifying that the model must, or must not, call a function on the current response. By default the model is allowed to choose whether to call a function or reply in natural language. Add a `FunctionCallingConfig` class and a `functionCallingConfig` argument to the `GenerativeModel` constructor and the methods for generating content. The classes are fields match the API model directly. Add support for overriding the `tools` for a specific generate content call. --- pkgs/google_generative_ai/CHANGELOG.md | 7 + .../lib/google_generative_ai.dart | 9 +- .../lib/src/function_calling.dart | 58 +++++++ pkgs/google_generative_ai/lib/src/model.dart | 68 ++++++--- .../google_generative_ai/lib/src/version.dart | 2 +- pkgs/google_generative_ai/pubspec.yaml | 2 +- .../test/generative_model_test.dart | 144 ++++++++++++++++++ 7 files changed, 269 insertions(+), 21 deletions(-) diff --git a/pkgs/google_generative_ai/CHANGELOG.md b/pkgs/google_generative_ai/CHANGELOG.md index 700f567..f8a0a16 100644 --- a/pkgs/google_generative_ai/CHANGELOG.md +++ b/pkgs/google_generative_ai/CHANGELOG.md @@ -1,3 +1,10 @@ +## 0.3.1-wip + +- Add support on content generating methods for overriding "tools" passed when + the generative model was instantiated. +- Add support for forcing the model to use or not use function calls to generate + content. + ## 0.3.0 - Allow specifying an API version in a `requestOptions` argument when diff --git a/pkgs/google_generative_ai/lib/google_generative_ai.dart b/pkgs/google_generative_ai/lib/google_generative_ai.dart index 2b7c4d9..545fc12 100644 --- a/pkgs/google_generative_ai/lib/google_generative_ai.dart +++ b/pkgs/google_generative_ai/lib/google_generative_ai.dart @@ -75,5 +75,12 @@ export 'src/error.dart' ServerException, UnsupportedUserLocation; export 'src/function_calling.dart' - show FunctionDeclaration, Schema, SchemaType, Tool; + show + FunctionCallingConfig, + FunctionCallingMode, + FunctionDeclaration, + Schema, + SchemaType, + Tool, + ToolConfig; export 'src/model.dart' show GenerativeModel, RequestOptions; diff --git a/pkgs/google_generative_ai/lib/src/function_calling.dart b/pkgs/google_generative_ai/lib/src/function_calling.dart index 7f6a429..fa4a725 100644 --- a/pkgs/google_generative_ai/lib/src/function_calling.dart +++ b/pkgs/google_generative_ai/lib/src/function_calling.dart @@ -66,6 +66,64 @@ final class FunctionDeclaration { }; } +final class ToolConfig { + final FunctionCallingConfig? functionCallingConfig; + ToolConfig({this.functionCallingConfig}); + + Map toJson() => { + if (functionCallingConfig case final config?) + 'functionCallingConfig': config.toJson(), + }; +} + +/// Configuration specifying how the model should use the functions provided as +/// tools. +final class FunctionCallingConfig { + /// The mode in which function calling should execute. + /// + /// If null, the default behavior will match [FunctionCallingMode.auto]. + final FunctionCallingMode? mode; + + /// A set of function names that, when provided, limits the functions the + /// model will call. + /// + /// This should only be set when the Mode is [FunctionCallingMode.any]. + /// Function names should match [FunctionDeclaration.name]. With mode set to + /// `any`, model will predict a function call from the set of function names + /// provided. + final Set? allowedFunctionNames; + FunctionCallingConfig({this.mode, this.allowedFunctionNames}); + + Object toJson() => { + if (mode case final mode?) 'mode': mode.toJson(), + if (allowedFunctionNames case final allowedFunctionNames?) + 'allowedFunctionNames': allowedFunctionNames.toList(), + }; +} + +enum FunctionCallingMode { + /// The mode with default model behavior. + /// + /// Model decides to predict either a function call or a natural language + /// repspose. + auto, + + /// A mode where the Model is constrained to always predicting a function + /// call only. + any, + + /// A mode where the model will not predict any function call. + /// + /// Model behavior is same as when not passing any function declarations. + none; + + String toJson() => switch (this) { + auto => 'AUTO', + any => 'ANY', + none => 'NONE', + }; +} + /// The definition of an input or output data types. /// /// These types can be objects, but also primitives and arrays. diff --git a/pkgs/google_generative_ai/lib/src/model.dart b/pkgs/google_generative_ai/lib/src/model.dart index e7392df..aaed822 100644 --- a/pkgs/google_generative_ai/lib/src/model.dart +++ b/pkgs/google_generative_ai/lib/src/model.dart @@ -60,6 +60,7 @@ final class GenerativeModel { final ApiClient _client; final Uri _baseUri; final Content? _systemInstruction; + final ToolConfig? _toolConfig; /// Create a [GenerativeModel] backed by the generative model named [model]. /// @@ -86,7 +87,10 @@ final class GenerativeModel { /// /// Functions that the model may call while generating content can be passed /// in [tools]. When using tools [requestOptions] must be passed to - /// override the `apiVersion` to `v1beta`. + /// override the `apiVersion` to `v1beta`. Tool usage by the model can be + /// configured with [toolConfig]. Tools and tool configuration can be + /// overridden for individual requests with arguments to [generateContent] or + /// [generateContentStream]. /// /// A [Content.system] can be passed to [systemInstruction] to give /// high priority instructions to the model. When using system instructions @@ -100,6 +104,7 @@ final class GenerativeModel { http.Client? httpClient, RequestOptions? requestOptions, Content? systemInstruction, + ToolConfig? toolConfig, }) => GenerativeModel._withClient( client: HttpApiClient(apiKey: apiKey, httpClient: httpClient), @@ -109,6 +114,7 @@ final class GenerativeModel { baseUri: _googleAIBaseUri(requestOptions), tools: tools, systemInstruction: systemInstruction, + toolConfig: toolConfig, ); GenerativeModel._withClient({ @@ -119,12 +125,14 @@ final class GenerativeModel { required Uri baseUri, required List? tools, required Content? systemInstruction, + required ToolConfig? toolConfig, }) : _model = _normalizeModelName(model), _baseUri = baseUri, _safetySettings = safetySettings, _generationConfig = generationConfig, _tools = tools, _systemInstruction = systemInstruction, + _toolConfig = toolConfig, _client = client; /// Returns the model code for a user friendly model name. @@ -146,24 +154,35 @@ final class GenerativeModel { /// Sends a "generateContent" API request for the configured model, /// and waits for the response. /// + /// The [safetySettings], [generationConfig], [tools], and [toolConfig], + /// override the arguments of the same name passed to the + /// [GenerativeModel.new] constructor. Each argument, when non-null, + /// overrides the model level configuration in its entirety. + /// /// Example: /// ```dart /// final response = await model.generateContent([Content.text(prompt)]); /// print(response.text); /// ``` - Future generateContent(Iterable prompt, - {List? safetySettings, - GenerationConfig? generationConfig}) async { + Future generateContent( + Iterable prompt, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) async { safetySettings ??= _safetySettings; generationConfig ??= _generationConfig; + tools ??= _tools; + toolConfig ??= _toolConfig; final parameters = { 'contents': prompt.map((p) => p.toJson()).toList(), if (safetySettings.isNotEmpty) 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), - if (generationConfig case final config?) - 'generationConfig': config.toJson(), - if (_tools case final tools?) - 'tools': tools.map((t) => t.toJson()).toList(), + if (generationConfig != null) + 'generationConfig': generationConfig.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (toolConfig != null) 'toolConfig': toolConfig.toJson(), if (_systemInstruction case final systemInstruction?) 'systemInstruction': systemInstruction.toJson(), }; @@ -177,6 +196,11 @@ final class GenerativeModel { /// Sends a "streamGenerateContent" API request for the configured model, /// and waits for the response. /// + /// The [safetySettings], [generationConfig], [tools], and [toolConfig], + /// override the arguments of the same name passed to the + /// [GenerativeModel.new] constructor. Each argument, when non-null, + /// overrides the model level configuration in its entirety. + /// /// Example: /// ```dart /// final responses = await model.generateContent([Content.text(prompt)]); @@ -185,19 +209,24 @@ final class GenerativeModel { /// } /// ``` Stream generateContentStream( - Iterable prompt, - {List? safetySettings, - GenerationConfig? generationConfig}) { + Iterable prompt, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) { safetySettings ??= _safetySettings; generationConfig ??= _generationConfig; + tools ??= _tools; + toolConfig ??= _toolConfig; final parameters = { 'contents': prompt.map((p) => p.toJson()).toList(), if (safetySettings.isNotEmpty) 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), - if (generationConfig case final config?) - 'generationConfig': config.toJson(), - if (_tools case final tools?) - 'tools': tools.map((t) => t.toJson()).toList(), + if (generationConfig != null) + 'generationConfig': generationConfig.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (toolConfig != null) 'toolConfig': toolConfig.toJson(), if (_systemInstruction case final systemInstruction?) 'systemInstruction': systemInstruction.toJson(), }; @@ -290,8 +319,9 @@ GenerativeModel createModelWithClient({ List safetySettings = const [], GenerationConfig? generationConfig, RequestOptions? requestOptions, - List? tools, Content? systemInstruction, + List? tools, + ToolConfig? toolConfig, }) => GenerativeModel._withClient( client: client, @@ -299,8 +329,9 @@ GenerativeModel createModelWithClient({ safetySettings: safetySettings, generationConfig: generationConfig, baseUri: _googleAIBaseUri(requestOptions), - tools: tools, systemInstruction: systemInstruction, + tools: tools, + toolConfig: toolConfig, ); /// Creates a model with an overridden base URL to communicate with a different @@ -323,6 +354,7 @@ GenerativeModel createModelWithBaseUri({ safetySettings: safetySettings, generationConfig: generationConfig, baseUri: baseUri, - tools: null, systemInstruction: systemInstruction, + tools: null, + toolConfig: null, ); diff --git a/pkgs/google_generative_ai/lib/src/version.dart b/pkgs/google_generative_ai/lib/src/version.dart index 13c2dd9..9f231bf 100644 --- a/pkgs/google_generative_ai/lib/src/version.dart +++ b/pkgs/google_generative_ai/lib/src/version.dart @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -const packageVersion = '0.3.0'; +const packageVersion = '0.3.1-wip'; diff --git a/pkgs/google_generative_ai/pubspec.yaml b/pkgs/google_generative_ai/pubspec.yaml index 0c795b2..735b76b 100644 --- a/pkgs/google_generative_ai/pubspec.yaml +++ b/pkgs/google_generative_ai/pubspec.yaml @@ -1,6 +1,6 @@ name: google_generative_ai # Update `lib/version.dart` when changing version. -version: 0.3.0 +version: 0.3.1-wip description: >- The Google AI Dart SDK enables developers to use Google's state-of-the-art generative AI models (like Gemini). diff --git a/pkgs/google_generative_ai/test/generative_model_test.dart b/pkgs/google_generative_ai/test/generative_model_test.dart index 26b4412..e526916 100644 --- a/pkgs/google_generative_ai/test/generative_model_test.dart +++ b/pkgs/google_generative_ai/test/generative_model_test.dart @@ -27,6 +27,8 @@ void main() { String modelName = defaultModelName, RequestOptions? requestOptions, Content? systemInstruction, + List? tools, + ToolConfig? toolConfig, }) { final client = StubClient(); final model = createModelWithClient( @@ -34,6 +36,8 @@ void main() { client: client, requestOptions: requestOptions, systemInstruction: systemInstruction, + tools: tools, + toolConfig: toolConfig, ); return (client, model); } @@ -338,6 +342,146 @@ void main() { Content('model', [TextPart(result)]), null, null, null, null), ], null))); }); + + test('can pass tools and function calling config', () async { + final (client, model) = createModel( + tools: [ + Tool(functionDeclarations: [ + FunctionDeclaration('someFunction', 'Some cool function.', + Schema(SchemaType.string, description: 'Some parameter.')) + ]) + ], + toolConfig: ToolConfig( + functionCallingConfig: FunctionCallingConfig( + mode: FunctionCallingMode.any, + allowedFunctionNames: {'someFunction'}))); + final prompt = 'Some prompt'; + final result = 'Some response'; + client.stub( + Uri.parse('https://generativelanguage.googleapis.com/v1/' + 'models/some-model:generateContent'), + { + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt} + ] + } + ], + 'tools': [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'STRING', + 'description': 'Some parameter.' + } + } + ] + } + ], + 'toolConfig': { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + } + }, + }, + { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': result} + ] + } + } + ] + }, + ); + final response = await model.generateContent([Content.text(prompt)]); + expect( + response, + matchesGenerateContentResponse(GenerateContentResponse([ + Candidate( + Content('model', [TextPart(result)]), null, null, null, null), + ], null))); + }); + + test('can override tools and function calling config', () async { + final (client, model) = createModel(); + final prompt = 'Some prompt'; + final result = 'Some response'; + client.stub( + Uri.parse('https://generativelanguage.googleapis.com/v1/' + 'models/some-model:generateContent'), + { + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt} + ] + } + ], + 'tools': [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'STRING', + 'description': 'Some parameter.' + } + } + ] + } + ], + 'toolConfig': { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + } + }, + }, + { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': result} + ] + } + } + ] + }, + ); + final response = await model.generateContent([ + Content.text(prompt) + ], + tools: [ + Tool(functionDeclarations: [ + FunctionDeclaration('someFunction', 'Some cool function.', + Schema(SchemaType.string, description: 'Some parameter.')) + ]) + ], + toolConfig: ToolConfig( + functionCallingConfig: FunctionCallingConfig( + mode: FunctionCallingMode.any, + allowedFunctionNames: {'someFunction'}))); + expect( + response, + matchesGenerateContentResponse(GenerateContentResponse([ + Candidate( + Content('model', [TextPart(result)]), null, null, null, null), + ], null))); + }); }); group('generate content stream', () {