Skip to content

Commit

Permalink
Implement constrained function calling (#121)
Browse files Browse the repository at this point in the history
Function calling configuration allows specifying that the model must, or
must not, call a function on the current response. By default the model
is allowed to choose whether to call a function or reply in natural
language.

Add a `FunctionCallingConfig` class and a `functionCallingConfig`
argument to the `GenerativeModel` constructor and the methods for
generating content. The classes are fields match the API model directly.

Add support for overriding the `tools` for a specific generate content
call.
  • Loading branch information
natebosch authored Apr 17, 2024
1 parent fe0ffa8 commit 95b2cfb
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 21 deletions.
7 changes: 7 additions & 0 deletions pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## 0.3.1-wip

- Add support on content generating methods for overriding "tools" passed when
the generative model was instantiated.
- Add support for forcing the model to use or not use function calls to generate
content.

## 0.3.0

- Allow specifying an API version in a `requestOptions` argument when
Expand Down
9 changes: 8 additions & 1 deletion pkgs/google_generative_ai/lib/google_generative_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,12 @@ export 'src/error.dart'
ServerException,
UnsupportedUserLocation;
export 'src/function_calling.dart'
show FunctionDeclaration, Schema, SchemaType, Tool;
show
FunctionCallingConfig,
FunctionCallingMode,
FunctionDeclaration,
Schema,
SchemaType,
Tool,
ToolConfig;
export 'src/model.dart' show GenerativeModel, RequestOptions;
58 changes: 58 additions & 0 deletions pkgs/google_generative_ai/lib/src/function_calling.dart
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,64 @@ final class FunctionDeclaration {
};
}

final class ToolConfig {
final FunctionCallingConfig? functionCallingConfig;
ToolConfig({this.functionCallingConfig});

Map<String, Object?> toJson() => {
if (functionCallingConfig case final config?)
'functionCallingConfig': config.toJson(),
};
}

/// Configuration specifying how the model should use the functions provided as
/// tools.
final class FunctionCallingConfig {
/// The mode in which function calling should execute.
///
/// If null, the default behavior will match [FunctionCallingMode.auto].
final FunctionCallingMode? mode;

/// A set of function names that, when provided, limits the functions the
/// model will call.
///
/// This should only be set when the Mode is [FunctionCallingMode.any].
/// Function names should match [FunctionDeclaration.name]. With mode set to
/// `any`, model will predict a function call from the set of function names
/// provided.
final Set<String>? allowedFunctionNames;
FunctionCallingConfig({this.mode, this.allowedFunctionNames});

Object toJson() => {
if (mode case final mode?) 'mode': mode.toJson(),
if (allowedFunctionNames case final allowedFunctionNames?)
'allowedFunctionNames': allowedFunctionNames.toList(),
};
}

enum FunctionCallingMode {
/// The mode with default model behavior.
///
/// Model decides to predict either a function call or a natural language
/// repspose.
auto,

/// A mode where the Model is constrained to always predicting a function
/// call only.
any,

/// A mode where the model will not predict any function call.
///
/// Model behavior is same as when not passing any function declarations.
none;

String toJson() => switch (this) {
auto => 'AUTO',
any => 'ANY',
none => 'NONE',
};
}

/// The definition of an input or output data types.
///
/// These types can be objects, but also primitives and arrays.
Expand Down
68 changes: 50 additions & 18 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ final class GenerativeModel {
final ApiClient _client;
final Uri _baseUri;
final Content? _systemInstruction;
final ToolConfig? _toolConfig;

/// Create a [GenerativeModel] backed by the generative model named [model].
///
Expand All @@ -86,7 +87,10 @@ final class GenerativeModel {
///
/// Functions that the model may call while generating content can be passed
/// in [tools]. When using tools [requestOptions] must be passed to
/// override the `apiVersion` to `v1beta`.
/// override the `apiVersion` to `v1beta`. Tool usage by the model can be
/// configured with [toolConfig]. Tools and tool configuration can be
/// overridden for individual requests with arguments to [generateContent] or
/// [generateContentStream].
///
/// A [Content.system] can be passed to [systemInstruction] to give
/// high priority instructions to the model. When using system instructions
Expand All @@ -100,6 +104,7 @@ final class GenerativeModel {
http.Client? httpClient,
RequestOptions? requestOptions,
Content? systemInstruction,
ToolConfig? toolConfig,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
Expand All @@ -109,6 +114,7 @@ final class GenerativeModel {
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
systemInstruction: systemInstruction,
toolConfig: toolConfig,
);

GenerativeModel._withClient({
Expand All @@ -119,12 +125,14 @@ final class GenerativeModel {
required Uri baseUri,
required List<Tool>? tools,
required Content? systemInstruction,
required ToolConfig? toolConfig,
}) : _model = _normalizeModelName(model),
_baseUri = baseUri,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
_toolConfig = toolConfig,
_client = client;

/// Returns the model code for a user friendly model name.
Expand All @@ -146,24 +154,35 @@ final class GenerativeModel {
/// Sends a "generateContent" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final response = await model.generateContent([Content.text(prompt)]);
/// print(response.text);
/// ```
Future<GenerateContentResponse> generateContent(Iterable<Content> prompt,
{List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig}) async {
Future<GenerateContentResponse> generateContent(
Iterable<Content> prompt, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) async {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
final parameters = {
'contents': prompt.map((p) => p.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig case final config?)
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
Expand All @@ -177,6 +196,11 @@ final class GenerativeModel {
/// Sends a "streamGenerateContent" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final responses = await model.generateContent([Content.text(prompt)]);
Expand All @@ -185,19 +209,24 @@ final class GenerativeModel {
/// }
/// ```
Stream<GenerateContentResponse> generateContentStream(
Iterable<Content> prompt,
{List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig}) {
Iterable<Content> prompt, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
final parameters = <String, Object?>{
'contents': prompt.map((p) => p.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig case final config?)
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
Expand Down Expand Up @@ -290,17 +319,19 @@ GenerativeModel createModelWithClient({
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
List<Tool>? tools,
Content? systemInstruction,
List<Tool>? tools,
ToolConfig? toolConfig,
}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
systemInstruction: systemInstruction,
tools: tools,
toolConfig: toolConfig,
);

/// Creates a model with an overridden base URL to communicate with a different
Expand All @@ -323,6 +354,7 @@ GenerativeModel createModelWithBaseUri({
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri,
tools: null,
systemInstruction: systemInstruction,
tools: null,
toolConfig: null,
);
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/lib/src/version.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.

const packageVersion = '0.3.0';
const packageVersion = '0.3.1-wip';
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: google_generative_ai
# Update `lib/version.dart` when changing version.
version: 0.3.0
version: 0.3.1-wip
description: >-
The Google AI Dart SDK enables developers to use Google's state-of-the-art
generative AI models (like Gemini).
Expand Down
Loading

0 comments on commit 95b2cfb

Please sign in to comment.