diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index f467f6b8b59c..8c7b809f1f49 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -409,6 +409,22 @@ public GenerateContentResponse generateContent(String text) throws IOException { return generateContent(text, null, null); } + /** + * Generates content from generative model given a text and configs. + * + * @param text a text message to send to the generative model + * @param config a {@link GenerateContentConfig} that contains all the configs in making a + * generate content api call + * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains + * response contents and other metadata + * @throws IOException if an I/O error occurs while making the API call + */ + @BetaApi + public GenerateContentResponse generateContent(String text, GenerateContentConfig config) + throws IOException { + return generateContent(ContentMaker.fromString(text), config); + } + /** * Generate content from generative model given a text and generation config. * @@ -511,6 +527,41 @@ public GenerateContentResponse generateContent( return generateContent(contents, null, safetySettings); } + /** + * Generates content from generative model given a list of contents and configs. + * + * @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the + * generative model + * @param config a {@link GenerateContentConfig} that contains all the configs in making a + * generate content api call + * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains + * response contents and other metadata + * @throws IOException if an I/O error occurs while making the API call + */ + @BetaApi + public GenerateContentResponse generateContent( + List contents, GenerateContentConfig config) throws IOException { + GenerateContentRequest.Builder requestBuilder = + GenerateContentRequest.newBuilder().addAllContents(contents); + if (config.getGenerationConfig() != null) { + requestBuilder.setGenerationConfig(config.getGenerationConfig()); + } else if (this.generationConfig != null) { + requestBuilder.setGenerationConfig(this.generationConfig); + } + if (config.getSafetySettings().isEmpty() == false) { + requestBuilder.addAllSafetySettings(config.getSafetySettings()); + } else if (this.safetySettings != null) { + requestBuilder.addAllSafetySettings(this.safetySettings); + } + if (config.getTools().isEmpty() == false) { + requestBuilder.addAllTools(config.getTools()); + } else if (this.tools != null) { + requestBuilder.addAllTools(this.tools); + } + + return generateContent(requestBuilder); + } + /** * Generate content from generative model given a list of contents, generation config, and safety * settings. @@ -581,6 +632,22 @@ public GenerateContentResponse generateContent(Content content) throws IOExcepti return generateContent(content, null, null); } + /** + * Generates content from generative model given a single content and configs. + * + * @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model + * @param config a {@link GenerateContentConfig} that contains all the configs in making a + * generate content api call + * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains + * response contents and other metadata + * @throws IOException if an I/O error occurs while making the API call + */ + @BetaApi + public GenerateContentResponse generateContent(Content content, GenerateContentConfig config) + throws IOException { + return generateContent(Arrays.asList(content), config); + } + /** * Generate content from this model given a single content and generation config. * diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 1ce55bf25c9c..f6930b38fca5 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -452,6 +452,36 @@ public void testGenerateContentwithDefaultTools() throws Exception { assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } + @Test + public void testGenerateContentwithGenerateContentConfig() throws Exception { + model = new GenerativeModel(MODEL_NAME, vertexAi); + GenerateContentConfig config = + GenerateContentConfig.newBuilder() + .setGenerationConfig(GENERATION_CONFIG) + .setSafetySettings(safetySettings) + .setTools(tools) + .build(); + + Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); + field.setAccessible(true); + field.set(vertexAi, mockPredictionServiceClient); + + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.call(any(GenerateContentRequest.class))) + .thenReturn(mockGenerateContentResponse); + + GenerateContentResponse unused = model.generateContent(TEXT, config); + + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockUnaryCallable).call(request.capture()); + + assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); + assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); + assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); + } + @Test public void testGenerateContentStreamwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi);