From ac542f24512da01d106ad528aa5cee6de9af2a22 Mon Sep 17 00:00:00 2001
From: Mike <45373284+munkhuushmgl@users.noreply.github.com>
Date: Fri, 6 Nov 2020 11:42:04 -0800
Subject: [PATCH] samples: ucaip samples batch 6 of 6 (#17)
---
aiplatform/snippets/pom.xml | 13 +-
.../aiplatform/CreateDatasetTextSample.java | 84 +++++++
...iningPipelineTextClassificationSample.java | 235 +++++++++++++++++
...ingPipelineTextEntityExtractionSample.java | 236 +++++++++++++++++
...ngPipelineTextSentimentAnalysisSample.java | 238 ++++++++++++++++++
...delEvaluationTextClassificationSample.java | 81 ++++++
...lEvaluationTextEntityExtractionSample.java | 81 ++++++
...EvaluationTextSentimentAnalysisSample.java | 81 ++++++
...taTextClassificationSingleLabelSample.java | 90 +++++++
.../ImportDataTextEntityExtractionSample.java | 89 +++++++
...ImportDataTextSentimentAnalysisSample.java | 90 +++++++
...ctTextClassificationSingleLabelSample.java | 78 ++++++
.../PredictTextEntityExtractionSample.java | 78 ++++++
.../PredictTextSentimentAnalysisSample.java | 78 ++++++
.../CreateDatasetTextSampleTest.java | 94 +++++++
...gPipelineTextClassificationSampleTest.java | 110 ++++++++
...ipelineTextEntityExtractionSampleTest.java | 112 +++++++++
...pelineTextSentimentAnalysisSampleTest.java | 112 +++++++++
...valuationTextClassificationSampleTest.java | 81 ++++++
...luationTextEntityExtractionSampleTest.java | 81 ++++++
...uationTextSentimentAnalysisSampleTest.java | 81 ++++++
.../java/aiplatform/ImportDataSampleTest.java | 90 +++++++
...xtClassificationSingleLabelSampleTest.java | 76 ++++++
...PredictTextEntityExtractionSampleTest.java | 89 +++++++
...redictTextSentimentAnalysisSampleTest.java | 90 +++++++
25 files changed, 2566 insertions(+), 2 deletions(-)
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml
index 335f524539a..8ea1e440238 100644
--- a/aiplatform/snippets/pom.xml
+++ b/aiplatform/snippets/pom.xml
@@ -22,8 +22,6 @@
1.8
UTF-8
-
-
@@ -31,6 +29,17 @@
google-cloud-aiplatform
0.0.1-SNAPSHOT
+
+
+ com.google.protobuf
+ protobuf-java-util
+ 4.0.0-rc-1
+
+
+ com.google.cloud
+ google-cloud-storage
+ 1.111.0
+
com.google.cloud
google-cloud-storage
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java
new file mode 100644
index 00000000000..ff3c93ee4d8
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_dataset_text_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class CreateDatasetTextSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+
+ createDatasetTextSample(project, datasetDisplayName);
+ }
+
+ static void createDatasetTextSample(String project, String datasetDisplayName)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml";
+
+ LocationName locationName = LocationName.of(project, location);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName(datasetDisplayName)
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName());
+
+ System.out.println("Waiting for operation to finish...");
+ Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS);
+
+ System.out.println("Create Text Dataset Response");
+ System.out.format("\tName: %s\n", datasetResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", datasetResponse.getDisplayName());
+ System.out.format("\tMetadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri());
+ System.out.format("\tMetadata: %s\n", datasetResponse.getMetadata());
+ System.out.format("\tCreate Time: %s\n", datasetResponse.getCreateTime());
+ System.out.format("\tUpdate Time: %s\n", datasetResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", datasetResponse.getLabelsMap());
+ }
+ }
+}
+// [END aiplatform_create_dataset_text_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
new file mode 100644
index 00000000000..c67e49a058e
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java
@@ -0,0 +1,235 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_text_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineTextClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+
+ createTrainingPipelineTextClassificationSample(
+ project, trainingPipelineDisplayName, datasetId, modelDisplayName);
+ }
+
+ static void createTrainingPipelineTextClassificationSample(
+ String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_text_classification_1.0.0.yaml";
+ String jsonString = "{\"multiLabel\": false}";
+
+ LocationName locationName = LocationName.of(project, location);
+
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig trainingInputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(trainingInputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Text Classification Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("State: %s\n", trainingPipelineResponse.getState());
+
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
+ System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfig.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList());
+
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());
+
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+ System.out.println("\t\tPredict Schemata");
+ System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format(
+ "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format(
+ "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("\t\tSupported Export Format");
+ System.out.format("\t\t\tId: %s\n", exportFormat.getId());
+ }
+
+ ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
+ System.out.println("\t\tContainer Spec");
+ System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
+ System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
+ System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
+ System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
+ System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());
+
+ for (EnvVar envVar : modelContainerSpec.getEnvList()) {
+ System.out.println("\t\t\tEnv");
+ System.out.format("\t\t\t\tName: %s\n", envVar.getName());
+ System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
+ }
+
+ for (Port port : modelContainerSpec.getPortsList()) {
+ System.out.println("\t\t\tPort");
+ System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("\t\tDeployed Model");
+ System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("\t\tExplanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("\t\t\tParameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("\t\t\t\tSampled Shapley Attribution");
+ System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("\t\t\tMetadata");
+ System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "\t\t\t\tFeature Attributions Schema_uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_text_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
new file mode 100644
index 00000000000..3aef27086ba
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java
@@ -0,0 +1,236 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_text_entity_extraction_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineTextEntityExtractionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+
+ createTrainingPipelineTextEntityExtractionSample(
+ project, trainingPipelineDisplayName, datasetId, modelDisplayName);
+ }
+
+ static void createTrainingPipelineTextEntityExtractionSample(
+ String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_text_extraction_1.0.0.yaml";
+ String jsonString = "{}";
+
+ LocationName locationName = LocationName.of(project, location);
+
+ // Training task inputs are empty for text entity extraction
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig trainingInputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(trainingInputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Text Entity Extraction Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("State: %s\n", trainingPipelineResponse.getState());
+
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
+ System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfig.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList());
+
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());
+
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+ System.out.println("\t\tPredict Schemata");
+ System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format(
+ "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format(
+ "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("\t\tSupported Export Format");
+ System.out.format("\t\t\tId: %s\n", exportFormat.getId());
+ }
+
+ ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
+ System.out.println("\t\tContainer Spec");
+ System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
+ System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
+ System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
+ System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
+ System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());
+
+ for (EnvVar envVar : modelContainerSpec.getEnvList()) {
+ System.out.println("\t\t\tEnv");
+ System.out.format("\t\t\t\tName: %s\n", envVar.getName());
+ System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
+ }
+
+ for (Port port : modelContainerSpec.getPortsList()) {
+ System.out.println("\t\t\tPort");
+ System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("\t\tDeployed Model");
+ System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("\t\tExplanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("\t\t\tParameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("\t\t\t\tSampled Shapley Attribution");
+ System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("\t\t\tMetadata");
+ System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "\t\t\t\tFeature Attributions Schema_uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_text_entity_extraction_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
new file mode 100644
index 00000000000..a6139405d31
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java
@@ -0,0 +1,238 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_text_sentiment_analysis_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineTextSentimentAnalysisSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+
+ createTrainingPipelineTextSentimentAnalysisSample(
+ project, trainingPipelineDisplayName, datasetId, modelDisplayName);
+ }
+
+ static void createTrainingPipelineTextSentimentAnalysisSample(
+ String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_text_sentiment_1.0.0.yaml";
+
+ // Sentiment max must be between 1 and 10 inclusive.
+ // Higher value means positive sentiment.
+ String jsonString = "{\"sentimentMax\": 4 }";
+
+ LocationName locationName = LocationName.of(project, location);
+
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig trainingInputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(trainingInputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Text Sentiment Analysis Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("State: %s\n", trainingPipelineResponse.getState());
+
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
+ System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfig.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList());
+
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());
+
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+ System.out.println("\t\tPredict Schemata");
+ System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format(
+ "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format(
+ "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("\t\tSupported Export Format");
+ System.out.format("\t\t\tId: %s\n", exportFormat.getId());
+ }
+
+ ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
+ System.out.println("\t\tContainer Spec");
+ System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
+ System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
+ System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
+ System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
+ System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());
+
+ for (EnvVar envVar : modelContainerSpec.getEnvList()) {
+ System.out.println("\t\t\tEnv");
+ System.out.format("\t\t\t\tName: %s\n", envVar.getName());
+ System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
+ }
+
+ for (Port port : modelContainerSpec.getPortsList()) {
+ System.out.println("\t\t\tPort");
+ System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("\t\tDeployed Model");
+ System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("\t\tExplanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("\t\t\tParameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("\t\t\t\tSampled Shapley Attribution");
+ System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("\t\t\tMetadata");
+ System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "\t\t\t\tFeature Attributions Schema_uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_text_sentiment_analysis_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java
new file mode 100644
index 00000000000..68820041ebe
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_text_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationTextClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+
+ getModelEvaluationTextClassificationSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationTextClassificationSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Text Classification Response");
+ System.out.format("\tModel Name: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+
+ System.out.println("\tModel Explanation");
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+
+ System.out.println("\t\tMean Attribution");
+ System.out.format(
+ "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_text_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java
new file mode 100644
index 00000000000..d85f9d3f3af
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_text_entity_extraction_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationTextEntityExtractionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+
+ getModelEvaluationTextEntityExtractionSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationTextEntityExtractionSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Text Entity Extraction Response");
+ System.out.format("\tModel Name: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+
+ System.out.println("\tModel Explanation");
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+
+ System.out.println("\t\tMean Attribution");
+ System.out.format(
+ "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_text_entity_extraction_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java
new file mode 100644
index 00000000000..0ccd5286898
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_text_sentiment_analysis_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationTextSentimentAnalysisSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+
+ getModelEvaluationTextSentimentAnalysisSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationTextSentimentAnalysisSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Text Sentiment Analysis Response");
+ System.out.format("\tModel Name: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+
+ System.out.println("\tModel Explanation");
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+
+ System.out.println("\t\tMean Attribution");
+ System.out.format(
+ "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_text_sentiment_analysis_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
new file mode 100644
index 00000000000..cf08bd38ae9
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_text_classification_single_label_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataTextClassificationSingleLabelSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]";
+
+ importDataTextClassificationSingleLabelSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataTextClassificationSingleLabelSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "text_classification_single_label_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ List importDataConfigList =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigList);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+ System.out.format("Import Data Text Classification Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_text_classification_single_label_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
new file mode 100644
index 00000000000..6bd5c4f0297
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_text_entity_extraction_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataTextEntityExtractionSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.jsonl]";
+
+ importDataTextEntityExtractionSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataTextEntityExtractionSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "text_extraction_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ List importDataConfigList =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigList);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+ System.out.format("Import Data Text Entity Extraction Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_text_entity_extraction_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
new file mode 100644
index 00000000000..e1750678392
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_text_sentiment_analysis_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataTextSentimentAnalysisSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]";
+
+ importDataTextSentimentAnalysisSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataTextSentimentAnalysisSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "text_sentiment_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ List importDataConfigList =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigList);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+ System.out.format("Import Data Text Sentiment Analysis Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_text_sentiment_analysis_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java
new file mode 100644
index 00000000000..84e323f59a5
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_text_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class PredictTextClassificationSingleLabelSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String content = "YOUR_TEXT_CONTENT";
+ String endpointId = "YOUR_ENDPOINT_ID";
+
+ predictTextClassificationSingleLabel(project, content, endpointId);
+ }
+
+ static void predictTextClassificationSingleLabel(
+ String project, String content, String endpointId) throws IOException {
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings)) {
+ String location = "us-central1";
+ String jsonString = "{\"content\": \"" + content + "\"}";
+
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
+ Value.Builder instance = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, instance);
+
+ List instances = new ArrayList<>();
+ instances.add(instance.build());
+
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instances, parameter);
+ System.out.println("Predict Text Classification Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_text_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java
new file mode 100644
index 00000000000..3fb1559ec90
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_text_entity_extraction_sample]
+
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class PredictTextEntityExtractionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String content = "YOUR_TEXT_CONTENT";
+ String endpointId = "YOUR_ENDPOINT_ID";
+
+ predictTextEntityExtraction(project, content, endpointId);
+ }
+
+ static void predictTextEntityExtraction(String project, String content, String endpointId)
+ throws IOException {
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings)) {
+ String location = "us-central1";
+ String jsonString = "{\"content\": \"" + content + "\"}";
+
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
+ Value.Builder instance = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, instance);
+
+ List instances = new ArrayList<>();
+ instances.add(instance.build());
+
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instances, parameter);
+ System.out.println("Predict Text Entity Extraction Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_text_entity_extraction_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java
new file mode 100644
index 00000000000..8ac212a26ef
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_text_sentiment_analysis_sample]
+
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class PredictTextSentimentAnalysisSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String content = "YOUR_TEXT_CONTENT";
+ String endpointId = "YOUR_ENDPOINT_ID";
+
+ predictTextSentimentAnalysis(project, content, endpointId);
+ }
+
+ static void predictTextSentimentAnalysis(String project, String content, String endpointId)
+ throws IOException {
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings)) {
+ String location = "us-central1";
+ String jsonString = "{\"content\": \"" + content + "\"}";
+
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
+ Value.Builder instance = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, instance);
+
+ List instances = new ArrayList<>();
+ instances.add(instance.build());
+
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instances, parameter);
+ System.out.println("Predict Text Sentiment Analysis Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_text_sentiment_analysis_sample]
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java
new file mode 100644
index 00000000000..a4048e5d96d
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateDatasetTextSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String datasetId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Delete the created dataset
+ DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Dataset.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateDatasetSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ String datasetDisplayName =
+ String.format(
+ "temp_create_dataset_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDatasetTextSample.createDatasetTextSample(PROJECT, datasetDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(datasetDisplayName);
+ assertThat(got).contains("Create Text Dataset Response");
+ datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java
new file mode 100644
index 00000000000..5b68dab26f6
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateTrainingPipelineTextClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineTextClassificationSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineTextClassificationSample.createTrainingPipelineTextClassificationSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Text Classification Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java
new file mode 100644
index 00000000000..367b9b3ac5c
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateTrainingPipelineTextEntityExtractionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineTextEntityExtractionSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineTextEntityExtractionSample
+ .createTrainingPipelineTextEntityExtractionSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Text Entity Extraction Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
new file mode 100644
index 00000000000..bd0f29461bf
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateTrainingPipelineTextSentimentAnalysisSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineTextSentimentAnalysisSample() throws IOException {
+ String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26);
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ tempUuid);
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ tempUuid);
+
+ CreateTrainingPipelineTextSentimentAnalysisSample
+ .createTrainingPipelineTextSentimentAnalysisSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Text Sentiment Analysis Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java
new file mode 100644
index 00000000000..b0b646723a6
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GetModelEvaluationTextClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("TEXT_CLASS_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("TEXT_CLASS_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_CLASS_MODEL_ID");
+ requireEnvVar("TEXT_CLASS_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationTextClassificationSample() throws IOException {
+ // Act
+ GetModelEvaluationTextClassificationSample.getModelEvaluationTextClassificationSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Text Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java
new file mode 100644
index 00000000000..55d89370254
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GetModelEvaluationTextEntityExtractionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("TEXT_ENTITY_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("TEXT_ENTITY_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_ENTITY_MODEL_ID");
+ requireEnvVar("TEXT_ENTITY_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationTextEntityExtractionSample() throws IOException {
+ // Act
+ GetModelEvaluationTextEntityExtractionSample.getModelEvaluationTextEntityExtractionSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Text Entity Extraction Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java
new file mode 100644
index 00000000000..1aeb635914d
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GetModelEvaluationTextSentimentAnalysisSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("TEXT_SENTI_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("TEXT_SENTI_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_SENTI_MODEL_ID");
+ requireEnvVar("TEXT_SENTI_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationTextSentimentAnalysisSample() throws IOException {
+ // Act
+ GetModelEvaluationTextSentimentAnalysisSample.getModelEvaluationTextSentimentAnalysisSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Text Sentiment Analysis Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
new file mode 100644
index 00000000000..72d5ad90f5f
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import io.grpc.StatusRuntimeException;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ImportDataSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID = "000000000000000000000";
+
+ private static final String GCS_SOURCE_URI =
+ "gs://automl-cloud-dataset/SMSSpamCollection_train_dataset_2.csv";
+
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataSample()
+ throws TimeoutException {
+ // As import data into dataset can take a long time, instead try to import data into a
+ // nonexistent dataset and confirm that the model was not found, but other
+ // elements of the request were valid.
+ try {
+ ImportDataTextClassificationSingleLabelSample.importDataTextClassificationSingleLabelSample(
+ PROJECT, DATASET_ID, GCS_SOURCE_URI);
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("The Dataset does not exist.");
+ } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) {
+ assertThat(e.getMessage()).contains("The Dataset does not exist.");
+ }
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java
new file mode 100644
index 00000000000..a47674098a9
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictTextClassificationSingleLabelSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String TEXT_CONTENT = "This is the test String!";
+ private static final String ENDPOINT_ID = System.getenv("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictTextClassification() throws IOException {
+ // Act
+ PredictTextClassificationSingleLabelSample.predictTextClassificationSingleLabel(
+ PROJECT, TEXT_CONTENT, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Text Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
new file mode 100644
index 00000000000..dde83fa06d9
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictTextEntityExtractionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String TEXT_CONTENT =
+ "1127526\\tAnalbuminemia in a neonate.\\tA small-for-gestational-age infant , found to have"
+ + " analbuminemia in the neonatal period , is reported and the twelve cases recorded in"
+ + " the world literature are reviewed . Patients lacking this serum protein are"
+ + " essentially asymptomatic , apart from minimal ankle edema and ease of fatigue ."
+ + " Apparent compensatory mechanisms which come into play when serum albumin is low"
+ + " include prolonged half-life of albumin and transferrin , an increase in serum"
+ + " globulins , beta lipoprotein , and glycoproteins , arterial hypotension with reduced"
+ + " capillary hydrostatic pressure , and the ability to respond with rapid sodium and"
+ + " chloride diuresis in response to small volume changes . Examination of plasma amino"
+ + " acids , an investigation not previously reported , revealed an extremely low plasma"
+ + " tryptophan level , a finding which may be important in view of the role of"
+ + " tryptophan in albumin synthesis.";
+ private static final String ENDPOINT_ID = System.getenv("TEXT_ENTITY_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_ENTITY_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictTextEntityExtraction() throws IOException {
+ // Act
+ PredictTextEntityExtractionSample.predictTextEntityExtraction(
+ PROJECT, TEXT_CONTENT, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Text Entity Extraction Response");
+ }
+}
+
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
new file mode 100644
index 00000000000..1189e391728
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictTextSentimentAnalysisSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String TEXT_CONTENT =
+ "I was excited at the concept of my favorite comic book hero being on television... and"
+ + " sorely disappointed at the end result.
The only amazing thing was the"
+ + " wall crawling (despite the visibility of the cable). I didn't think Nick Hammond was"
+ + " Peter Parker... and he was visibly of a different build than the guy who did the"
+ + " stunts in the spider suit. You could tell they were two different actors.
Granted, I can also spot in the modern Spider-Man movies when I am looking at"
+ + " Tobey Macguire and when I am looking at CGI. But that is from a trained eye and"
+ + " experience working with CGI. Still, the 70's version could have been better despite"
+ + " lack of Special FX.
The webs were hokey and looked like ropes that seemed"
+ + " to wrap around things rather than stick to them. And what was up with giving him a"
+ + " spider mobile to ride around in. Hello? He's the web slinger people.
Sorry... didn't mean to get so worked up, but our beloved wall crawler deserved"
+ + " better.";
+ private static final String ENDPOINT_ID = System.getenv("TEXT_SENTI_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TEXT_SENTI_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictTextSentimentAnalysis() throws IOException {
+ // Act
+ PredictTextSentimentAnalysisSample.predictTextSentimentAnalysis(
+ PROJECT, TEXT_CONTENT, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Text Sentiment Analysis Response");
+ }
+}
+