Skip to content

Commit

Permalink
Add a ClientController test utility (#155)
Browse files Browse the repository at this point in the history
With `StubClient` the responsibilities for checking arguments and
stubbing return values are overloaded to an `EqualityMap` checking the
full content of the Uri and JSON payload and using it as a key to choose
a stubbed response. This is verbose; details are repeated in every test
that are not relevant to the test. It also has hard to diagnose
failures; when an argument is incorrect it surfaces as a missing key
with no hint about what value in the payload is wrong.

Add a `ClientController` class which more directly addresses stubbing
and argument checking individually. When only stubbing is necessary the
`verifyRequest` argument can be omitted, and any request gets the
stubbed response. Expectations checked against the request URI and JSON
payload can be as shallow or deep as necessary for the given test.

Add an `arbitraryGenerateContentResponse` variable to fill in for
responses where the parsed output isn't tested.
  • Loading branch information
natebosch authored May 10, 2024
1 parent aae64f1 commit 5be1ca2
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 895 deletions.
1 change: 0 additions & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dependencies:
http: ^1.1.0

dev_dependencies:
collection: ^1.18.0
lints: ^3.0.0
matcher: ^0.12.16
test: ^1.24.0
Expand Down
225 changes: 31 additions & 194 deletions pkgs/google_generative_ai/test/chat_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ void main() {
group('Chat', () {
const defaultModelName = 'some-model';

(StubClient, GenerativeModel) createModel([
(ClientController, GenerativeModel) createModel([
String modelName = defaultModelName,
]) {
final client = StubClient();
final model = createModelWithClient(model: modelName, client: client);
final client = ClientController();
final model = createModelWithClient(
model: modelName,
client: client.client,
);
return (client, model);
}

Expand All @@ -38,61 +41,13 @@ void main() {
Content.model([TextPart('Hello, how can I help you today?')]),
]);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': 'Hi!'},
],
},
{
'role': 'model',
'parts': [
{'text': 'Hello, how can I help you today?'},
],
},
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
final response = await client.checkRequest(
() => chat.sendMessage(Content.text(prompt)),
verifyRequest: (_, request) {
final contents = request['contents'];
expect(contents, hasLength(3));
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result},
],
},
},
],
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
response: arbitraryGenerateContentResponse,
);
expect(
chat.history.last,
Expand All @@ -106,55 +61,17 @@ void main() {
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high),
]);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
'safetySettings': [
await client.checkRequest(
() => chat.sendMessage(Content.text(prompt)),
verifyRequest: (_, request) {
expect(request['safetySettings'], [
{
'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
'threshold': 'BLOCK_ONLY_HIGH',
},
],
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result},
],
},
},
],
]);
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
response: arbitraryGenerateContentResponse,
);
});

Expand All @@ -164,61 +81,19 @@ void main() {
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high),
], generationConfig: GenerationConfig(stopSequences: ['a']));
final prompt = 'Some prompt';
final result = 'Some response';
client.stubStream(
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:streamGenerateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
'safetySettings': [
final responses = await client.checkStreamRequest(
() async => chat.sendMessageStream(Content.text(prompt)),
verifyRequest: (_, request) {
expect(request['safetySettings'], [
{
'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
'threshold': 'BLOCK_ONLY_HIGH',
},
],
'generationConfig': {
'stopSequences': ['a'],
},
]);
},
[
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result},
],
},
},
],
},
],
responses: [arbitraryGenerateContentResponse],
);
final responses =
await chat.sendMessageStream(Content.text(prompt)).toList();
expect(responses, [
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
]);
await responses.drain<void>();
});

test('forwards generation config', () async {
Expand All @@ -227,52 +102,14 @@ void main() {
generationConfig: GenerationConfig(stopSequences: ['a']),
);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
'generationConfig': {
await client.checkRequest(
() => chat.sendMessage(Content.text(prompt)),
verifyRequest: (_, request) {
expect(request['generationConfig'], {
'stopSequences': ['a'],
},
});
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result},
],
},
},
],
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
response: arbitraryGenerateContentResponse,
);
});
});
Expand Down
Loading

0 comments on commit 5be1ca2

Please sign in to comment.