From 54a2cfca6327872fd623387349ac22bb2a008977 Mon Sep 17 00:00:00 2001 From: Nicholas Cook Date: Fri, 2 Aug 2024 13:58:45 -0700 Subject: [PATCH] feat(aiplatform): add initial Imagen code sample and test (#9422) * feat(aiplatform): add initial Imagen code sample and test * Update aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java Co-authored-by: Sita Lakshmi Sangameswaran * address feedback * update to Imagen 3 * trigger build * trigger build * trigger build * trigger build * trigger build * trigger build * change test project ID var --------- Co-authored-by: Sita Lakshmi Sangameswaran --- .../imagen/GenerateImageSample.java | 105 ++++++++++++++++++ .../imagen/GenerateImageSampleTest.java | 73 ++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java create mode 100644 aiplatform/src/test/java/aiplatform/imagen/GenerateImageSampleTest.java diff --git a/aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java b/aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java new file mode 100644 index 00000000000..b3719a35c5b --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform.imagen; + +// [START generativeaionvertexai_imagen_generate_image] + +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class GenerateImageSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String projectId = "my-project-id"; + String location = "us-central1"; + String prompt = ""; // The text prompt describing what you want to see. + + generateImage(projectId, location, prompt); + } + + // Generate an image using a text prompt using an Imagen model + public static PredictResponse generateImage(String projectId, String location, String prompt) + throws ApiException, IOException { + final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + + final EndpointName endpointName = + EndpointName.ofProjectLocationPublisherModelName( + projectId, location, "google", "imagen-3.0-generate-001"); + + Map instancesMap = new HashMap<>(); + instancesMap.put("prompt", prompt); + Value instances = mapToValue(instancesMap); + + Map paramsMap = new HashMap<>(); + paramsMap.put("sampleCount", 1); + // You can't use a seed value and watermark at the same time. + // paramsMap.put("seed", 100); + // paramsMap.put("addWatermark", true); + paramsMap.put("aspectRatio", "1:1"); + paramsMap.put("safetyFilterLevel", "block_some"); + paramsMap.put("personGeneration", "allow_adult"); + Value parameters = mapToValue(paramsMap); + + PredictResponse predictResponse = + predictionServiceClient.predict( + endpointName, Collections.singletonList(instances), parameters); + + for (Value prediction : predictResponse.getPredictionsList()) { + Map fieldsMap = prediction.getStructValue().getFieldsMap(); + if (fieldsMap.containsKey("bytesBase64Encoded")) { + String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue(); + Path tmpPath = Files.createTempFile("imagen-", ".png"); + Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded)); + System.out.format("Image file written to: %s\n", tmpPath.toUri()); + } + } + return predictResponse; + } + } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } +} + +// [END generativeaionvertexai_imagen_generate_image] diff --git a/aiplatform/src/test/java/aiplatform/imagen/GenerateImageSampleTest.java b/aiplatform/src/test/java/aiplatform/imagen/GenerateImageSampleTest.java new file mode 100644 index 00000000000..0112df5b324 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/imagen/GenerateImageSampleTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform.imagen; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.aiplatform.v1.PredictResponse; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GenerateImageSampleTest { + + private static final String PROJECT = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String PROMPT = "a dog reading a newspaper"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGenerateImageSample() throws IOException { + PredictResponse response = GenerateImageSample.generateImage(PROJECT, "us-central1", PROMPT); + assertThat(response).isNotNull(); + } +}