diff --git a/retail/interactive-tutorials/src/main/java/prediction/FilteringPrediction.java b/retail/interactive-tutorials/src/main/java/prediction/FilteringPrediction.java new file mode 100644 index 00000000000..25dd7400a6b --- /dev/null +++ b/retail/interactive-tutorials/src/main/java/prediction/FilteringPrediction.java @@ -0,0 +1,89 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * [START retail_prediction_get_prediction_with_filtering] + * Call Retail API to get predictions from Recommendation AI using filtering. + */ + +package prediction; + +import com.google.cloud.retail.v2.PredictRequest; +import com.google.cloud.retail.v2.PredictResponse; +import com.google.cloud.retail.v2.PredictionServiceClient; +import com.google.cloud.retail.v2.Product; +import com.google.cloud.retail.v2.ProductDetail; +import com.google.cloud.retail.v2.UserEvent; +import com.google.protobuf.Value; +import java.io.IOException; + +public class FilteringPrediction { + + public static void main(String[] args) { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + } + + public static void predict(String predictPlacement) { + try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { + PredictResponse predictResponse = + predictionServiceClient.predict(getPredictRequest(predictPlacement)); + System.out.printf("Predict response: %n%s", predictResponse); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private static PredictRequest getPredictRequest(String predictPlacement) { + // create product object + Product product = + Product.newBuilder() + .setId("55106") // Id of real product + .build(); + + // create product detail object + ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); + + // create user event object + UserEvent userEvent = + UserEvent.newBuilder() + .setEventType("detail-page-view") + .setVisitorId("281790") // Unique identifier to track visitors + .addProductDetails(productDetail) + .build(); + + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setPlacement(predictPlacement) + .setUserEvent(userEvent) + // TRY DIFFERENT FILTER HERE: + .setFilter("filterOutOfStockItems") + // TRY TO UPDATE `strictFiltering` HERE: + .putParams("strictFiltering", Value.newBuilder().setBoolValue(true).build()) + .build(); + System.out.printf("Predict request: %n%s", predictRequest); + + return predictRequest; + } +} + +// [END retail_prediction_get_prediction_with_filtering] diff --git a/retail/interactive-tutorials/src/main/java/prediction/PredictionWithParameters.java b/retail/interactive-tutorials/src/main/java/prediction/PredictionWithParameters.java new file mode 100644 index 00000000000..03ffe8139b4 --- /dev/null +++ b/retail/interactive-tutorials/src/main/java/prediction/PredictionWithParameters.java @@ -0,0 +1,89 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * [START retail_prediction_get_prediction_with_params] + * Call Retail API to get predictions from Recommendation AI using parameters. + */ + +package prediction; + +import com.google.cloud.retail.v2.PredictRequest; +import com.google.cloud.retail.v2.PredictResponse; +import com.google.cloud.retail.v2.PredictionServiceClient; +import com.google.cloud.retail.v2.Product; +import com.google.cloud.retail.v2.ProductDetail; +import com.google.cloud.retail.v2.UserEvent; +import com.google.protobuf.Value; +import java.io.IOException; + +public class PredictionWithParameters { + + public static void main(String[] args) { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + } + + public static void predict(String predictPlacement) { + try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { + PredictResponse predictResponse = + predictionServiceClient.predict(getPredictRequest(predictPlacement)); + System.out.printf("Predict response: %n%s", predictResponse); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private static PredictRequest getPredictRequest(String predictPlacement) { + // create product object + Product product = + Product.newBuilder() + .setId("55106") // Id of real product + .build(); + + // create product detail object + ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); + + // create user event object + UserEvent userEvent = + UserEvent.newBuilder() + .setEventType("detail-page-view") + .setVisitorId("281790") // Unique identifier to track visitors + .addProductDetails(productDetail) + .build(); + + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setPlacement(predictPlacement) // Placement is used to identify the Serving Config name + .setUserEvent(userEvent) // Context about the user is required for event logging + // TRY TO ADD/UPDATE PARAMETERS `priceRerankLevel` OR `diversityLevel` HERE: + .putParams( + "priceRerankLevel", + Value.newBuilder().setStringValue("low-price-reranking").build()) + .build(); + System.out.printf("Predict request: %n%s", predictRequest); + + return predictRequest; + } +} + +// [END retail_prediction_get_prediction_with_params] diff --git a/retail/interactive-tutorials/src/main/java/prediction/SimplePrediction.java b/retail/interactive-tutorials/src/main/java/prediction/SimplePrediction.java new file mode 100644 index 00000000000..a7fcabd9309 --- /dev/null +++ b/retail/interactive-tutorials/src/main/java/prediction/SimplePrediction.java @@ -0,0 +1,86 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * [START retail_prediction_get_simple_prediction] + * Call Retail API to get predictions from Recommendation AI using simple request. + */ + +package prediction; + +import com.google.cloud.retail.v2.PredictRequest; +import com.google.cloud.retail.v2.PredictResponse; +import com.google.cloud.retail.v2.PredictionServiceClient; +import com.google.cloud.retail.v2.Product; +import com.google.cloud.retail.v2.ProductDetail; +import com.google.cloud.retail.v2.UserEvent; +import com.google.protobuf.Value; +import java.io.IOException; + +public class SimplePrediction { + + public static void main(String[] args) { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + } + + public static void predict(String predictPlacement) { + try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { + PredictResponse predictResponse = + predictionServiceClient.predict(getPredictRequest(predictPlacement)); + System.out.printf("Predict response: %n%s", predictResponse); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private static PredictRequest getPredictRequest(String predictPlacement) { + // create product object + Product product = + Product.newBuilder() + .setId("55106") // Id of real product + .build(); + + // create product detail object + ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); + + // create user event object + UserEvent userEvent = + UserEvent.newBuilder() + .setEventType("detail-page-view") + .setVisitorId("281790") // Unique identifier to track visitors + .addProductDetails(productDetail) + .build(); + + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setPlacement(predictPlacement) // Placement is used to identify the Serving Config name + .setUserEvent(userEvent) // Context about the user is required for event logging + .putParams("returnProduct", Value.newBuilder().setBoolValue(true).build()) + .build(); + System.out.printf("Predict request: %n%s", predictRequest); + + return predictRequest; + } +} + +// [END retail_prediction_get_simple_prediction] diff --git a/retail/interactive-tutorials/src/test/java/prediction/FilteringPredictionTest.java b/retail/interactive-tutorials/src/test/java/prediction/FilteringPredictionTest.java new file mode 100644 index 00000000000..b8c1c72210e --- /dev/null +++ b/retail/interactive-tutorials/src/test/java/prediction/FilteringPredictionTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prediction; + +import static com.google.common.truth.Truth.assertThat; +import static prediction.FilteringPrediction.predict; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class FilteringPredictionTest { + + private ByteArrayOutputStream bout; + private PrintStream originalPrintStream; + + @Before + public void setUp() throws IOException, InterruptedException { + bout = new ByteArrayOutputStream(); + PrintStream out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @Test + public void testPredict() { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + + String outputResult = bout.toString(); + + assertThat(outputResult).contains("Predict request"); + assertThat(outputResult).contains("filter: \"filterOutOfStockItems\""); + assertThat(outputResult).contains("Predict response"); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } +} diff --git a/retail/interactive-tutorials/src/test/java/prediction/PredictionWithParametersTest.java b/retail/interactive-tutorials/src/test/java/prediction/PredictionWithParametersTest.java new file mode 100644 index 00000000000..0ed278f45ed --- /dev/null +++ b/retail/interactive-tutorials/src/test/java/prediction/PredictionWithParametersTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prediction; + +import static com.google.common.truth.Truth.assertThat; +import static prediction.PredictionWithParameters.predict; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PredictionWithParametersTest { + + private ByteArrayOutputStream bout; + private PrintStream originalPrintStream; + + @Before + public void setUp() throws IOException, InterruptedException { + bout = new ByteArrayOutputStream(); + PrintStream out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @Test + public void testPredict() { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + + String outputResult = bout.toString(); + + assertThat(outputResult).contains("Predict request"); + assertThat(outputResult) + .contains( + "params {\n" + + " key: \"priceRerankLevel\"\n" + + " value {\n" + + " string_value: \"low-price-reranking\"\n" + + " }\n" + + "}"); + assertThat(outputResult).contains("Predict response"); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } +} diff --git a/retail/interactive-tutorials/src/test/java/prediction/SimplePredictionTest.java b/retail/interactive-tutorials/src/test/java/prediction/SimplePredictionTest.java new file mode 100644 index 00000000000..d301030edd8 --- /dev/null +++ b/retail/interactive-tutorials/src/test/java/prediction/SimplePredictionTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prediction; + +import static com.google.common.truth.Truth.assertThat; +import static prediction.SimplePrediction.predict; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SimplePredictionTest { + + private ByteArrayOutputStream bout; + private PrintStream originalPrintStream; + + @Before + public void setUp() throws IOException, InterruptedException { + bout = new ByteArrayOutputStream(); + PrintStream out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @Test + public void testPredict() { + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); + String predictPlacement = + String.format( + "projects/%s/locations/global/catalogs/default_catalog/placements/%s", + projectId, placementId); + + predict(predictPlacement); + + String outputResult = bout.toString(); + + assertThat(outputResult).contains("Predict request"); + assertThat(outputResult) + .contains( + "params {\n" + + " key: \"returnProduct\"\n" + + " value {\n" + + " bool_value: true\n" + + " }\n" + + "}"); + assertThat(outputResult).contains("Predict response"); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } +}