Skip to content

Commit

Permalink
chore: [vertexai] pass in immutable object in generateContentStream p…
Browse files Browse the repository at this point in the history
…rivate method (#10500)

PiperOrigin-RevId: 613330411

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Mar 8, 2024
1 parent 3166f32 commit 88453fd
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
throws IOException {
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
Expand All @@ -982,7 +982,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}
return generateContentStream(requestBuilder);
return generateContentStream(requestBuilder.build());
}

/**
Expand All @@ -1000,7 +1000,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
public ResponseStream<GenerateContentResponse> generateContentStream(
List<Content> contents, GenerateContentConfig config) throws IOException {
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents);
if (config.getGenerationConfig() != null) {
requestBuilder.setGenerationConfig(config.getGenerationConfig());
} else if (this.generationConfig != null) {
Expand All @@ -1017,42 +1017,36 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
requestBuilder.addAllTools(this.tools);
}

return generateContentStream(requestBuilder);
return generateContentStream(requestBuilder.build());
}

/**
* A base generateContentStream method that will be used internally.
*
* @param requestBuilder a {@link com.google.cloud.vertexai.api.GenerateContentRequest.Builder}
* instance
* @param request a {@link com.google.cloud.vertexai.api.GenerateContentRequest} instance
* @return a {@link ResponseStream} that contains a streaming of {@link
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest.Builder requestBuilder) throws IOException {
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
ResponseStream<GenerateContentResponse> responseStream = null;
GenerateContentRequest request) throws IOException {
if (this.transport == Transport.REST) {
responseStream =
new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceRestClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceRestClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
} else {
responseStream =
new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
}
return responseStream;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -33,20 +35,15 @@
public class ResponseHandler {

/**
* Get the text message in a GenerateContentResponse.
* Gets the text message in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a String that aggregates all the text parts in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static String getText(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

String text = "";
List<Part> parts = response.getCandidates(0).getContent().getPartsList();
Expand All @@ -58,26 +55,40 @@ public static String getText(GenerateContentResponse response) {
}

/**
* Get the content in a GenerateContentResponse.
* Gets the list of function calls in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static ImmutableList<FunctionCall> getFunctionCalls(GenerateContentResponse response) {
checkFinishReason(getFinishReason(response));
if (response.getCandidatesCount() == 0) {
return ImmutableList.of();
}
return response.getCandidates(0).getContent().getPartsList().stream()
.filter((part) -> part.hasFunctionCall())
.map((part) -> part.getFunctionCall())
.collect(ImmutableList.toImmutableList());
}

/**
* Gets the content in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.Content} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static Content getContent(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

return response.getCandidates(0).getContent();
}

/**
* Get the finish reason in a GenerateContentResponse.
* Gets the finish reason in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response
Expand All @@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) {
return response.getCandidates(0).getFinishReason();
}

/** Aggregate a stream of responses into a single GenerateContentResponse. */
/** Aggregates a stream of responses into a single GenerateContentResponse. */
static GenerateContentResponse aggregateStreamIntoResponse(
ResponseStream<GenerateContentResponse> responseStream) {
GenerateContentResponse res = GenerateContentResponse.getDefaultInstance();
Expand Down Expand Up @@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse(

return res;
}

private static void checkFinishReason(FinishReason finishReason) {
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Iterator;
import org.junit.Rule;
Expand All @@ -47,6 +49,13 @@ public final class ResponseHandlerTest {
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setText(TEXT_2))
.build();
private static final Content CONTENT_WITH_FNCTION_CALL =
Content.newBuilder()
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.addParts(Part.newBuilder().setText(TEXT_2))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.build();
private static final Citation CITATION_1 =
Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build();
private static final Citation CITATION_2 =
Expand All @@ -61,10 +70,14 @@ public final class ResponseHandlerTest {
.setContent(CONTENT)
.setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2))
.build();
private static final Candidate CANDIDATE_3 =
Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build();
private static final GenerateContentResponse RESPONSE_1 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build();
private static final GenerateContentResponse RESPONSE_2 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build();
private static final GenerateContentResponse RESPONSE_3 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build();
private static final GenerateContentResponse INVALID_RESPONSE =
GenerateContentResponse.newBuilder()
.addCandidates(CANDIDATE_1)
Expand Down Expand Up @@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() {
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetFunctionCallsFromResponse() {
ImmutableList<FunctionCall> functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3);
assertThat(functionCalls.size()).isEqualTo(2);
assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance());
assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance());
}

@Test
public void testGetFunctionCallsFromInvalidResponse() {
IllegalArgumentException thrown =
assertThrows(
IllegalArgumentException.class,
() -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE));
assertThat(thrown)
.hasMessageThat()
.isEqualTo(
String.format(
"This response should have exactly 1 candidate, but it has %s.",
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetContentFromResponse() {
Content content = ResponseHandler.getContent(RESPONSE_1);
Expand Down

0 comments on commit 88453fd

Please sign in to comment.