Skip to content

Commit

Permalink
Add usageMetadata to GenerateContentResponse (#143)
Browse files Browse the repository at this point in the history
Add `UsageMetadata` and relevante parsing. Add a `usageMetadata` field
on `GenerateContentResponse`.

Add usage of the new field to the advanced text sample.

Refactor the `GenerateContentResponse` parse method to handled each
field individually at the top level. There is a behavior change for an
error case, but it is not visible through the message formats that are
returned from the backend in practice.

Prepare to publish.
  • Loading branch information
natebosch authored May 7, 2024
1 parent 6ef825b commit c3d87ab
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
5 changes: 4 additions & 1 deletion pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
## 0.3.3-wip
## 0.3.3

- Add support for parsing the `usageMetadata` field in `GenerateContentResponse`
messages.

## 0.3.2

Expand Down
78 changes: 65 additions & 13 deletions pkgs/google_generative_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ final class GenerateContentResponse {
/// Returns the prompt's feedback related to the content filters.
final PromptFeedback? promptFeedback;

GenerateContentResponse(this.candidates, this.promptFeedback);
final UsageMetadata? usageMetadata;

// TODO(natebosch): Change `promptFeedback` to a named argument.
GenerateContentResponse(
this.candidates,
this.promptFeedback, {
this.usageMetadata,
});

/// The text content of the first part of the first of [candidates], if any.
///
Expand Down Expand Up @@ -145,6 +152,24 @@ final class PromptFeedback {
PromptFeedback(this.blockReason, this.blockReasonMessage, this.safetyRatings);
}

/// Metadata on the generation request's token usage.
final class UsageMetadata {
/// Number of tokens in the prompt.
final int? promptTokenCount;

/// Total number of tokens across the generated candidates.
final int? candidatesTokenCount;

/// Total token count for the generation request (prompt + candidates).
final int? totalTokenCount;

UsageMetadata({
this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
});
}

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
/// Generated content returned from the model.
Expand Down Expand Up @@ -496,20 +521,24 @@ enum TaskType {
}

GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
return switch (jsonObject) {
{'candidates': final List<Object?> candidates} => GenerateContentResponse(
candidates.map(_parseCandidate).toList(),
switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
_parsePromptFeedback(promptFeedback),
_ => null
}),
if (jsonObject case {'error': final Object error}) throw parseError(error);
final candidates = switch (jsonObject) {
{'candidates': final List<Object?> candidates} =>
candidates.map(_parseCandidate).toList(),
_ => <Candidate>[]
};
final promptFeedback = switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
GenerateContentResponse([], _parsePromptFeedback(promptFeedback)),
{'error': final Object error} => throw parseError(error),
_ => throw FormatException(
'Unhandled GenerateContentResponse format', jsonObject)
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
}

CountTokensResponse parseCountTokensResponse(Object jsonObject) {
Expand Down Expand Up @@ -594,6 +623,29 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw FormatException('Unhandled UsageMetadata format', jsonObject);
}
final promptTokenCount = switch (jsonObject) {
{'promptTokenCount': final int promptTokenCount} => promptTokenCount,
_ => null,
};
final candidatesTokenCount = switch (jsonObject) {
{'candidatesTokenCount': final int candidatesTokenCount} =>
candidatesTokenCount,
_ => null,
};
final totalTokenCount = switch (jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
return UsageMetadata(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount);
}

SafetyRating _parseSafetyRating(Object? jsonObject) {
return switch (jsonObject) {
{
Expand Down
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/lib/src/version.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.

const packageVersion = '0.3.3-wip';
const packageVersion = '0.3.3';
2 changes: 1 addition & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: google_generative_ai
# Update `lib/version.dart` when changing version.
version: 0.3.3-wip
version: 0.3.3
description: >-
The Google AI Dart SDK enables developers to use Google's state-of-the-art
generative AI models (like Gemini).
Expand Down
5 changes: 5 additions & 0 deletions samples/dart/bin/advanced_text.dart
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ void main() async {

final responses = model.generateContentStream(content);
await for (final response in responses) {
if (response.usageMetadata case final usageMetadata?) {
stdout.writeln('(Usage: prompt - ${usageMetadata.promptTokenCount}), '
'candidates - ${usageMetadata.candidatesTokenCount}, '
'total - ${usageMetadata.totalTokenCount}');
}
stdout.write(response.text);
}
stdout.writeln();
Expand Down

0 comments on commit c3d87ab

Please sign in to comment.