From 413b908fa1b94d6e433b1a638b790f30b414be37 Mon Sep 17 00:00:00 2001
From: Mike <45373284+munkhuushmgl@users.noreply.github.com>
Date: Fri, 6 Nov 2020 11:35:26 -0800
Subject: [PATCH] samples: ucaip samples batch 3 of 6 (#18)
* samples:samples: ucaip samples batch 3 of 6
* made requested the changes
* changed all instance of tables into tabular
* fixed the lint
* reversed some comments
---
aiplatform/snippets/pom.xml | 4 +-
.../CancelTrainingPipelineSample.java | 57 ++++
.../CreateDatasetTabularBigquerySample.java | 89 ++++++
.../CreateDatasetTabularGcsSample.java | 88 ++++++
...ngPipelineTabularClassificationSample.java | 255 ++++++++++++++++++
...ainingPipelineTabularRegressionSample.java | 254 +++++++++++++++++
.../java/aiplatform/DeleteDatasetSample.java | 67 +++++
.../aiplatform/DeleteExportModelSample.java | 45 ++++
.../DeleteTrainingPipelineSample.java | 68 +++++
...xportModelTabularClassificationSample.java | 79 ++++++
...EvaluationTabularClassificationSample.java | 78 ++++++
...odelEvaluationTabularRegressionSample.java | 78 ++++++
.../PredictTabularClassificationSample.java | 73 +++++
.../PredictTabularRegressionSample.java | 73 +++++
...reateDatasetTabularBigquerySampleTest.java | 93 +++++++
.../CreateDatasetTabularGcsSampleTest.java | 93 +++++++
...pelineTabularClassificationSampleTest.java | 121 +++++++++
...ngPipelineTabularRegressionSampleTest.java | 119 ++++++++
...tModelTabularClassificationSampleTest.java | 89 ++++++
...uationTabularClassificationSampleTest.java | 80 ++++++
...EvaluationTabularRegressionSampleTest.java | 80 ++++++
...redictTabularClassificationSampleTest.java | 80 ++++++
.../PredictTabularRegressionSampleTest.java | 100 +++++++
23 files changed, 2161 insertions(+), 2 deletions(-)
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java
create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java
create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java
diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml
index a89bbb4a567..d9652815f2f 100644
--- a/aiplatform/snippets/pom.xml
+++ b/aiplatform/snippets/pom.xml
@@ -23,14 +23,14 @@
UTF-8
+
+
-
com.google.cloud
google-cloud-aiplatform
0.0.1-SNAPSHOT
-
com.google.cloud
google-cloud-storage
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java
new file mode 100644
index 00000000000..4dd2902f328
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_cancel_training_pipeline_sample]
+
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipelineName;
+import java.io.IOException;
+
+public class CancelTrainingPipelineSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID";
+ String project = "YOUR_PROJECT_ID";
+ cancelTrainingPipelineSample(project, trainingPipelineId);
+ }
+
+ static void cancelTrainingPipelineSample(String project, String trainingPipelineId)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ TrainingPipelineName trainingPipelineName =
+ TrainingPipelineName.of(project, location, trainingPipelineId);
+
+ pipelineServiceClient.cancelTrainingPipeline(trainingPipelineName);
+
+ System.out.println("Cancelled the Training Pipeline");
+ }
+ }
+}
+// [END aiplatform_cancel_training_pipeline_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java
new file mode 100644
index 00000000000..bcaf5c94eee
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_dataset_tabular_bigquery_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class CreateDatasetTabularBigquerySample {
+
+ public static void main(String[] args)
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String bigqueryDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+ String bigqueryUri =
+ "bq://YOUR_GOOGLE_CLOUD_PROJECT_ID.BIGQUERY_DATASET_ID.BIGQUERY_TABLE_OR_VIEW_ID";
+ createDatasetTableBigquery(project, bigqueryDisplayName, bigqueryUri);
+ }
+
+ static void createDatasetTableBigquery(
+ String project, String bigqueryDisplayName, String bigqueryUri)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ DatasetServiceSettings settings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) {
+ String location = "us-central1";
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString =
+ "{\"input_config\": {\"bigquery_source\": {\"uri\": \"" + bigqueryUri + "\"}}}";
+ Value.Builder metaData = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, metaData);
+
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName(bigqueryDisplayName)
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .setMetadata(metaData)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.println("Create Dataset Table Bigquery sample");
+ System.out.format("Name: %s\n", datasetResponse.getName());
+ System.out.format("Display Name: %s\n", datasetResponse.getDisplayName());
+ System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", datasetResponse.getMetadata());
+ }
+ }
+}
+// [END aiplatform_create_dataset_tabular_bigquery_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java
new file mode 100644
index 00000000000..2b2f17f41ba
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_dataset_tabular_gcs_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class CreateDatasetTabularGcsSample {
+
+ public static void main(String[] args)
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+ String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_gcs_table/file.csv";
+ ;
+ createDatasetTableGcs(project, datasetDisplayName, gcsSourceUri);
+ }
+
+ static void createDatasetTableGcs(String project, String datasetDisplayName, String gcsSourceUri)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ DatasetServiceSettings settings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) {
+ String location = "us-central1";
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString =
+ "{\"input_config\": {\"gcs_source\": {\"uri\": [\"" + gcsSourceUri + "\"]}}}";
+ Value.Builder metaData = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, metaData);
+
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName(datasetDisplayName)
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .setMetadata(metaData)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.println("Create Dataset Table GCS sample");
+ System.out.format("Name: %s\n", datasetResponse.getName());
+ System.out.format("Display Name: %s\n", datasetResponse.getDisplayName());
+ System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", datasetResponse.getMetadata());
+ }
+ }
+}
+// [END aiplatform_create_dataset_tabular_gcs_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
new file mode 100644
index 00000000000..de54c0a0a07
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java
@@ -0,0 +1,255 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_tabular_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineTabularClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+ String datasetId = "YOUR_DATASET_ID";
+ String targetColumn = "TARGET_COLUMN";
+ String transformation =
+ "[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]";
+ createTrainingPipelineTableClassification(
+ project, modelDisplayName, datasetId, targetColumn, transformation);
+ }
+
+ static void createTrainingPipelineTableClassification(
+ String project,
+ String modelDisplayName,
+ String datasetId,
+ String targetColumn,
+ String transformation)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ LocationName locationName = LocationName.of(project, location);
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";
+ String jsonString =
+ "{\"targetColumn\": \""
+ + targetColumn
+ + "\",\"predictionType\": \"classification\",\"transformations\": "
+ + transformation
+ + ",\"trainBudgetMilliNodeHours\": 8000}";
+
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ FractionSplit fractionSplit =
+ FractionSplit.newBuilder()
+ .setTrainingFraction(0.8)
+ .setValidationFraction(0.1)
+ .setTestFraction(0.1)
+ .build();
+
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder()
+ .setDatasetId(datasetId)
+ .setFractionSplit(fractionSplit)
+ .build();
+ Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
+
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(modelDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(inputDataConfig)
+ .setModelToUpload(modelToUpload)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Tabular Classification Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+ System.out.format(
+ "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+
+ System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
+ System.out.format(
+ "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());
+
+ FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format(
+ "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
+ System.out.format(
+ "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList().toString());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList().toString());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList().toString());
+
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+
+ System.out.println("\tPredict Schemata");
+ System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format(
+ "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format(
+ "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (Model.ExportFormat supportedExportFormat :
+ modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("\tSupported Export Format");
+ System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
+ }
+ ModelContainerSpec containerSpec = modelResponse.getContainerSpec();
+
+ System.out.println("\tContainer Spec");
+ System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
+ System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
+ System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
+ System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
+ System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());
+
+ for (EnvVar envVar : containerSpec.getEnvList()) {
+ System.out.println("\t\tEnv");
+ System.out.format("\t\t\tName: %s\n", envVar.getName());
+ System.out.format("\t\t\tValue: %s\n", envVar.getValue());
+ }
+
+ for (Port port : containerSpec.getPortsList()) {
+ System.out.println("\t\tPort");
+ System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("\tDeployed Model");
+ System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("\tExplanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("\t\tParameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("\t\tSampled Shapley Attribution");
+ System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("\t\tMetadata");
+ System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "\t\t\tFeature Attributions Schema Uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_tabular_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
new file mode 100644
index 00000000000..ca24862cad2
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java
@@ -0,0 +1,254 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_tabular_regression_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineTabularRegressionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+ String datasetId = "YOUR_DATASET_ID";
+ String targetColumn = "TARGET_COLUMN";
+ String transformation =
+ "[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]";
+ createTrainingPipelineTableRegression(
+ project, modelDisplayName, datasetId, targetColumn, transformation);
+ }
+
+ static void createTrainingPipelineTableRegression(
+ String project,
+ String modelDisplayName,
+ String datasetId,
+ String targetColumn,
+ String transformation)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ LocationName locationName = LocationName.of(project, location);
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";
+ String jsonString =
+ "{\"targetColumn\": \""
+ + targetColumn
+ + "\",\"predictionType\": \"regression\",\"transformations\": "
+ + transformation
+ + ",\"trainBudgetMilliNodeHours\": 8000}";
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ FractionSplit fractionSplit =
+ FractionSplit.newBuilder()
+ .setTrainingFraction(0.8)
+ .setValidationFraction(0.1)
+ .setTestFraction(0.1)
+ .build();
+
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder()
+ .setDatasetId(datasetId)
+ .setFractionSplit(fractionSplit)
+ .build();
+ Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
+
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(modelDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(inputDataConfig)
+ .setModelToUpload(modelToUpload)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Tabular Regression Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+ System.out.format(
+ "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+
+ System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
+ System.out.format(
+ "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());
+
+ FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format(
+ "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
+ System.out.format(
+ "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList().toString());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList().toString());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList().toString());
+
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+
+ System.out.println("\tPredict Schemata");
+ System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format(
+ "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format(
+ "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (Model.ExportFormat supportedExportFormat :
+ modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("\tSupported Export Format");
+ System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
+ }
+ ModelContainerSpec containerSpec = modelResponse.getContainerSpec();
+
+ System.out.println("\tContainer Spec");
+ System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
+ System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
+ System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
+ System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
+ System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());
+
+ for (EnvVar envVar : containerSpec.getEnvList()) {
+ System.out.println("\t\tEnv");
+ System.out.format("\t\t\tName: %s\n", envVar.getName());
+ System.out.format("\t\t\tValue: %s\n", envVar.getValue());
+ }
+
+ for (Port port : containerSpec.getPortsList()) {
+ System.out.println("\t\tPort");
+ System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("\tDeployed Model");
+ System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("\tExplanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("\t\tParameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("\t\tSampled Shapley Attribution");
+ System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("\t\tMetadata");
+ System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "\t\t\tFeature Attributions Schema Uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_tabular_regression_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
new file mode 100644
index 00000000000..39ad52d0fdf
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_delete_dataset_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.protobuf.Empty;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class DeleteDatasetSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ deleteDatasetSample(project, datasetId);
+ }
+
+ static void deleteDatasetSample(String project, String datasetId)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ operationFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.format("Deleted Dataset.");
+ }
+ }
+}
+// [END aiplatform_delete_dataset_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java
new file mode 100644
index 00000000000..d6ed1995714
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_delete_export_model_sample]
+
+import com.google.cloud.storage.Blob;
+import com.google.cloud.storage.Storage;
+import com.google.cloud.storage.StorageOptions;
+
+public class DeleteExportModelSample {
+
+ public static void main(String[] args) {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String bucketName = "YOUR_BUCKET_NAME";
+ String folderName = "YOUR_FOLDER_NAME";
+ deleteExportModelSample(project, bucketName, folderName);
+ }
+
+ static void deleteExportModelSample(String project, String bucketName, String folderName) {
+ Storage storage = StorageOptions.newBuilder().setProjectId(project).build().getService();
+ Iterable blobs =
+ storage.list(bucketName, Storage.BlobListOption.prefix(folderName)).iterateAll();
+ for (Blob blob : blobs) {
+ blob.delete(Blob.BlobSourceOption.generationMatch());
+ }
+ System.out.println("Export Model Deleted");
+ }
+}
+// [END aiplatform_delete_export_model_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java
new file mode 100644
index 00000000000..d3819cd03ea
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_delete_training_pipeline_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipelineName;
+import com.google.protobuf.Empty;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class DeleteTrainingPipelineSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID";
+ String project = "YOUR_PROJECT_ID";
+ deleteTrainingPipelineSample(project, trainingPipelineId);
+ }
+
+ static void deleteTrainingPipelineSample(String project, String trainingPipelineId)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ TrainingPipelineName trainingPipelineName =
+ TrainingPipelineName.of(project, location, trainingPipelineId);
+
+ OperationFuture operationFuture =
+ pipelineServiceClient.deleteTrainingPipelineAsync(trainingPipelineName);
+ System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ operationFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.format("Deleted Training Pipeline.");
+ }
+ }
+}
+// [END aiplatform_delete_training_pipeline_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java
new file mode 100644
index 00000000000..f3fedf710c2
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_export_model_tabular_classification_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.ExportModelOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExportModelRequest;
+import com.google.cloud.aiplatform.v1beta1.ExportModelResponse;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ExportModelTabularClassificationSample {
+ public static void main(String[] args)
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String gcsDestinationOutputUriPrefix = "gs://your-gcs-bucket/destination_path";
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ exportModelTableClassification(gcsDestinationOutputUriPrefix, project, modelId);
+ }
+
+ static void exportModelTableClassification(
+ String gcsDestinationOutputUriPrefix, String project, String modelId)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelName modelName = ModelName.of(project, location, modelId);
+
+ GcsDestination.Builder gcsDestination = GcsDestination.newBuilder();
+ gcsDestination.setOutputUriPrefix(gcsDestinationOutputUriPrefix);
+ ExportModelRequest.OutputConfig outputConfig =
+ ExportModelRequest.OutputConfig.newBuilder()
+ .setExportFormatId("tf-saved-model")
+ .setArtifactDestination(gcsDestination)
+ .build();
+
+ OperationFuture exportModelResponseFuture =
+ modelServiceClient.exportModelAsync(modelName, outputConfig);
+ System.out.format(
+ "Operation name: %s\n", exportModelResponseFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ ExportModelResponse exportModelResponse =
+ exportModelResponseFuture.get(300, TimeUnit.SECONDS);
+ System.out.format(
+ "Export Model Tabular Classification Response: %s", exportModelResponse.toString());
+ }
+ }
+}
+// [END aiplatform_export_model_tabular_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java
new file mode 100644
index 00000000000..e347bf820d6
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_tabular_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationTabularClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationTabularClassification(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationTabularClassification(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Tabular Classification Response");
+ System.out.format("\tName: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+
+ System.out.println("\tModel Explanation");
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+ System.out.println("\t\tMean Attributions");
+ System.out.format(
+ "\t\t\tBaseline Output value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Index Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_tabular_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java
new file mode 100644
index 00000000000..bc9910a793b
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_tabular_regression_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationTabularRegressionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationTabularRegression(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationTabularRegression(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Tabular Regression Response");
+ System.out.format("\tName: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+
+ System.out.println("\tModel Explanation");
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+ System.out.println("\t\tMean Attributions");
+ System.out.format(
+ "\t\t\tBaseline Output value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Index Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_tabular_regression_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java
new file mode 100644
index 00000000000..302af2d55bd
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_tabular_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.ListValue;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.List;
+
+public class PredictTabularClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]";
+ String endpointId = "YOUR_ENDPOINT_ID";
+ predictTabularClassification(instance, project, endpointId);
+ }
+
+ static void predictTabularClassification(String instance, String project, String endpointId)
+ throws IOException {
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings)) {
+ String location = "us-central1";
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ ListValue.Builder listValue = ListValue.newBuilder();
+ JsonFormat.parser().merge(instance, listValue);
+ List instanceList = listValue.getValuesList();
+
+ Value parameters = Value.newBuilder().build();
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instanceList, parameters);
+ System.out.println("Predict Tabular Classification Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_tabular_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java
new file mode 100644
index 00000000000..7520f554d84
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_tabular_regression_sample]
+
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.ListValue;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.List;
+
+public class PredictTabularRegressionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]";
+ String endpointId = "YOUR_ENDPOINT_ID";
+ predictTabularRegression(instance, project, endpointId);
+ }
+
+ static void predictTabularRegression(String instance, String project, String endpointId)
+ throws IOException {
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings)) {
+ String location = "us-central1";
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ ListValue.Builder listValue = ListValue.newBuilder();
+ JsonFormat.parser().merge(instance, listValue);
+ List instanceList = listValue.getValuesList();
+
+ Value parameters = Value.newBuilder().build();
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instanceList, parameters);
+ System.out.println("Predict Tabular Regression Response");
+ System.out.format("\tDisplay Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_tabular_regression_sample]
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java
new file mode 100644
index 00000000000..42b002514a5
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateDatasetTabularBigquerySampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String GCS_SOURCE_URI = "bq://ucaip-sample-tests.table_test.all_bq_types";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String datasetId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Delete the created dataset
+ DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Dataset.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateDatasetTabularBigquerySample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ String datasetDisplayName =
+ String.format(
+ "temp_create_dataset_table_bigquery_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDatasetTabularBigquerySample.createDatasetTableBigquery(
+ PROJECT, datasetDisplayName, GCS_SOURCE_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(datasetDisplayName);
+ assertThat(got).contains("Create Dataset Table Bigquery sample");
+ datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
new file mode 100644
index 00000000000..10a26a5e144
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateDatasetTabularGcsSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String GCS_SOURCE_URI = "gs://cloud-ml-tables-data/bank-marketing.csv";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String datasetId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Delete the created dataset
+ DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Dataset.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateDatasetTabularGcsSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ String datasetDisplayName =
+ String.format(
+ "temp_create_dataset_table_gcs_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDatasetTabularGcsSample.createDatasetTableGcs(PROJECT,
+ datasetDisplayName, GCS_SOURCE_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(datasetDisplayName);
+ assertThat(got).contains("Create Dataset Table GCS sample");
+ datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
new file mode 100644
index 00000000000..65f7d041bf2
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineTabularClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID");
+ private static final String TARGET_COLUMN = "TripType";
+ private static final String TRANSFORMATION =
+ "[{\"numeric\":{\"columnName\":\"Age\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"Job\"}},"
+ + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}},"
+ + "{\"categorical\":{\"columnName\":\"Default\"}},"
+ + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"Housing\"}},"
+ + "{\"categorical\":{\"columnName\":\"Loan\"}},"
+ + "{\"categorical\":{\"columnName\":\"Contact\"}},"
+ + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"Month\"}},"
+ + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"POutcome\"}},"
+ + "{\"categorical\":{\"columnName\":\"Deposit\"}}]";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void createTrainingPipelineTabularClassification() throws IOException {
+ // Act
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipelinetabularclassification_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineTabularClassificationSample.createTrainingPipelineTableClassification(
+ PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN, TRANSFORMATION);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Tabular Classification Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
new file mode 100644
index 00000000000..3933106f3e6
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineTabularRegressionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID");
+ private static final String TARGET_COLUMN = "Amount";
+ private static final String TRANSFORMATION =
+ "[{\"categorical\":{\"columnName\":\"SC_Group_Desc\"}},"
+ + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}},"
+ + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}},"
+ + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}},"
+ + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void createTrainingPipelineTabularRegression() throws IOException {
+ // Act
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipelinetabularregression_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineTabularRegressionSample.createTrainingPipelineTableRegression(
+ PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN, TRANSFORMATION);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Tabular Regression Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java
new file mode 100644
index 00000000000..9212dc3d920
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class ExportModelTabularClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID =
+ System.getenv("EXPORT_MODEL_TABLES_CLASSIFICATION_MODEL_ID");
+ private static final String GCS_DESTINATION_URI_PREFIX =
+ "gs://ucaip-samples-test-output/tmp/export_model_test";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("EXPORT_MODEL_TABLES_CLASSIFICATION_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ // Delete the export model
+ String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2];
+ String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID);
+ DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Export Model Deleted");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void exportModelTabularClassification()
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // Act
+ ExportModelTabularClassificationSample.exportModelTableClassification(
+ GCS_DESTINATION_URI_PREFIX, PROJECT, MODEL_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Export Model Tabular Classification Response: ");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java
new file mode 100644
index 00000000000..6995dcd9f8a
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetModelEvaluationTabularClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID =
+ System.getenv("MODEL_EVALUATION_TABLES_CLASSIFICATION_MODEL_ID");
+ private static final String EVALUATION_ID =
+ System.getenv("MODEL_EVALUATION_TABLES_CLASSIFICATION_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("MODEL_EVALUATION_TABLES_CLASSIFICATION_MODEL_ID");
+ requireEnvVar("MODEL_EVALUATION_TABLES_CLASSIFICATION_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void getModelEvaluationTabularClassification() throws IOException {
+ // Act
+ GetModelEvaluationTabularClassificationSample.getModelEvaluationTabularClassification(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Tabular Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java
new file mode 100644
index 00000000000..81daedecc6d
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetModelEvaluationTabularRegressionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID =
+ System.getenv("MODEL_EVALUATION_TABLES_REGRESSION_MODEL_ID");
+ private static final String EVALUATION_ID =
+ System.getenv("MODEL_EVALUATION_TABLES_REGRESSION_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("MODEL_EVALUATION_TABLES_REGRESSION_MODEL_ID");
+ requireEnvVar("MODEL_EVALUATION_TABLES_REGRESSION_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void getModelEvaluationTabularRegression() throws IOException {
+ // Act
+ GetModelEvaluationTabularRegressionSample.getModelEvaluationTabularRegression(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Tabular Regression Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java
new file mode 100644
index 00000000000..1574efe2ae1
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictTabularClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String INSTANCE =
+ "[{\"petal_length\": '1.4',"
+ + " \"petal_width\": '1.3',"
+ + " \"sepal_length\": '5.1',"
+ + " \"sepal_width\": '2.8'}]";
+ private static final String ENDPOINT_ID =
+ System.getenv("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictTabularClassification() throws IOException {
+ // Act
+ PredictTabularClassificationSample.predictTabularClassification(INSTANCE, PROJECT, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Tabular Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java
new file mode 100644
index 00000000000..44f5bfdfa21
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictTabularRegressionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String INSTANCE =
+ "[{\n"
+ + " \"BOOLEAN_2unique_NULLABLE\": False,\n"
+ + " \"DATETIME_1unique_NULLABLE\": '2019-01-01 00:00:00',\n"
+ + " \"DATE_1unique_NULLABLE\": '2019-01-01',\n"
+ + " \"FLOAT_5000unique_NULLABLE\": 1611,\n"
+ + " \"FLOAT_5000unique_REPEATED\": [2320,1192],\n"
+ + " \"INTEGER_5000unique_NULLABLE\": '8',\n"
+ + " \"NUMERIC_5000unique_NULLABLE\": 16,\n"
+ + " \"STRING_5000unique_NULLABLE\": 'str-2',\n"
+ + " \"STRUCT_NULLABLE\": {\n"
+ + " 'BOOLEAN_2unique_NULLABLE': False,\n"
+ + " 'DATE_1unique_NULLABLE': '2019-01-01',\n"
+ + " 'DATETIME_1unique_NULLABLE': '2019-01-01 00:00:00',\n"
+ + " 'FLOAT_5000unique_NULLABLE': 1308,\n"
+ + " 'FLOAT_5000unique_REPEATED': [2323, 1178],\n"
+ + " 'FLOAT_5000unique_REQUIRED': 3089,\n"
+ + " 'INTEGER_5000unique_NULLABLE': '1777',\n"
+ + " 'NUMERIC_5000unique_NULLABLE': 3323,\n"
+ + " 'TIME_1unique_NULLABLE': '23:59:59.999999',\n"
+ + " 'STRING_5000unique_NULLABLE': 'str-49',\n"
+ + " 'TIMESTAMP_1unique_NULLABLE': '1546387199999999'\n"
+ + " },\n"
+ + " \"TIMESTAMP_1unique_NULLABLE\": '1546387199999999',\n"
+ + " \"TIME_1unique_NULLABLE\": '23:59:59.999999'\n"
+ + "}]";
+ private static final String ENDPOINT_ID = System.getenv("PREDICT_TABLES_REGRESSION_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("PREDICT_TABLES_REGRESSION_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictTabularRegression() throws IOException {
+ // Act
+ PredictTabularRegressionSample.predictTabularRegression(INSTANCE, PROJECT, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Tabular Regression Response");
+ }
+}