Skip to content

Commit

Permalink
Add support for function calling (#116)
Browse files Browse the repository at this point in the history
- Add `FunctionCall` and `FunctionResponse` subclasses of the sealed
  class `Part`.
- Add a `Content.functionResponse` utility to cover the common case of
  constructing a content with a single function response part.
- Add `Tool`, `FunctionDeclaration`, `Schema`, and `SchemaType` types
  with `toJson()` support matching the REST API to declare the tools and
  the JSON schema for the function signatures.
- Add `tools` constructor argument to `GenerativeModel`.
- Add an example showing how authors can dispatch to Dart callbacks.
  • Loading branch information
natebosch authored Apr 10, 2024
1 parent 1c34637 commit cf0babc
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 13 deletions.
5 changes: 4 additions & 1 deletion pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
- Allow specifying an API version in a `requestOptions` argument when
constructing a model.
- Add support for referring to uploaded files in request contents.
- **Breaking** Added a new subclass `FilePart` of the sealed class `Part`.
- Add support for passing tools with functions the model may call while
generating responses.
- **Breaking** Added new subclasses `FilePart`, `FunctionCall`, and
`FunctionResponse` of the sealed class `Part`.

## 0.2.3

Expand Down
12 changes: 11 additions & 1 deletion pkgs/google_generative_ai/lib/google_generative_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,21 @@ export 'src/api.dart'
SafetySetting,
TaskType;
export 'src/chat.dart' show ChatSession, StartChatExtension;
export 'src/content.dart' show Content, DataPart, FilePart, Part, TextPart;
export 'src/content.dart'
show
Content,
DataPart,
FilePart,
FunctionCall,
FunctionResponse,
Part,
TextPart;
export 'src/error.dart'
show
GenerativeAIException,
InvalidApiKey,
ServerException,
UnsupportedUserLocation;
export 'src/function_calling.dart'
show FunctionDeclaration, Schema, SchemaType, Tool;
export 'src/model.dart' show GenerativeModel, RequestOptions;
51 changes: 51 additions & 0 deletions pkgs/google_generative_ai/lib/src/content.dart
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ final class Content {
Content('user', [DataPart(mimeType, bytes)]);
static Content multi(Iterable<Part> parts) => Content('user', [...parts]);
static Content model(Iterable<Part> parts) => Content('model', [...parts]);
static Content functionResponse(
String name, Map<String, Object?>? response) =>
Content('function', [FunctionResponse(name, response)]);

Map<String, Object?> toJson() => {
if (role case final role?) 'role': role,
Expand All @@ -57,6 +60,17 @@ Content parseContent(Object jsonObject) {
Part _parsePart(Object? jsonObject) {
return switch (jsonObject) {
{'text': final String text} => TextPart(text),
{
'functionCall': {
'name': final String name,
'args': final Map<String, Object?> args
}
} =>
FunctionCall(name, args),
{
'functionResponse': {'name': String _, 'response': Map<String, Object?> _}
} =>
throw UnimplementedError('FunctionResponse part not yet supported'),
{'inlineData': {'mimeType': String _, 'data': String _}} =>
throw UnimplementedError('inlineData content part not yet supported'),
_ => throw FormatException('Unhandled Part format', jsonObject),
Expand Down Expand Up @@ -97,3 +111,40 @@ final class FilePart implements Part {
'file_data': {'file_uri': '$uri'}
};
}

/// A predicted `FunctionCall` returned from the model that contains
/// a string representing the `FunctionDeclaration.name` with the
/// arguments and their values.
final class FunctionCall implements Part {
/// The name of the function to call.
final String name;

/// The function parameters and values.
final Map<String, Object?> args;

FunctionCall(this.name, this.args);

@override
// TODO: Do we need the wrapper object?
Object toJson() => {
'functionCall': {'name': name, 'args': args}
};
}

final class FunctionResponse implements Part {
/// The name of the function that was called.
final String name;

/// The function response.
///
/// The values must be JSON compatible types; `String`, `num`, `bool`, `List`
/// of JSON compatibles types, or `Map` from String to JSON compatible types.
final Map<String, Object?>? response;

FunctionResponse(this.name, this.response);

@override
Object toJson() => {
'functionResponse': {'name': name, 'response': response}
};
}
156 changes: 156 additions & 0 deletions pkgs/google_generative_ai/lib/src/function_calling.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import 'content.dart';

/// Tool details that the model may use to generate a response.
///
/// A `Tool` is a piece of code that enables the system to interact with
/// external systems to perform an action, or set of actions, outside of
/// knowledge and scope of the model.
final class Tool {
/// A list of `FunctionDeclarations` available to the model that can be used
/// for function calling.
///
/// The model or system does not execute the function. Instead the defined
/// function may be returned as a [FunctionCall] with arguments to the client
/// side for execution. The next conversation turn may contain a
/// [FunctionResponse]
/// with the role "function" generation context for the next model turn.
final List<FunctionDeclaration>? functionDeclarations;

Tool({this.functionDeclarations});

Map<String, Object> toJson() => {
if (functionDeclarations case final functionDeclarations?)
'functionDeclarations':
functionDeclarations.map((f) => f.toJson()).toList(),
};
}

/// Structured representation of a function declaration as defined by the
/// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3).
///
/// Included in this declaration are the function name and parameters. This
/// FunctionDeclaration is a representation of a block of code that can be used
/// as a `Tool` by the model and executed by the client.
final class FunctionDeclaration {
/// The name of the function.
///
/// Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum
/// length of 63.
final String name;

/// A brief description of the function.
final String description;

final Schema? parameters;

FunctionDeclaration(this.name, this.description, this.parameters);

Map<String, Object?> toJson() => {
'name': name,
'description': description,
if (parameters case final parameters?) 'parameters': parameters.toJson()
};
}

/// The definition of an input or output data types.
///
/// These types can be objects, but also primitives and arrays.
/// Represents a select subset of an
/// [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema).
final class Schema {
/// The type of this value.
SchemaType type;

/// The format of the data.
///
/// This is used only for primitive datatypes.
///
/// Supported formats:
/// for [SchemaType.number] type: float, double
/// for [SchemaType.integer] type: int32, int64
/// for [SchemaType.string] type: enum. See [enumValues]
String? format;

/// A brief description of the parameter.
///
/// This could contain examples of use.
/// Parameter description may be formatted as Markdown.
String? description;

/// Whether the value mey be null.
bool? nullable;

/// Possible values if this is a [SchemaType.string] with an enum format.
List<String>? enumValues;

/// Schema for the elements if this is a [SchemaType.array].
Schema? items;

/// Properties of this type if this is a [SchemaType.object].
Map<String, Schema>? properties;

/// The keys from [properties] for properties that are required if this is a
/// [SchemaType.object].
List<String>? requiredProperties;

// TODO: Add named constructors for the types?
Schema(
this.type, {
this.format,
this.description,
this.nullable,
this.enumValues,
this.items,
this.properties,
this.requiredProperties,
});

Map<String, Object> toJson() => {
'type': type.toJson(),
if (format case final format?) 'format': format,
if (description case final description?) 'description': description,
if (nullable case final nullable?) 'nullable': nullable,
if (enumValues case final enumValues?) 'enum': enumValues,
if (items case final items?) 'items': items.toJson(),
if (properties case final properties?)
'properties': {
for (final MapEntry(:key, :value) in properties.entries)
key: value.toJson()
},
if (requiredProperties case final requiredProperties?)
'required': requiredProperties
};
}

/// The value type of a [Schema].
enum SchemaType {
string,
number,
integer,
boolean,
array,
object;

String toJson() => switch (this) {
string => 'STRING',
number => 'NUMBER',
integer => 'INTEGER',
boolean => 'BOOLEAN',
array => 'ARRAY',
object => 'OBJECT',
};
}
41 changes: 30 additions & 11 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import 'package:http/http.dart' as http;
import 'api.dart';
import 'client.dart';
import 'content.dart';
import 'function_calling.dart';

const _apiVersion = 'v1';
Uri _googleAIBaseUri(RequestOptions? options) => Uri.https(
Expand Down Expand Up @@ -55,6 +56,7 @@ final class GenerativeModel {
final ({String prefix, String name}) _model;
final List<SafetySetting> _safetySettings;
final GenerationConfig? _generationConfig;
final List<Tool>? _tools;
final ApiClient _client;
final Uri _baseUri;

Expand All @@ -80,11 +82,16 @@ final class GenerativeModel {
/// concurrent requests.
/// If the `httpClient` is omitted, a new [http.Client] is created for each
/// request.
///
/// Functions that the model may call while generating content can be passed
/// in [tools]. When using tools the [requestOptions] must be passed to
/// override the `apiVersion` to `v1beta`.
factory GenerativeModel({
required String model,
required String apiKey,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
List<Tool>? tools,
http.Client? httpClient,
RequestOptions? requestOptions,
}) =>
Expand All @@ -94,6 +101,7 @@ final class GenerativeModel {
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
);

GenerativeModel._withClient({
Expand All @@ -102,10 +110,12 @@ final class GenerativeModel {
required List<SafetySetting> safetySettings,
required GenerationConfig? generationConfig,
required Uri baseUri,
required List<Tool>? tools,
}) : _model = _normalizeModelName(model),
_baseUri = baseUri,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_tools = tools,
_client = client;

/// Returns the model code for a user friendly model name.
Expand Down Expand Up @@ -143,6 +153,8 @@ final class GenerativeModel {
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig case final config?)
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
};
final response =
await _client.makeRequest(_taskUri(Task.generateContent), parameters);
Expand Down Expand Up @@ -173,6 +185,8 @@ final class GenerativeModel {
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig case final config?)
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
};
final response =
_client.streamRequest(_taskUri(Task.streamGenerateContent), parameters);
Expand Down Expand Up @@ -263,13 +277,15 @@ GenerativeModel createModelWithClient({
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
List<Tool>? tools,
}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
);

/// Creates a model with an overridden base URL to communicate with a different
Expand All @@ -278,15 +294,18 @@ GenerativeModel createModelWithClient({
/// Used from a `src/` import in the Vertex AI SDK.
// TODO: https://github.com/google/generative-ai-dart/issues/111 - Changes to
// this API need to be coordinated with the vertex AI SDK.
GenerativeModel createModelWithBaseUri(
{required String model,
required String apiKey,
required Uri baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig}) =>
GenerativeModel createModelWithBaseUri({
required String model,
required String apiKey,
required Uri baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri);
client: HttpApiClient(apiKey: apiKey),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri,
tools: null,
);
11 changes: 11 additions & 0 deletions pkgs/google_generative_ai/test/utils/matchers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ Matcher matchesPart(Part part) => switch (part) {
// TODO: When updating min SDK remove ignore.
// ignore: unused_result, implementation bug
.having((p) => p.uri, 'uri', uri),
FunctionCall(name: final name, args: final args) => isA<FunctionCall>()
.having((p) => p.name, 'name', name)
// TODO: When updating min SDK remove ignore.
// ignore: unused_result, implementation bug
.having((p) => p.args, 'args', args),
FunctionResponse(name: final name, response: final response) =>
isA<FunctionResponse>()
.having((p) => p.name, 'name', name)
// TODO: When updating min SDK remove ignore.
// ignore: unused_result, implementation bug
.having((p) => p.response, 'args', response),
};

Matcher matchesContent(Content content) => isA<Content>()
Expand Down
Loading

0 comments on commit cf0babc

Please sign in to comment.