From 8bc8adbf0dae047d508ca591c659bfb91dfc02f0 Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 18:05:33 -0700 Subject: [PATCH] feat: [vertexai] add fluent API in GenerativeModel (#10585) PiperOrigin-RevId: 617585215 Co-authored-by: Jaycee Li --- .../generativeai/GenerativeModel.java | 35 +++++++++++ .../generativeai/GenerativeModelTest.java | 58 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 12ebe51bf2b1..60a1b9d75d78 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -181,6 +181,41 @@ public Builder setTools(List tools) { } } + /** + * Creates a copy of the current model with updated GenerationConfig. + * + * @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be + * used in the new model. + * @return a new {@link GenerativeModel} instance with the specified GenerationConfig. + */ + public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) { + return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi); + } + + /** + * Creates a copy of the current model with updated safetySettings. + * + * @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will + * be used in the new model. + * @return a new {@link GenerativeModel} instance with the specified safetySettings. + */ + public GenerativeModel withSafetySettings(List safetySettings) { + return new GenerativeModel( + modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi); + } + + /** + * Creates a copy of the current model with updated tools. + * + * @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in + * the new model. + * @return a new {@link GenerativeModel} instance with the specified tools. + */ + public GenerativeModel withTools(List tools) { + return new GenerativeModel( + modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi); + } + /** * Counts tokens in a text message. * diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index aba97637b0af..0f74e84fc768 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -416,6 +416,34 @@ public void testGenerateContentwithDefaultTools() throws Exception { assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } + @Test + public void testGenerateContentwithFluentApi() throws Exception { + model = new GenerativeModel(MODEL_NAME, vertexAi); + + Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); + field.setAccessible(true); + field.set(vertexAi, mockPredictionServiceClient); + + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.call(any(GenerateContentRequest.class))) + .thenReturn(mockGenerateContentResponse); + + GenerateContentResponse unused = + model + .withGenerationConfig(GENERATION_CONFIG) + .withSafetySettings(safetySettings) + .withTools(tools) + .generateContent(TEXT); + + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockUnaryCallable).call(request.capture()); + assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); + assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); + assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); + } + @Test public void testGenerateContentStreamwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); @@ -569,4 +597,34 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { verify(mockServerStreamCallable).call(request.capture()); assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } + + @Test + public void testGenerateContentStreamwithFluentApi() throws Exception { + model = new GenerativeModel(MODEL_NAME, vertexAi); + + Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); + field.setAccessible(true); + field.set(vertexAi, mockPredictionServiceClient); + + when(mockPredictionServiceClient.streamGenerateContentCallable()) + .thenReturn(mockServerStreamCallable); + when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) + .thenReturn(mockServerStream); + when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator); + + ResponseStream unused = + model + .withGenerationConfig(GENERATION_CONFIG) + .withSafetySettings(safetySettings) + .withTools(tools) + .generateContentStream(TEXT); + + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockServerStreamCallable).call(request.capture()); + assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); + assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); + assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); + } }