Skip to content

Commit

Permalink
feat: [vertexai] add fluent API in GenerativeModel (#10585)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617585215

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Mar 21, 2024
1 parent bedcddf commit 8bc8adb
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,41 @@ public Builder setTools(List<Tool> tools) {
}
}

/**
* Creates a copy of the current model with updated GenerationConfig.
*
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
* used in the new model.
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
}

/**
* Creates a copy of the current model with updated safetySettings.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will
* be used in the new model.
* @return a new {@link GenerativeModel} instance with the specified safetySettings.
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
}

/**
* Counts tokens in a text message.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,34 @@ public void testGenerateContentwithDefaultTools() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

GenerateContentResponse unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down Expand Up @@ -569,4 +597,34 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockServerStream);
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);

ResponseStream unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContentStream(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}
}

0 comments on commit 8bc8adb

Please sign in to comment.