Skip to content

Commit

Permalink
Add support for system instructions (#119)
Browse files Browse the repository at this point in the history
- Add a `systemInstruction` constructor argument on `GenerativeModel`.
  Accept the `Content` type instead of a String to directly match the
  REST API and to avoid breaking changes if the backend gains the
  flexibility to honor more than a single TextPart for system
  instructions.
- Add a `Content.system` utility with the single string argument to
  match the usage which works with the backend today.
  • Loading branch information
natebosch authored Apr 11, 2024
1 parent af723c8 commit dbc812b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 7 deletions.
1 change: 1 addition & 0 deletions pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Add support for referring to uploaded files in request contents.
- Add support for passing tools with functions the model may call while
generating responses.
- Add support for passing a system instruction when creating the model.
- **Breaking** Added new subclasses `FilePart`, `FunctionCall`, and
`FunctionResponse` of the sealed class `Part`.

Expand Down
2 changes: 2 additions & 0 deletions pkgs/google_generative_ai/lib/src/content.dart
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ final class Content {
static Content functionResponse(
String name, Map<String, Object?>? response) =>
Content('function', [FunctionResponse(name, response)]);
static Content system(String instructions) =>
Content('system', [TextPart(instructions)]);

Map<String, Object?> toJson() => {
if (role case final role?) 'role': role,
Expand Down
19 changes: 18 additions & 1 deletion pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ final class GenerativeModel {
final List<Tool>? _tools;
final ApiClient _client;
final Uri _baseUri;
final Content? _systemInstruction;

/// Create a [GenerativeModel] backed by the generative model named [model].
///
Expand All @@ -84,8 +85,12 @@ final class GenerativeModel {
/// request.
///
/// Functions that the model may call while generating content can be passed
/// in [tools]. When using tools the [requestOptions] must be passed to
/// in [tools]. When using tools [requestOptions] must be passed to
/// override the `apiVersion` to `v1beta`.
///
/// A [Content.system] can be passed to [systemInstruction] to give
/// high priority instructions to the model. When using system instructions
/// [requestOptions] must be passed to override the `apiVersion` to `v1beta`.
factory GenerativeModel({
required String model,
required String apiKey,
Expand All @@ -94,6 +99,7 @@ final class GenerativeModel {
List<Tool>? tools,
http.Client? httpClient,
RequestOptions? requestOptions,
Content? systemInstruction,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
Expand All @@ -102,6 +108,7 @@ final class GenerativeModel {
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
systemInstruction: systemInstruction,
);

GenerativeModel._withClient({
Expand All @@ -111,11 +118,13 @@ final class GenerativeModel {
required GenerationConfig? generationConfig,
required Uri baseUri,
required List<Tool>? tools,
required Content? systemInstruction,
}) : _model = _normalizeModelName(model),
_baseUri = baseUri,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
_client = client;

/// Returns the model code for a user friendly model name.
Expand Down Expand Up @@ -155,6 +164,8 @@ final class GenerativeModel {
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
final response =
await _client.makeRequest(_taskUri(Task.generateContent), parameters);
Expand Down Expand Up @@ -187,6 +198,8 @@ final class GenerativeModel {
'generationConfig': config.toJson(),
if (_tools case final tools?)
'tools': tools.map((t) => t.toJson()).toList(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
final response =
_client.streamRequest(_taskUri(Task.streamGenerateContent), parameters);
Expand Down Expand Up @@ -278,6 +291,7 @@ GenerativeModel createModelWithClient({
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
List<Tool>? tools,
Content? systemInstruction,
}) =>
GenerativeModel._withClient(
client: client,
Expand All @@ -286,6 +300,7 @@ GenerativeModel createModelWithClient({
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
systemInstruction: systemInstruction,
);

/// Creates a model with an overridden base URL to communicate with a different
Expand All @@ -300,6 +315,7 @@ GenerativeModel createModelWithBaseUri({
required Uri baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
Content? systemInstruction,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey),
Expand All @@ -308,4 +324,5 @@ GenerativeModel createModelWithBaseUri({
generationConfig: generationConfig,
baseUri: baseUri,
tools: null,
systemInstruction: systemInstruction,
);
70 changes: 64 additions & 6 deletions pkgs/google_generative_ai/test/generative_model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,24 @@ void main() {
group('GenerativeModel', () {
const defaultModelName = 'some-model';

(StubClient, GenerativeModel) createModel(
[String modelName = defaultModelName, RequestOptions? requestOptions]) {
(StubClient, GenerativeModel) createModel({
String modelName = defaultModelName,
RequestOptions? requestOptions,
Content? systemInstruction,
}) {
final client = StubClient();
final model = createModelWithClient(
model: modelName, client: client, requestOptions: requestOptions);
model: modelName,
client: client,
requestOptions: requestOptions,
systemInstruction: systemInstruction,
);
return (client, model);
}

test('strips leading "models/" from model name', () async {
final (client, model) = createModel('models/$defaultModelName');
final (client, model) =
createModel(modelName: 'models/$defaultModelName');
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Expand Down Expand Up @@ -71,7 +79,8 @@ void main() {
});

test('allows specifying a tuned model', () async {
final (client, model) = createModel('tunedModels/$defaultModelName');
final (client, model) =
createModel(modelName: 'tunedModels/$defaultModelName');
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Expand Down Expand Up @@ -111,7 +120,7 @@ void main() {

test('allows specifying an API version', () async {
final (client, model) = createModel(
defaultModelName, RequestOptions(apiVersion: 'override_version'));
requestOptions: RequestOptions(apiVersion: 'override_version'));
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Expand Down Expand Up @@ -280,6 +289,55 @@ void main() {
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
});

test('can pass system instructions', () async {
final instructions = 'Do a good job';
final (client, model) =
createModel(systemInstruction: Content.system(instructions));
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/v1/'
'models/some-model:generateContent'),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
}
],
'systemInstruction': {
'role': 'system',
'parts': [
{'text': instructions}
],
},
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
},
);
final response = await model.generateContent(
[Content.text(prompt)],
);
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
});
});

group('generate content stream', () {
Expand Down

0 comments on commit dbc812b

Please sign in to comment.