Skip to content

Commit

Permalink
Add support for overriding API version
Browse files Browse the repository at this point in the history
This API will be required for authors to use to try function calling
while the backend support is still in version `v1beta`.

Add a RequestOptions class. For now this supports only the `apiVersion`
configuration. In the future this may add a `timeout` (#44) if implement
a deeper HTTP timeout functionality than `Future.timeout`.

Add a utility for VertexAI SDK to pass a callback instead of a single
URI.
  • Loading branch information
natebosch committed Apr 9, 2024
1 parent 5ea809c commit ccb1a65
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 33 deletions.
4 changes: 4 additions & 0 deletions pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.2.4-wip

- Allow specifying an API version.

## 0.2.3

- Update the package version that is sent with the HTTP client name.
Expand Down
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/lib/google_generative_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ export 'src/error.dart'
InvalidApiKey,
ServerException,
UnsupportedUserLocation;
export 'src/model.dart' show GenerativeModel;
export 'src/model.dart' show GenerativeModel, RequestOptions;
102 changes: 73 additions & 29 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import 'client.dart';
import 'content.dart';

const _apiVersion = 'v1';
final _googleAIBaseUri =
Uri.https('generativelanguage.googleapis.com', _apiVersion);
Uri _googleAIBaseUri(RequestOptions? options) => Uri.https(
'generativelanguage.googleapis.com', options?.apiVersion ?? _apiVersion);

enum Task {
generateContent,
Expand All @@ -32,6 +32,11 @@ enum Task {
batchEmbedContents;
}

class RequestOptions {
final String? apiVersion;
const RequestOptions({this.apiVersion});
}

/// A multimodel generative model (like Gemini).
///
/// Allows generating content, creating embeddings, and counting the number of
Expand All @@ -43,7 +48,8 @@ final class GenerativeModel {
final List<SafetySetting> _safetySettings;
final GenerationConfig? _generationConfig;
final ApiClient _client;
final Uri _baseUri;
final RequestOptions? _requestOptions;
final Uri Function(RequestOptions?) _baseUri;

/// Create a [GenerativeModel] backed by the generative model named [model].
///
Expand Down Expand Up @@ -73,26 +79,30 @@ final class GenerativeModel {
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
http.Client? httpClient,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri,
requestOptions: requestOptions,
);

GenerativeModel._withClient({
required ApiClient client,
required String model,
required List<SafetySetting> safetySettings,
required GenerationConfig? generationConfig,
required Uri baseUri,
required Uri Function(RequestOptions?) baseUri,
required RequestOptions? requestOptions,
}) : _model = _normalizeModelName(model),
_baseUri = baseUri,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_client = client;
_client = client,
_requestOptions = requestOptions;

/// Returns the model code for a user friendly model name.
///
Expand All @@ -104,9 +114,12 @@ final class GenerativeModel {
return (prefix: parts.first, name: parts.skip(1).join('/'));
}

Uri _taskUri(Task task) => _baseUri.replace(
pathSegments: _baseUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
Uri _taskUri(Task task) {
final baseUri = _baseUri(_requestOptions);
return baseUri.replace(
pathSegments: baseUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
}

/// Generates content responding to [prompt].
///
Expand Down Expand Up @@ -243,33 +256,64 @@ final class GenerativeModel {
/// Creates a model with an overridden [ApiClient] for testing.
///
/// Package private test-only method.
GenerativeModel createModelWithClient(
{required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig}) =>
GenerativeModel createModelWithClient({
required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri);
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri,
requestOptions: requestOptions,
);

/// Creates a model with an overridden base URL to communicate with a different
/// backend.
///
/// Used from a `src/` import in the Vertex AI SDK.
// TODO: https://github.com/google/generative-ai-dart/issues/111 - Changes to
// this API need to be coordinated with the vertex AI SDK.
GenerativeModel createModelWithBaseUri(
{required String model,
required String apiKey,
required Uri baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig}) =>
GenerativeModel createModelWithBaseUri({
required String model,
required String apiKey,
required Uri baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri);
client: HttpApiClient(apiKey: apiKey),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: (_) => baseUri,
requestOptions: requestOptions,
);

/// Creates a model with an overridden base URL callback to communicate with a
/// different backend.
///
/// Used from a `src/` import in the Vertex AI SDK.
// TODO: https://github.com/google/generative-ai-dart/issues/111 - Changes to
// this API need to be coordinated with the vertex AI SDK.
GenerativeModel createModelWithVersionedBaseUri({
required String model,
required String apiKey,
required Uri Function(RequestOptions?) baseUri,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri,
requestOptions: requestOptions,
);
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: google_generative_ai
version: 0.2.3
version: 0.2.4-wip
description: >-
The Google AI Dart SDK enables developers to use Google's state-of-the-art
generative AI models (like Gemini).
Expand Down
45 changes: 43 additions & 2 deletions pkgs/google_generative_ai/test/generative_model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ void main() {
const defaultModelName = 'some-model';

(StubClient, GenerativeModel) createModel(
[String modelName = defaultModelName]) {
[String modelName = defaultModelName, RequestOptions? requestOptions]) {
final client = StubClient();
final model = createModelWithClient(model: modelName, client: client);
final model = createModelWithClient(
model: modelName, client: client, requestOptions: requestOptions);
return (client, model);
}

Expand Down Expand Up @@ -108,6 +109,46 @@ void main() {
], null)));
});

test('allows specifying an API version', () async {
final (client, model) = createModel(
defaultModelName, RequestOptions(apiVersion: 'override_version'));
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/override_version/'
'models/some-model:generateContent'),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
}
]
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
},
);
final response = await model.generateContent([Content.text(prompt)]);
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
});

group('generate unary content', () {
test('can make successful request', () async {
final (client, model) = createModel();
Expand Down

0 comments on commit ccb1a65

Please sign in to comment.