Skip to content

Commit

Permalink
Insert formatting trailing commas in tests (#157)
Browse files Browse the repository at this point in the history
An upcoming test refactor looks better when it's formatted with trailing
commas in more places. Pre-emptively format with inserted trailing
commas to make it easier to automate the formatting and reduce the diff
in the test refactor.

Trailing commas are inserted automatically by first formatting with the
`tall-style` experiment against the latest commit of `dart_style`, then
reformatting with the stable SDK to match CI expectations. The result
is the current formatting, with the trailing commas inferred where the
tall style has improved splitting. Blank lines are restored following
license header comments.
  • Loading branch information
natebosch authored May 10, 2024
1 parent 0583f15 commit aae64f1
Show file tree
Hide file tree
Showing 6 changed files with 951 additions and 604 deletions.
185 changes: 114 additions & 71 deletions pkgs/google_generative_ai/test/chat_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ void main() {
group('Chat', () {
const defaultModelName = 'some-model';

(StubClient, GenerativeModel) createModel(
[String modelName = defaultModelName]) {
(StubClient, GenerativeModel) createModel([
String modelName = defaultModelName,
]) {
final client = StubClient();
final model = createModelWithClient(model: modelName, client: client);
return (client, model);
Expand All @@ -34,83 +35,97 @@ void main() {
final (client, model) = createModel('models/$defaultModelName');
final chat = model.startChat(history: [
Content.text('Hi!'),
Content.model([TextPart('Hello, how can I help you today?')])
Content.model([TextPart('Hello, how can I help you today?')]),
]);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent'),
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': 'Hi!'}
]
{'text': 'Hi!'},
],
},
{
'role': 'model',
'parts': [
{'text': 'Hello, how can I help you today?'}
]
{'text': 'Hello, how can I help you today?'},
],
},
{
'role': 'user',
'parts': [
{'text': prompt}
]
{'text': prompt},
],
},
]
],
},
{
'candidates': [
{
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
{'text': result},
],
},
},
],
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
);
expect(
chat.history.last, matchesContent(response.candidates.first.content));
chat.history.last,
matchesContent(response.candidates.first.content),
);
});

test('forwards safety settings', () async {
final (client, model) = createModel('models/$defaultModelName');
final chat = model.startChat(safetySettings: [
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high)
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high),
]);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent'),
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
{'text': prompt},
],
},
],
'safetySettings': [
{
'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
'threshold': 'BLOCK_ONLY_HIGH'
}
'threshold': 'BLOCK_ONLY_HIGH',
},
],
},
{
Expand All @@ -119,49 +134,59 @@ void main() {
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
{'text': result},
],
},
},
],
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
);
});

test('forwards safety settings and config when streaming', () async {
final (client, model) = createModel('models/$defaultModelName');
final chat = model.startChat(safetySettings: [
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high)
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high),
], generationConfig: GenerationConfig(stopSequences: ['a']));
final prompt = 'Some prompt';
final result = 'Some response';
client.stubStream(
Uri.parse('https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:streamGenerateContent'),
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:streamGenerateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
{'text': prompt},
],
},
],
'safetySettings': [
{
'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
'threshold': 'BLOCK_ONLY_HIGH'
}
'threshold': 'BLOCK_ONLY_HIGH',
},
],
'generationConfig': {
'stopSequences': ['a']
'stopSequences': ['a'],
},
},
[
Expand All @@ -171,44 +196,54 @@ void main() {
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
}
{'text': result},
],
},
},
],
},
],
);
final responses =
await chat.sendMessageStream(Content.text(prompt)).toList();
expect(responses, [
matchesGenerateContentResponse(GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null))
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
]);
});

test('forwards generation config', () async {
final (client, model) = createModel('models/$defaultModelName');
final chat = model.startChat(
generationConfig: GenerationConfig(stopSequences: ['a']));
generationConfig: GenerationConfig(stopSequences: ['a']),
);
final prompt = 'Some prompt';
final result = 'Some response';
client.stub(
Uri.parse('https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent'),
Uri.parse(
'https://generativelanguage.googleapis.com/v1beta/'
'models/some-model:generateContent',
),
{
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt}
]
{'text': prompt},
],
},
],
'generationConfig': {
'stopSequences': ['a']
'stopSequences': ['a'],
},
},
{
Expand All @@ -217,20 +252,28 @@ void main() {
'content': {
'role': 'model',
'parts': [
{'text': result}
]
}
}
]
{'text': result},
],
},
},
],
},
);
final response = await chat.sendMessage(Content.text(prompt));
expect(
response,
matchesGenerateContentResponse(GenerateContentResponse([
response,
matchesGenerateContentResponse(
GenerateContentResponse([
Candidate(
Content('model', [TextPart(result)]), null, null, null, null),
], null)));
Content('model', [TextPart(result)]),
null,
null,
null,
null,
),
], null),
),
);
});
});
}
Loading

0 comments on commit aae64f1

Please sign in to comment.