diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 361b9185314d..cdb3311bb35d 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -968,7 +968,7 @@ public ResponseStream generateContentStream( List contents, GenerationConfig generationConfig, List safetySettings) throws IOException { GenerateContentRequest.Builder requestBuilder = - GenerateContentRequest.newBuilder().addAllContents(contents); + GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents); if (generationConfig != null) { requestBuilder.setGenerationConfig(generationConfig); } else if (this.generationConfig != null) { @@ -982,7 +982,7 @@ public ResponseStream generateContentStream( if (this.tools != null) { requestBuilder.addAllTools(this.tools); } - return generateContentStream(requestBuilder); + return generateContentStream(requestBuilder.build()); } /** @@ -1000,7 +1000,7 @@ public ResponseStream generateContentStream( public ResponseStream generateContentStream( List contents, GenerateContentConfig config) throws IOException { GenerateContentRequest.Builder requestBuilder = - GenerateContentRequest.newBuilder().addAllContents(contents); + GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents); if (config.getGenerationConfig() != null) { requestBuilder.setGenerationConfig(config.getGenerationConfig()); } else if (this.generationConfig != null) { @@ -1017,42 +1017,36 @@ public ResponseStream generateContentStream( requestBuilder.addAllTools(this.tools); } - return generateContentStream(requestBuilder); + return generateContentStream(requestBuilder.build()); } /** * A base generateContentStream method that will be used internally. * - * @param requestBuilder a {@link com.google.cloud.vertexai.api.GenerateContentRequest.Builder} - * instance + * @param request a {@link com.google.cloud.vertexai.api.GenerateContentRequest} instance * @return a {@link ResponseStream} that contains a streaming of {@link * com.google.cloud.vertexai.api.GenerateContentResponse} * @throws IOException if an I/O error occurs while making the API call */ private ResponseStream generateContentStream( - GenerateContentRequest.Builder requestBuilder) throws IOException { - GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build(); - ResponseStream responseStream = null; + GenerateContentRequest request) throws IOException { if (this.transport == Transport.REST) { - responseStream = - new ResponseStream( - new ResponseStreamIteratorWithHistory( - vertexAi - .getPredictionServiceRestClient() - .streamGenerateContentCallable() - .call(request) - .iterator())); + return new ResponseStream( + new ResponseStreamIteratorWithHistory( + vertexAi + .getPredictionServiceRestClient() + .streamGenerateContentCallable() + .call(request) + .iterator())); } else { - responseStream = - new ResponseStream( - new ResponseStreamIteratorWithHistory( - vertexAi - .getPredictionServiceClient() - .streamGenerateContentCallable() - .call(request) - .iterator())); + return new ResponseStream( + new ResponseStreamIteratorWithHistory( + vertexAi + .getPredictionServiceClient() + .streamGenerateContentCallable() + .call(request) + .iterator())); } - return responseStream; } /** diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java index 7f9e5f616da5..44cca9bd4925 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java @@ -22,8 +22,10 @@ import com.google.cloud.vertexai.api.Citation; import com.google.cloud.vertexai.api.CitationMetadata; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.Part; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,7 +35,7 @@ public class ResponseHandler { /** - * Get the text message in a GenerateContentResponse. + * Gets the text message in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return a String that aggregates all the text parts in the response @@ -41,12 +43,7 @@ public class ResponseHandler { * response is blocked by safety reason or unauthorized citations */ public static String getText(GenerateContentResponse response) { - FinishReason finishReason = getFinishReason(response); - if (finishReason == FinishReason.SAFETY) { - throw new IllegalArgumentException("The response is blocked due to safety reason."); - } else if (finishReason == FinishReason.RECITATION) { - throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); - } + checkFinishReason(getFinishReason(response)); String text = ""; List parts = response.getCandidates(0).getContent().getPartsList(); @@ -58,7 +55,26 @@ public static String getText(GenerateContentResponse response) { } /** - * Get the content in a GenerateContentResponse. + * Gets the list of function calls in a GenerateContentResponse. + * + * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance + * @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response + * @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the + * response is blocked by safety reason or unauthorized citations + */ + public static ImmutableList getFunctionCalls(GenerateContentResponse response) { + checkFinishReason(getFinishReason(response)); + if (response.getCandidatesCount() == 0) { + return ImmutableList.of(); + } + return response.getCandidates(0).getContent().getPartsList().stream() + .filter((part) -> part.hasFunctionCall()) + .map((part) -> part.getFunctionCall()) + .collect(ImmutableList.toImmutableList()); + } + + /** + * Gets the content in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return the {@link com.google.cloud.vertexai.api.Content} in the response @@ -66,18 +82,13 @@ public static String getText(GenerateContentResponse response) { * response is blocked by safety reason or unauthorized citations */ public static Content getContent(GenerateContentResponse response) { - FinishReason finishReason = getFinishReason(response); - if (finishReason == FinishReason.SAFETY) { - throw new IllegalArgumentException("The response is blocked due to safety reason."); - } else if (finishReason == FinishReason.RECITATION) { - throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); - } + checkFinishReason(getFinishReason(response)); return response.getCandidates(0).getContent(); } /** - * Get the finish reason in a GenerateContentResponse. + * Gets the finish reason in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response @@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) { return response.getCandidates(0).getFinishReason(); } - /** Aggregate a stream of responses into a single GenerateContentResponse. */ + /** Aggregates a stream of responses into a single GenerateContentResponse. */ static GenerateContentResponse aggregateStreamIntoResponse( ResponseStream responseStream) { GenerateContentResponse res = GenerateContentResponse.getDefaultInstance(); @@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse( return res; } + + private static void checkFinishReason(FinishReason finishReason) { + if (finishReason == FinishReason.SAFETY) { + throw new IllegalArgumentException("The response is blocked due to safety reason."); + } else if (finishReason == FinishReason.RECITATION) { + throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); + } + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java index 74c633779565..80487e7355fe 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java @@ -25,8 +25,10 @@ import com.google.cloud.vertexai.api.Citation; import com.google.cloud.vertexai.api.CitationMetadata; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.Part; +import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.Iterator; import org.junit.Rule; @@ -47,6 +49,13 @@ public final class ResponseHandlerTest { .addParts(Part.newBuilder().setText(TEXT_1)) .addParts(Part.newBuilder().setText(TEXT_2)) .build(); + private static final Content CONTENT_WITH_FNCTION_CALL = + Content.newBuilder() + .addParts(Part.newBuilder().setText(TEXT_1)) + .addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance())) + .addParts(Part.newBuilder().setText(TEXT_2)) + .addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance())) + .build(); private static final Citation CITATION_1 = Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build(); private static final Citation CITATION_2 = @@ -61,10 +70,14 @@ public final class ResponseHandlerTest { .setContent(CONTENT) .setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2)) .build(); + private static final Candidate CANDIDATE_3 = + Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build(); private static final GenerateContentResponse RESPONSE_1 = GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build(); private static final GenerateContentResponse RESPONSE_2 = GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build(); + private static final GenerateContentResponse RESPONSE_3 = + GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build(); private static final GenerateContentResponse INVALID_RESPONSE = GenerateContentResponse.newBuilder() .addCandidates(CANDIDATE_1) @@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() { INVALID_RESPONSE.getCandidatesCount())); } + @Test + public void testGetFunctionCallsFromResponse() { + ImmutableList functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3); + assertThat(functionCalls.size()).isEqualTo(2); + assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance()); + assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance()); + } + + @Test + public void testGetFunctionCallsFromInvalidResponse() { + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + String.format( + "This response should have exactly 1 candidate, but it has %s.", + INVALID_RESPONSE.getCandidatesCount())); + } + @Test public void testGetContentFromResponse() { Content content = ResponseHandler.getContent(RESPONSE_1);