Skip to content

Commit

Permalink
Add support for overriding API version (#114)
Browse files Browse the repository at this point in the history
This API will be required for authors to use to try function calling
while the backend support is still in version `v1beta`.

Add a RequestOptions class. For now this supports only the `apiVersion`
configuration. In the future this may add a `timeout` (#44) if implement
a deeper HTTP timeout functionality than `Future.timeout`.
  • Loading branch information
natebosch authored Apr 9, 2024
1 parent 5ea809c commit 20e96b9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 18 deletions.
5 changes: 5 additions & 0 deletions pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.2.4-wip

- Allow specifying an API version in a `requestOptions` argument when
constructing a model.

## 0.2.3

- Update the package version that is sent with the HTTP client name.
Expand Down
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/lib/google_generative_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ export 'src/error.dart'
InvalidApiKey,
ServerException,
UnsupportedUserLocation;
export 'src/model.dart' show GenerativeModel;
export 'src/model.dart' show GenerativeModel, RequestOptions;
43 changes: 30 additions & 13 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import 'client.dart';
import 'content.dart';

const _apiVersion = 'v1';
final _googleAIBaseUri =
Uri.https('generativelanguage.googleapis.com', _apiVersion);
Uri _googleAIBaseUri(RequestOptions? options) => Uri.https(
'generativelanguage.googleapis.com', options?.apiVersion ?? _apiVersion);

enum Task {
generateContent,
Expand All @@ -32,6 +32,19 @@ enum Task {
batchEmbedContents;
}

/// Configuration for how a [GenerativeModel] makes requests.
///
/// This allows overriding the API version in use which may be required to use
/// some beta features.
final class RequestOptions {
/// The API version used to make requests.
///
/// By default the version is `v1`. This may be specified as `v1beta` to use
/// beta features.
final String? apiVersion;
const RequestOptions({this.apiVersion});
}

/// A multimodel generative model (like Gemini).
///
/// Allows generating content, creating embeddings, and counting the number of
Expand Down Expand Up @@ -73,13 +86,14 @@ final class GenerativeModel {
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
http.Client? httpClient,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri,
baseUri: _googleAIBaseUri(requestOptions),
);

GenerativeModel._withClient({
Expand Down Expand Up @@ -243,17 +257,20 @@ final class GenerativeModel {
/// Creates a model with an overridden [ApiClient] for testing.
///
/// Package private test-only method.
GenerativeModel createModelWithClient(
{required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig}) =>
GenerativeModel createModelWithClient({
required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri);
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
);

/// Creates a model with an overridden base URL to communicate with a different
/// backend.
Expand Down
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/lib/src/version.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.

const packageVersion = '0.2.3';
const packageVersion = '0.2.4-wip';
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: google_generative_ai
version: 0.2.3
version: 0.2.4-wip
description: >-
The Google AI Dart SDK enables developers to use Google's state-of-the-art
generative AI models (like Gemini).
Expand Down
45 changes: 43 additions & 2 deletions pkgs/google_generative_ai/test/generative_model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ void main() {
const defaultModelName = 'some-model';

(StubClient, GenerativeModel) createModel(
[String modelName = defaultModelName]) {
[String modelName = defaultModelName, RequestOptions? requestOptions]) {
final client = StubClient();
final model = createModelWithClient(model: modelName, client: client);
final model = createModelWithClient(
model: modelName, client: client, requestOptions: requestOptions);
return (client, model);
}

Expand Down Expand Up @@ -108,6 +109,46 @@ void main() {
], null)));
});

test('allows specifying an API version', () async {
final (client, model) = createModel(
defaultModelName, RequestOptions(apiVersion: 'override_version'));
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/override_version/'
'models/some-model:generateContent'),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
}
]
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
},
);
final response = await model.generateContent([Content.text(prompt)]);
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
});

group('generate unary content', () {
test('can make successful request', () async {
final (client, model) = createModel();
Expand Down

0 comments on commit 20e96b9

Please sign in to comment.