From c3d87ab550fe4b5ac33c572829eb0ae82f549b88 Mon Sep 17 00:00:00 2001 From: Nate Bosch Date: Mon, 6 May 2024 17:47:31 -0700 Subject: [PATCH] Add usageMetadata to GenerateContentResponse (#143) Add `UsageMetadata` and relevante parsing. Add a `usageMetadata` field on `GenerateContentResponse`. Add usage of the new field to the advanced text sample. Refactor the `GenerateContentResponse` parse method to handled each field individually at the top level. There is a behavior change for an error case, but it is not visible through the message formats that are returned from the backend in practice. Prepare to publish. --- pkgs/google_generative_ai/CHANGELOG.md | 5 +- pkgs/google_generative_ai/lib/src/api.dart | 78 +++++++++++++++---- .../google_generative_ai/lib/src/version.dart | 2 +- pkgs/google_generative_ai/pubspec.yaml | 2 +- samples/dart/bin/advanced_text.dart | 5 ++ 5 files changed, 76 insertions(+), 16 deletions(-) diff --git a/pkgs/google_generative_ai/CHANGELOG.md b/pkgs/google_generative_ai/CHANGELOG.md index d648248..6e5fad9 100644 --- a/pkgs/google_generative_ai/CHANGELOG.md +++ b/pkgs/google_generative_ai/CHANGELOG.md @@ -1,4 +1,7 @@ -## 0.3.3-wip +## 0.3.3 + +- Add support for parsing the `usageMetadata` field in `GenerateContentResponse` + messages. ## 0.3.2 diff --git a/pkgs/google_generative_ai/lib/src/api.dart b/pkgs/google_generative_ai/lib/src/api.dart index 4e4b5f2..4d9e2e3 100644 --- a/pkgs/google_generative_ai/lib/src/api.dart +++ b/pkgs/google_generative_ai/lib/src/api.dart @@ -33,7 +33,14 @@ final class GenerateContentResponse { /// Returns the prompt's feedback related to the content filters. final PromptFeedback? promptFeedback; - GenerateContentResponse(this.candidates, this.promptFeedback); + final UsageMetadata? usageMetadata; + + // TODO(natebosch): Change `promptFeedback` to a named argument. + GenerateContentResponse( + this.candidates, + this.promptFeedback, { + this.usageMetadata, + }); /// The text content of the first part of the first of [candidates], if any. /// @@ -145,6 +152,24 @@ final class PromptFeedback { PromptFeedback(this.blockReason, this.blockReasonMessage, this.safetyRatings); } +/// Metadata on the generation request's token usage. +final class UsageMetadata { + /// Number of tokens in the prompt. + final int? promptTokenCount; + + /// Total number of tokens across the generated candidates. + final int? candidatesTokenCount; + + /// Total token count for the generation request (prompt + candidates). + final int? totalTokenCount; + + UsageMetadata({ + this.promptTokenCount, + this.candidatesTokenCount, + this.totalTokenCount, + }); +} + /// Response candidate generated from a [GenerativeModel]. final class Candidate { /// Generated content returned from the model. @@ -496,20 +521,24 @@ enum TaskType { } GenerateContentResponse parseGenerateContentResponse(Object jsonObject) { - return switch (jsonObject) { - {'candidates': final List candidates} => GenerateContentResponse( - candidates.map(_parseCandidate).toList(), - switch (jsonObject) { - {'promptFeedback': final promptFeedback?} => - _parsePromptFeedback(promptFeedback), - _ => null - }), + if (jsonObject case {'error': final Object error}) throw parseError(error); + final candidates = switch (jsonObject) { + {'candidates': final List candidates} => + candidates.map(_parseCandidate).toList(), + _ => [] + }; + final promptFeedback = switch (jsonObject) { {'promptFeedback': final promptFeedback?} => - GenerateContentResponse([], _parsePromptFeedback(promptFeedback)), - {'error': final Object error} => throw parseError(error), - _ => throw FormatException( - 'Unhandled GenerateContentResponse format', jsonObject) + _parsePromptFeedback(promptFeedback), + _ => null, + }; + final usageMedata = switch (jsonObject) { + {'usageMetadata': final usageMetadata?} => + _parseUsageMetadata(usageMetadata), + _ => null, }; + return GenerateContentResponse(candidates, promptFeedback, + usageMetadata: usageMedata); } CountTokensResponse parseCountTokensResponse(Object jsonObject) { @@ -594,6 +623,29 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) { }; } +UsageMetadata _parseUsageMetadata(Object jsonObject) { + if (jsonObject is! Map) { + throw FormatException('Unhandled UsageMetadata format', jsonObject); + } + final promptTokenCount = switch (jsonObject) { + {'promptTokenCount': final int promptTokenCount} => promptTokenCount, + _ => null, + }; + final candidatesTokenCount = switch (jsonObject) { + {'candidatesTokenCount': final int candidatesTokenCount} => + candidatesTokenCount, + _ => null, + }; + final totalTokenCount = switch (jsonObject) { + {'totalTokenCount': final int totalTokenCount} => totalTokenCount, + _ => null, + }; + return UsageMetadata( + promptTokenCount: promptTokenCount, + candidatesTokenCount: candidatesTokenCount, + totalTokenCount: totalTokenCount); +} + SafetyRating _parseSafetyRating(Object? jsonObject) { return switch (jsonObject) { { diff --git a/pkgs/google_generative_ai/lib/src/version.dart b/pkgs/google_generative_ai/lib/src/version.dart index 4c7d074..0d42719 100644 --- a/pkgs/google_generative_ai/lib/src/version.dart +++ b/pkgs/google_generative_ai/lib/src/version.dart @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -const packageVersion = '0.3.3-wip'; +const packageVersion = '0.3.3'; diff --git a/pkgs/google_generative_ai/pubspec.yaml b/pkgs/google_generative_ai/pubspec.yaml index ac8c1fa..7cc4e3f 100644 --- a/pkgs/google_generative_ai/pubspec.yaml +++ b/pkgs/google_generative_ai/pubspec.yaml @@ -1,6 +1,6 @@ name: google_generative_ai # Update `lib/version.dart` when changing version. -version: 0.3.3-wip +version: 0.3.3 description: >- The Google AI Dart SDK enables developers to use Google's state-of-the-art generative AI models (like Gemini). diff --git a/samples/dart/bin/advanced_text.dart b/samples/dart/bin/advanced_text.dart index 7388857..1368911 100644 --- a/samples/dart/bin/advanced_text.dart +++ b/samples/dart/bin/advanced_text.dart @@ -36,6 +36,11 @@ void main() async { final responses = model.generateContentStream(content); await for (final response in responses) { + if (response.usageMetadata case final usageMetadata?) { + stdout.writeln('(Usage: prompt - ${usageMetadata.promptTokenCount}), ' + 'candidates - ${usageMetadata.candidatesTokenCount}, ' + 'total - ${usageMetadata.totalTokenCount}'); + } stdout.write(response.text); } stdout.writeln();