From 04e9574a223cd9c42615bfde18aa6ad3dc551251 Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 08:20:31 +0800 Subject: [PATCH] feat: [vertexai] add GenerateContentConfig class (#10413) PiperOrigin-RevId: 610472076 Co-authored-by: Jaycee Li --- .../generativeai/GenerateContentConfig.java | 172 ++++++++++++++++++ .../GenerateContentConfigTest.java | 96 ++++++++++ 2 files changed, 268 insertions(+) create mode 100644 java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java create mode 100644 java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java new file mode 100644 index 000000000000..378e04c17b7a --- /dev/null +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java @@ -0,0 +1,172 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vertexai.generativeai; + +import com.google.api.core.BetaApi; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.Tool; +import com.google.common.collect.ImmutableList; +import java.util.List; + +/** This class holds all the configs when making a generate content API call */ +public class GenerateContentConfig { + private GenerationConfig generationConfig; + private ImmutableList safetySettings; + private ImmutableList tools; + + /** Creates a builder for the GenerateContentConfig. */ + public static Builder newBuilder() { + return new Builder(); + } + + private GenerateContentConfig(Builder builder) { + if (builder.generationConfig != null) { + this.generationConfig = builder.generationConfig; + } else { + this.generationConfig = null; + } + if (builder.safetySettings != null) { + this.safetySettings = builder.safetySettings; + } else { + this.safetySettings = ImmutableList.of(); + } + if (builder.tools != null) { + this.tools = builder.tools; + } else { + this.tools = ImmutableList.of(); + } + } + + /** Builder class for {@link GenerateContentConfig}. */ + public static class Builder { + private GenerationConfig generationConfig; + private ImmutableList safetySettings; + private ImmutableList tools; + + private Builder() {} + + /** Builds a GenerateContentConfig instance. */ + public GenerateContentConfig build() { + return new GenerateContentConfig(this); + } + + /** + * Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used in the generate + * content API call. + * + * @return builder for the GenerateContentConfig + */ + @BetaApi + public Builder setGenerationConfig(GenerationConfig generationConfig) { + this.generationConfig = generationConfig; + return this; + } + + /** + * Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used in the + * generate content API call. + * + * @return builder for the GenerateContentConfig + */ + @BetaApi + public Builder setSafetySettings(List safetySettings) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (SafetySetting safetySetting : safetySettings) { + if (safetySetting != null) { + builder.add(safetySetting); + } + } + this.safetySettings = builder.build(); + + return this; + } + + /** + * Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the generate + * content API call. + * + * @return builder for the GenerateContentConfig + */ + @BetaApi + public Builder setTools(List tools) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Tool tool : tools) { + if (tool != null) { + builder.add(tool); + } + } + this.tools = builder.build(); + + return this; + } + } + + /** + * Sets the value for {@link #getGenerationConfig}, which will be used in the generate content API + * call. + */ + @BetaApi + public void setGenerationConfig(GenerationConfig generationConfig) { + this.generationConfig = generationConfig; + } + + /** + * Sets the value for {@link #getSafetySettings}, which will be used in the generate content API + * call. + */ + @BetaApi("safetySettings is a preview feature.") + public void setSafetySettings(List safetySettings) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (SafetySetting safetySetting : safetySettings) { + if (safetySetting != null) { + builder.add(safetySetting); + } + } + this.safetySettings = builder.build(); + } + + /** Sets the value for {@link #getTools}, which will be used in the generate content API call. */ + @BetaApi("tools is a preview feature.") + public void setTools(List tools) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Tool tool : tools) { + if (tool != null) { + builder.add(tool); + } + } + this.tools = builder.build(); + } + + /** Returns the {@link com.google.cloud.vertexai.api.GenerationConfig} of this config. */ + @BetaApi + public GenerationConfig getGenerationConfig() { + return this.generationConfig; + } + + /** Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this config. */ + @BetaApi("safetySettings is a preview feature.") + public ImmutableList getSafetySettings() { + return this.safetySettings; + } + + /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this config. */ + @BetaApi("tools is a preview feature.") + public ImmutableList getTools() { + return this.tools; + } +} diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java new file mode 100644 index 000000000000..835b0604aab9 --- /dev/null +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vertexai.generativeai; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.HarmCategory; +import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.Type; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class GenerateContentConfigTest { + private static final GenerationConfig GENERATION_CONFIG = + GenerationConfig.newBuilder().setCandidateCount(1).build(); + private static final SafetySetting SAFETY_SETTING = + SafetySetting.newBuilder() + .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) + .setThreshold(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) + .build(); + private static final Tool TOOL = + Tool.newBuilder() + .addFunctionDeclarations( + FunctionDeclaration.newBuilder() + .setName("getCurrentWeather") + .setDescription("Get the current weather in a given location") + .setParameters( + Schema.newBuilder() + .setType(Type.OBJECT) + .putProperties( + "location", + Schema.newBuilder() + .setType(Type.STRING) + .setDescription("location") + .build()) + .addRequired("location"))) + .build(); + + private List safetySettings = Arrays.asList(SAFETY_SETTING); + private List tools = Arrays.asList(TOOL); + + private GenerateContentConfig config; + + @Test + public void testInstantiateGenerateContentConfigWithBuilder() { + config = + GenerateContentConfig.newBuilder() + .setGenerationConfig(GENERATION_CONFIG) + .setSafetySettings(safetySettings) + .setTools(tools) + .build(); + assertThat(config.getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(config.getSafetySettings()).isEqualTo(safetySettings); + assertThat(config.getTools()).isEqualTo(tools); + } + + @Test + public void testGenerateContentConfigSetters() { + config = GenerateContentConfig.newBuilder().build(); + + assertThat(config.getGenerationConfig()).isNull(); + assertThat(config.getSafetySettings()).isEmpty(); + assertThat(config.getTools()).isEmpty(); + + config.setGenerationConfig(GENERATION_CONFIG); + config.setSafetySettings(safetySettings); + config.setTools(tools); + + assertThat(config.getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(config.getSafetySettings()).isEqualTo(safetySettings); + assertThat(config.getTools()).isEqualTo(tools); + } +}