diff --git a/aiplatform/snippets/format.sh b/aiplatform/snippets/format.sh
new file mode 100644
index 00000000000..153ed361ef9
--- /dev/null
+++ b/aiplatform/snippets/format.sh
@@ -0,0 +1,6 @@
+touch format.sh
+ chmod +rx format.sh
+
+git add .
+git reset HEAD format.sh
+./format.sh
\ No newline at end of file
diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml
index 8ea1e440238..bad1d0b69f3 100644
--- a/aiplatform/snippets/pom.xml
+++ b/aiplatform/snippets/pom.xml
@@ -30,16 +30,6 @@
0.0.1-SNAPSHOT
-
- com.google.protobuf
- protobuf-java-util
- 4.0.0-rc-1
-
-
- com.google.cloud
- google-cloud-storage
- 1.111.0
-
com.google.cloud
google-cloud-storage
@@ -50,8 +40,6 @@
protobuf-java-util
4.0.0-rc-1
-
-
junit
junit
diff --git a/aiplatform/snippets/resources/caprese_salad.jpg b/aiplatform/snippets/resources/caprese_salad.jpg
new file mode 100644
index 00000000000..fbd7e6575c3
Binary files /dev/null and b/aiplatform/snippets/resources/caprese_salad.jpg differ
diff --git a/aiplatform/snippets/resources/image_flower_daisy.jpg b/aiplatform/snippets/resources/image_flower_daisy.jpg
new file mode 100644
index 00000000000..3ba1d67705a
Binary files /dev/null and b/aiplatform/snippets/resources/image_flower_daisy.jpg differ
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java
new file mode 100644
index 00000000000..61931a9fd2e
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_cancel_batch_prediction_job_sample]
+
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJobName;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import java.io.IOException;
+
+public class CancelBatchPredictionJobSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID";
+ cancelBatchPredictionJobSample(project, batchPredictionJobId);
+ }
+
+ static void cancelBatchPredictionJobSample(String project, String batchPredictionJobId)
+ throws IOException {
+ JobServiceSettings jobServiceSettings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
+ String location = "us-central1";
+ BatchPredictionJobName batchPredictionJobName =
+ BatchPredictionJobName.of(project, location, batchPredictionJobId);
+
+ jobServiceClient.cancelBatchPredictionJob(batchPredictionJobName);
+
+ System.out.println("Cancelled the Batch Prediction Job");
+ }
+ }
+}
+// [END aiplatform_cancel_batch_prediction_job_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java
new file mode 100644
index 00000000000..a89f2bfe3d5
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_batch_prediction_job_video_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo;
+import com.google.cloud.aiplatform.v1beta1.BigQueryDestination;
+import com.google.cloud.aiplatform.v1beta1.BigQuerySource;
+import com.google.cloud.aiplatform.v1beta1.CompletionStats;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateBatchPredictionJobVideoClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ String batchPredictionDisplayName = "YOUR_VIDEO_CLASSIFICATION_DISPLAY_NAME";
+ String modelId = "YOUR_MODEL_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]";
+ String gcsDestinationOutputUriPrefix =
+ "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/";
+ String project = "YOUR_PROJECT_ID";
+ createBatchPredictionJobVideoClassification(
+ batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project);
+ }
+
+ static void createBatchPredictionJobVideoClassification(
+ String batchPredictionDisplayName,
+ String modelId,
+ String gcsSourceUri,
+ String gcsDestinationOutputUriPrefix,
+ String project)
+ throws IOException {
+ JobServiceSettings jobServiceSettings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
+ String location = "us-central1";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString =
+ "{\"confidenceThreshold\": 0.5,\"maxPredictions\": 10000,\"segmentClassification\":"
+ + " True,\"shotClassification\": True,\"oneSecIntervalClassification\": True}";
+ Value.Builder modelParameters = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, modelParameters);
+
+ ModelName modelName = ModelName.of(project, location, modelId);
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ InputConfig inputConfig =
+ InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build();
+
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
+ OutputConfig outputConfig =
+ OutputConfig.newBuilder()
+ .setPredictionsFormat("jsonl")
+ .setGcsDestination(gcsDestination)
+ .build();
+
+ BatchPredictionJob batchPredictionJob =
+ BatchPredictionJob.newBuilder()
+ .setDisplayName(batchPredictionDisplayName)
+ .setModel(modelName.toString())
+ .setModelParameters(modelParameters)
+ .setInputConfig(inputConfig)
+ .setOutputConfig(outputConfig)
+ .build();
+ BatchPredictionJob batchPredictionJobResponse =
+ jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob);
+
+ System.out.println("Create Batch Prediction Job Video Classification Response");
+ System.out.format("\tName: %s\n", batchPredictionJobResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName());
+ System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel());
+ System.out.format(
+ "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters());
+ System.out.format(
+ "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation());
+
+ System.out.format("\tState: %s\n", batchPredictionJobResponse.getState());
+ System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime());
+ System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap());
+
+ InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig();
+ System.out.println("\tInput Config");
+ System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat());
+
+ GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource();
+ System.out.println("\t\tGcs Source");
+ System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList());
+
+ BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource();
+ System.out.println("\t\tBigquery Source");
+ System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri());
+
+ OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig();
+ System.out.println("\tOutput Config");
+ System.out.format(
+ "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat());
+
+ GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination();
+ System.out.println("\t\tGcs Destination");
+ System.out.format(
+ "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix());
+
+ BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination();
+ System.out.println("\t\tBig Query Destination");
+ System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri());
+
+ BatchDedicatedResources batchDedicatedResources =
+ batchPredictionJobResponse.getDedicatedResources();
+ System.out.println("\tBatch Dedicated Resources");
+ System.out.format(
+ "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount());
+ System.out.format(
+ "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount());
+
+ MachineSpec machineSpec = batchDedicatedResources.getMachineSpec();
+ System.out.println("\t\tMachine Spec");
+ System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType());
+ System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType());
+ System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount());
+
+ ManualBatchTuningParameters manualBatchTuningParameters =
+ batchPredictionJobResponse.getManualBatchTuningParameters();
+ System.out.println("\tManual Batch Tuning Parameters");
+ System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize());
+
+ OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo();
+ System.out.println("\tOutput Info");
+ System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory());
+ System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset());
+
+ Status status = batchPredictionJobResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ List details = status.getDetailsList();
+
+ for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) {
+ System.out.println("\tPartial Failure");
+ System.out.format("\t\tCode: %s\n", partialFailure.getCode());
+ System.out.format("\t\tMessage: %s\n", partialFailure.getMessage());
+ List partialFailureDetailsList = partialFailure.getDetailsList();
+ }
+
+ ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed();
+ System.out.println("\tResources Consumed");
+ System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours());
+
+ CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats();
+ System.out.println("\tCompletion Stats");
+ System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount());
+ System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount());
+ System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount());
+ }
+ }
+}
+// [END aiplatform_create_batch_prediction_job_video_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java
new file mode 100644
index 00000000000..da0550b2607
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java
@@ -0,0 +1,199 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_batch_prediction_job_video_object_tracking_sample]
+
+import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo;
+import com.google.cloud.aiplatform.v1beta1.BigQueryDestination;
+import com.google.cloud.aiplatform.v1beta1.BigQuerySource;
+import com.google.cloud.aiplatform.v1beta1.CompletionStats;
+import com.google.cloud.aiplatform.v1beta1.GcsDestination;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.MachineSpec;
+import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters;
+import com.google.cloud.aiplatform.v1beta1.ModelName;
+import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateBatchPredictionJobVideoObjectTrackingSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String batchPredictionDisplayName = "YOUR_VIDEO_OBJECT_TRACKING_DISPLAY_NAME";
+ String modelId = "YOUR_MODEL_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]";
+ String gcsDestinationOutputUriPrefix =
+ "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/";
+ String project = "YOUR_PROJECT_ID";
+ batchPredictionJobVideoObjectTracking(
+ batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project);
+ }
+
+ static void batchPredictionJobVideoObjectTracking(
+ String batchPredictionDisplayName,
+ String modelId,
+ String gcsSourceUri,
+ String gcsDestinationOutputUriPrefix,
+ String project)
+ throws IOException {
+ JobServiceSettings jobServiceSettings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
+ String location = "us-central1";
+ LocationName locationName = LocationName.of(project, location);
+ ModelName modelName = ModelName.of(project, location, modelId);
+
+ String jsonString = "{\"confidenceThreshold\": 0.0}";
+ Value.Builder modelParameters = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, modelParameters);
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ InputConfig inputConfig =
+ InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build();
+
+ GcsDestination gcsDestination =
+ GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
+ OutputConfig outputConfig =
+ OutputConfig.newBuilder()
+ .setPredictionsFormat("jsonl")
+ .setGcsDestination(gcsDestination)
+ .build();
+
+ BatchPredictionJob batchPredictionJob =
+ BatchPredictionJob.newBuilder()
+ .setDisplayName(batchPredictionDisplayName)
+ .setModel(modelName.toString())
+ .setModelParameters(modelParameters)
+ .setInputConfig(inputConfig)
+ .setOutputConfig(outputConfig)
+ .build();
+ BatchPredictionJob batchPredictionJobResponse =
+ jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob);
+
+ System.out.println("Create Batch Prediction Job Video Object Tracking Response");
+ System.out.format("\tName: %s\n", batchPredictionJobResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName());
+ System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel());
+ System.out.format(
+ "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters());
+ System.out.format(
+ "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation());
+
+ System.out.format("\tState: %s\n", batchPredictionJobResponse.getState());
+ System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime());
+ System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap());
+
+ InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig();
+ System.out.println("\tInput Config");
+ System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat());
+
+ GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource();
+ System.out.println("\t\tGcs Source");
+ System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList());
+
+ BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource();
+ System.out.println("\t\tBigquery Source");
+ System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri());
+
+ OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig();
+ System.out.println("\tOutput Config");
+ System.out.format(
+ "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat());
+
+ GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination();
+ System.out.println("\t\tGcs Destination");
+ System.out.format(
+ "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix());
+
+ BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination();
+ System.out.println("\t\tBig Query Destination");
+ System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri());
+
+ BatchDedicatedResources batchDedicatedResources =
+ batchPredictionJobResponse.getDedicatedResources();
+ System.out.println("\tBatch Dedicated Resources");
+ System.out.format(
+ "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount());
+ System.out.format(
+ "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount());
+
+ MachineSpec machineSpec = batchDedicatedResources.getMachineSpec();
+ System.out.println("\t\tMachine Spec");
+ System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType());
+ System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType());
+ System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount());
+
+ ManualBatchTuningParameters manualBatchTuningParameters =
+ batchPredictionJobResponse.getManualBatchTuningParameters();
+ System.out.println("\tManual Batch Tuning Parameters");
+ System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize());
+
+ OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo();
+ System.out.println("\tOutput Info");
+ System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory());
+ System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset());
+
+ Status status = batchPredictionJobResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ List details = status.getDetailsList();
+
+ for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) {
+ System.out.println("\tPartial Failure");
+ System.out.format("\t\tCode: %s\n", partialFailure.getCode());
+ System.out.format("\t\tMessage: %s\n", partialFailure.getMessage());
+ List partialFailureDetailsList = partialFailure.getDetailsList();
+ }
+
+ ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed();
+ System.out.println("\tResources Consumed");
+ System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours());
+
+ CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats();
+ System.out.println("\tCompletion Stats");
+ System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount());
+ System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount());
+ System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount());
+ }
+ }
+}
+// [END aiplatform_create_batch_prediction_job_video_object_tracking_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java
new file mode 100644
index 00000000000..0ce9767c48f
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_dataset_image_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class CreateDatasetImageSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME";
+ createDatasetImageSample(project, datasetDisplayName);
+ }
+
+ static void createDatasetImageSample(String project, String datasetDisplayName)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName(datasetDisplayName)
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS);
+
+ System.out.println("Create Image Dataset Response");
+ System.out.format("Name: %s\n", datasetResponse.getName());
+ System.out.format("Display Name: %s\n", datasetResponse.getDisplayName());
+ System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", datasetResponse.getMetadata());
+ System.out.format("Create Time: %s\n", datasetResponse.getCreateTime());
+ System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime());
+ System.out.format("Labels: %s\n", datasetResponse.getLabelsMap());
+ }
+ }
+}
+// [END aiplatform_create_dataset_image_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java
new file mode 100644
index 00000000000..537525c8136
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_dataset_video_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class CreateDatasetVideoSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetVideoDisplayName = "YOUR_DATASET_VIDEO_DISPLAY_NAME";
+ createDatasetSample(datasetVideoDisplayName, project);
+ }
+
+ static void createDatasetSample(String datasetVideoDisplayName, String project)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName(datasetVideoDisplayName)
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.println("Create Dataset Video Response");
+ System.out.format("Name: %s\n", datasetResponse.getName());
+ System.out.format("Display Name: %s\n", datasetResponse.getDisplayName());
+ System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", datasetResponse.getMetadata());
+ System.out.format("Create Time: %s\n", datasetResponse.getCreateTime());
+ System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime());
+ System.out.format("Labels: %s\n", datasetResponse.getLabelsMap());
+ }
+ }
+}
+// [END aiplatform_create_dataset_video_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java
new file mode 100644
index 00000000000..7327cba9b20
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java
@@ -0,0 +1,233 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_image_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineImageClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+ createTrainingPipelineImageClassificationSample(
+ project, trainingPipelineDisplayName, datasetId, modelDisplayName);
+ }
+
+ static void createTrainingPipelineImageClassificationSample(
+ String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_image_classification_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString =
+ "{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
+ + " \"disableEarlyStopping\": false}";
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig trainingInputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(trainingInputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Image Classification Response");
+ System.out.format("Name: %s\n", trainingPipelineResponse.getName());
+ System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("State: %s\n", trainingPipelineResponse.getState());
+
+ System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("Input Data Config");
+ System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
+ System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
+ System.out.println("Fraction Split");
+ System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfig.getFilterSplit();
+ System.out.println("Filter Split");
+ System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
+ System.out.println("Predefined Split");
+ System.out.format("Key: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
+ System.out.println("Timestamp Split");
+ System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("Key: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("Model To Upload");
+ System.out.format("Name: %s\n", modelResponse.getName());
+ System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("Description: %s\n", modelResponse.getDescription());
+
+ System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", modelResponse.getMetadata());
+ System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "Supported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList());
+ System.out.format(
+ "Supported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList());
+ System.out.format(
+ "Supported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList());
+
+ System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());
+
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+ System.out.println("Predict Schemata");
+ System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("Supported Export Format");
+ System.out.format("Id: %s\n", exportFormat.getId());
+ }
+
+ ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
+ System.out.println("Container Spec");
+ System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
+ System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
+ System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
+ System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
+ System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());
+
+ for (EnvVar envVar : modelContainerSpec.getEnvList()) {
+ System.out.println("Env");
+ System.out.format("Name: %s\n", envVar.getName());
+ System.out.format("Value: %s\n", envVar.getValue());
+ }
+
+ for (Port port : modelContainerSpec.getPortsList()) {
+ System.out.println("Port");
+ System.out.format("Container Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("Deployed Model");
+ System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("Explanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("Parameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("Sampled Shapley Attribution");
+ System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("Metadata");
+ System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "Feature Attributions Schema_uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("Error");
+ System.out.format("Code: %s\n", status.getCode());
+ System.out.format("Message: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_image_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java
new file mode 100644
index 00000000000..636ab022418
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java
@@ -0,0 +1,233 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_image_object_detection_sample]
+
+import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
+import com.google.cloud.aiplatform.v1beta1.EnvVar;
+import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ExplanationParameters;
+import com.google.cloud.aiplatform.v1beta1.ExplanationSpec;
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat;
+import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.Port;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.PredictSchemata;
+import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineImageObjectDetectionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+ createTrainingPipelineImageObjectDetectionSample(
+ project, trainingPipelineDisplayName, datasetId, modelDisplayName);
+ }
+
+ static void createTrainingPipelineImageObjectDetectionSample(
+ String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_image_object_detection_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString =
+ "{\"modelType\": \"CLOUD_HIGH_ACCURACY_1\", \"budgetMilliNodeHours\": 20000,"
+ + " \"disableEarlyStopping\": false}";
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig trainingInputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(trainingInputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Image Object Detection Response");
+ System.out.format("Name: %s\n", trainingPipelineResponse.getName());
+ System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("State: %s\n", trainingPipelineResponse.getState());
+
+ System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("Input Data Config");
+ System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
+ System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
+ System.out.println("Fraction Split");
+ System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfig.getFilterSplit();
+ System.out.println("Filter Split");
+ System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
+ System.out.println("Predefined Split");
+ System.out.format("Key: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
+ System.out.println("Timestamp Split");
+ System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("Key: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("Model To Upload");
+ System.out.format("Name: %s\n", modelResponse.getName());
+ System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("Description: %s\n", modelResponse.getDescription());
+
+ System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", modelResponse.getMetadata());
+ System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "Supported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList());
+ System.out.format(
+ "Supported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList());
+ System.out.format(
+ "Supported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList());
+
+ System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());
+
+ PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
+ System.out.println("Predict Schemata");
+ System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
+ System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
+ System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
+
+ for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
+ System.out.println("Supported Export Format");
+ System.out.format("Id: %s\n", exportFormat.getId());
+ }
+
+ ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
+ System.out.println("Container Spec");
+ System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
+ System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
+ System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
+ System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
+ System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());
+
+ for (EnvVar envVar : modelContainerSpec.getEnvList()) {
+ System.out.println("Env");
+ System.out.format("Name: %s\n", envVar.getName());
+ System.out.format("Value: %s\n", envVar.getValue());
+ }
+
+ for (Port port : modelContainerSpec.getPortsList()) {
+ System.out.println("Port");
+ System.out.format("Container Port: %s\n", port.getContainerPort());
+ }
+
+ for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
+ System.out.println("Deployed Model");
+ System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
+ System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
+ }
+
+ ExplanationSpec explanationSpec = modelResponse.getExplanationSpec();
+ System.out.println("Explanation Spec");
+
+ ExplanationParameters explanationParameters = explanationSpec.getParameters();
+ System.out.println("Parameters");
+
+ SampledShapleyAttribution sampledShapleyAttribution =
+ explanationParameters.getSampledShapleyAttribution();
+ System.out.println("Sampled Shapley Attribution");
+ System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount());
+
+ ExplanationMetadata explanationMetadata = explanationSpec.getMetadata();
+ System.out.println("Metadata");
+ System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap());
+ System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap());
+ System.out.format(
+ "Feature Attributions Schema_uri: %s\n",
+ explanationMetadata.getFeatureAttributionsSchemaUri());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("Error");
+ System.out.format("Code: %s\n", status.getCode());
+ System.out.format("Message: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_image_object_detection_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java
new file mode 100644
index 00000000000..383e56954af
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java
@@ -0,0 +1,162 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_video_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineVideoClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String videoClassificationDisplayName =
+ "YOUR_TRAINING_PIPELINE_VIDEO_CLASSIFICATION_DISPLAY_NAME";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ createTrainingPipelineVideoClassification(
+ videoClassificationDisplayName, datasetId, modelDisplayName, project);
+ }
+
+ static void createTrainingPipelineVideoClassification(
+ String videoClassificationDisplayName,
+ String datasetId,
+ String modelDisplayName,
+ String project)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ LocationName locationName = LocationName.of(project, location);
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_video_classification_1.0.0.yaml";
+
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
+
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(videoClassificationDisplayName)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(Value.newBuilder())
+ .setInputDataConfig(inputDataConfig)
+ .setModelToUpload(model)
+ .build();
+
+ TrainingPipeline trainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Video Classification Response");
+ System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
+ System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
+ System.out.format(
+ "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
+ System.out.format(
+ "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
+ System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
+ System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
+ System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
+ System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
+ System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
+ System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
+
+ InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
+ System.out.println("\tInput Data Config");
+ System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
+ System.out.format(
+ "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit();
+ System.out.println("\t\tFraction Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
+ System.out.println("\t\tFilter Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
+ System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
+ System.out.println("\t\tPredefined Split");
+ System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
+ System.out.println("\t\tTimestamp Split");
+ System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = trainingPipelineResponse.getModelToUpload();
+ System.out.println("\tModel To Upload");
+ System.out.format("\t\tName: %s\n", modelResponse.getName());
+ System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
+ System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
+ System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
+ System.out.format(
+ "\t\tSupported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList().toString());
+ System.out.format(
+ "\t\tSupported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList().toString());
+ System.out.format(
+ "\t\tSupported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList().toString());
+ System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
+
+ Status status = trainingPipelineResponse.getError();
+ System.out.println("\tError");
+ System.out.format("\t\tCode: %s\n", status.getCode());
+ System.out.format("\t\tMessage: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_video_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java
new file mode 100644
index 00000000000..d49fcff96dd
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_create_training_pipeline_video_object_tracking_sample]
+
+import com.google.cloud.aiplatform.v1beta1.FilterSplit;
+import com.google.cloud.aiplatform.v1beta1.FractionSplit;
+import com.google.cloud.aiplatform.v1beta1.InputDataConfig;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.cloud.aiplatform.v1beta1.Model;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.PredefinedSplit;
+import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
+import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
+import com.google.protobuf.Any;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.List;
+
+public class CreateTrainingPipelineVideoObjectTrackingSample {
+
+ public static void main(String[] args) throws IOException {
+ String trainingPipelineVideoObjectTracking =
+ "YOUR_TRAINING_PIPELINE_VIDEO_OBJECT_TRACKING_DISPLAY_NAME";
+ String datasetId = "YOUR_DATASET_ID";
+ String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
+ String project = "YOUR_PROJECT_ID";
+ createTrainingPipelineVideoObjectTracking(
+ trainingPipelineVideoObjectTracking, datasetId, modelDisplayName, project);
+ }
+
+ static void createTrainingPipelineVideoObjectTracking(
+ String trainingPipelineVideoObjectTracking,
+ String datasetId,
+ String modelDisplayName,
+ String project)
+ throws IOException {
+ PipelineServiceSettings pipelineServiceSettings =
+ PipelineServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PipelineServiceClient pipelineServiceClient =
+ PipelineServiceClient.create(pipelineServiceSettings)) {
+ String location = "us-central1";
+ String trainingTaskDefinition =
+ "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ + "automl_video_object_tracking_1.0.0.yaml";
+ LocationName locationName = LocationName.of(project, location);
+
+ String jsonString = "{\"modelType\": \"CLOUD\"}";
+ Value.Builder trainingTaskInputs = Value.newBuilder();
+ JsonFormat.parser().merge(jsonString, trainingTaskInputs);
+
+ InputDataConfig inputDataConfig =
+ InputDataConfig.newBuilder().setDatasetId(datasetId).build();
+ Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
+ TrainingPipeline trainingPipeline =
+ TrainingPipeline.newBuilder()
+ .setDisplayName(trainingPipelineVideoObjectTracking)
+ .setTrainingTaskDefinition(trainingTaskDefinition)
+ .setTrainingTaskInputs(trainingTaskInputs)
+ .setInputDataConfig(inputDataConfig)
+ .setModelToUpload(modelToUpload)
+ .build();
+
+ TrainingPipeline createTrainingPipelineResponse =
+ pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
+
+ System.out.println("Create Training Pipeline Video Object Tracking Response");
+ System.out.format("Name: %s\n", createTrainingPipelineResponse.getName());
+ System.out.format("Display Name: %s\n", createTrainingPipelineResponse.getDisplayName());
+
+ System.out.format(
+ "Training Task Definition %s\n",
+ createTrainingPipelineResponse.getTrainingTaskDefinition());
+ System.out.format(
+ "Training Task Inputs: %s\n",
+ createTrainingPipelineResponse.getTrainingTaskInputs().toString());
+ System.out.format(
+ "Training Task Metadata: %s\n",
+ createTrainingPipelineResponse.getTrainingTaskMetadata().toString());
+
+ System.out.format("State: %s\n", createTrainingPipelineResponse.getState().toString());
+ System.out.format(
+ "Create Time: %s\n", createTrainingPipelineResponse.getCreateTime().toString());
+ System.out.format("StartTime %s\n", createTrainingPipelineResponse.getStartTime().toString());
+ System.out.format("End Time: %s\n", createTrainingPipelineResponse.getEndTime().toString());
+ System.out.format(
+ "Update Time: %s\n", createTrainingPipelineResponse.getUpdateTime().toString());
+ System.out.format("Labels: %s\n", createTrainingPipelineResponse.getLabelsMap().toString());
+
+ InputDataConfig inputDataConfigResponse = createTrainingPipelineResponse.getInputDataConfig();
+ System.out.println("Input Data config");
+ System.out.format("Dataset Id: %s\n", inputDataConfigResponse.getDatasetId());
+ System.out.format("Annotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());
+
+ FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit();
+ System.out.println("Fraction split");
+ System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());
+
+ FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
+ System.out.println("Filter Split");
+ System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
+ System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
+ System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());
+
+ PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
+ System.out.println("Predefined Split");
+ System.out.format("Key: %s\n", predefinedSplit.getKey());
+
+ TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
+ System.out.println("Timestamp Split");
+ System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
+ System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
+ System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
+ System.out.format("Key: %s\n", timestampSplit.getKey());
+
+ Model modelResponse = createTrainingPipelineResponse.getModelToUpload();
+ System.out.println("Model To Upload");
+ System.out.format("Name: %s\n", modelResponse.getName());
+ System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
+ System.out.format("Description: %s\n", modelResponse.getDescription());
+ System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
+ System.out.format("Metadata: %s\n", modelResponse.getMetadata());
+
+ System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
+ System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());
+
+ System.out.format(
+ "Supported Deployment Resources Types: %s\n",
+ modelResponse.getSupportedDeploymentResourcesTypesList().toString());
+ System.out.format(
+ "Supported Input Storage Formats: %s\n",
+ modelResponse.getSupportedInputStorageFormatsList().toString());
+ System.out.format(
+ "Supported Output Storage Formats: %s\n",
+ modelResponse.getSupportedOutputStorageFormatsList().toString());
+
+ System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
+ System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
+ System.out.format("Labels: %s\n", modelResponse.getLabelsMap());
+
+ Status status = createTrainingPipelineResponse.getError();
+ System.out.println("Error");
+ System.out.format("Code: %s\n", status.getCode());
+ System.out.format("Message: %s\n", status.getMessage());
+ }
+ }
+}
+// [END aiplatform_create_training_pipeline_video_object_tracking_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java
new file mode 100644
index 00000000000..c128689d78a
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_delete_batch_prediction_job_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.BatchPredictionJobName;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.JobServiceClient;
+import com.google.cloud.aiplatform.v1beta1.JobServiceSettings;
+import com.google.protobuf.Empty;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class DeleteBatchPredictionJobSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID";
+ deleteBatchPredictionJobSample(project, batchPredictionJobId);
+ }
+
+ static void deleteBatchPredictionJobSample(String project, String batchPredictionJobId)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ JobServiceSettings jobServiceSettings =
+ JobServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
+ String location = "us-central1";
+
+ BatchPredictionJobName batchPredictionJobName =
+ BatchPredictionJobName.of(project, location, batchPredictionJobId);
+
+ OperationFuture operationFuture =
+ jobServiceClient.deleteBatchPredictionJobAsync(batchPredictionJobName);
+ System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ operationFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.println("Deleted Batch Prediction Job.");
+ }
+ }
+}
+// [END aiplatform_delete_batch_prediction_job_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
index 39ad52d0fdf..a9989b564e1 100644
--- a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
+++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java
@@ -59,7 +59,7 @@ static void deleteDatasetSample(String project, String datasetId)
System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName());
System.out.println("Waiting for operation to finish...");
operationFuture.get(300, TimeUnit.SECONDS);
-
+
System.out.format("Deleted Dataset.");
}
}
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java
new file mode 100644
index 00000000000..7c1a3bfad6a
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_image_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationImageClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationImageClassificationSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationImageClassificationSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Image Classification Response");
+ System.out.format("Model Name: %s\n", modelEvaluation.getName());
+ System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("Metrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+ System.out.println("\t\tMean Attribution");
+ System.out.format(
+ "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_image_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java
new file mode 100644
index 00000000000..cd8f7a1cb43
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_image_object_detection_sample]
+
+import com.google.cloud.aiplatform.v1beta1.Attribution;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelExplanation;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationImageObjectDetectionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationImageObjectDetectionSample(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationImageObjectDetectionSample(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Image Object Detection Response");
+ System.out.format("\tName: %s\n", modelEvaluation.getName());
+ System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+
+ ModelExplanation modelExplanation = modelEvaluation.getModelExplanation();
+ for (Attribution attribution : modelExplanation.getMeanAttributionsList()) {
+ System.out.println("\t\tMean Attribution");
+ System.out.format(
+ "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue());
+ System.out.format(
+ "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue());
+ System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions());
+ System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList());
+ System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName());
+ System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError());
+ }
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_image_object_detection_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java
new file mode 100644
index 00000000000..5dc3d85ab43
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_video_classification_sample]
+
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationVideoClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationVideoClassification(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationVideoClassification(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Video Classification Response");
+ System.out.format("Name: %s\n", modelEvaluation.getName());
+ System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("Metrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_video_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java
new file mode 100644
index 00000000000..8cd4ccb5905
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_get_model_evaluation_object_tracking_sample]
+
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluation;
+import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceClient;
+import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings;
+import java.io.IOException;
+
+public class GetModelEvaluationVideoObjectTrackingSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String modelId = "YOUR_MODEL_ID";
+ String evaluationId = "YOUR_EVALUATION_ID";
+ getModelEvaluationVideoObjectTracking(project, modelId, evaluationId);
+ }
+
+ static void getModelEvaluationVideoObjectTracking(
+ String project, String modelId, String evaluationId) throws IOException {
+ ModelServiceSettings modelServiceSettings =
+ ModelServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
+ String location = "us-central1";
+ ModelEvaluationName modelEvaluationName =
+ ModelEvaluationName.of(project, location, modelId, evaluationId);
+
+ ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName);
+
+ System.out.println("Get Model Evaluation Video Object Tracking Response");
+ System.out.format("Name: %s\n", modelEvaluation.getName());
+ System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri());
+ System.out.format("Metrics: %s\n", modelEvaluation.getMetrics());
+ System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime());
+ System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList());
+ }
+ }
+}
+// [END aiplatform_get_model_evaluation_object_tracking_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java
new file mode 100644
index 00000000000..04beb8c54b2
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_image_classification_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataImageClassificationSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]";
+ importDataImageClassificationSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataImageClassificationSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "image_classification_single_label_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ List importDataConfigList =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigList);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.format("Import Data Image Classification Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_image_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java
new file mode 100644
index 00000000000..ae17cfd3a49
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_image_object_detection_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataImageObjectDetectionSample {
+
+ public static void main(String[] args)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]";
+ importDataImageObjectDetectionSample(project, datasetId, gcsSourceUri);
+ }
+
+ static void importDataImageObjectDetectionSample(
+ String project, String datasetId, String gcsSourceUri)
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "image_bounding_box_io_format_1.0.0.yaml";
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+
+ List importDataConfigList =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigList);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.format("Import Data Image Object Detection Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_image_object_detection_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java
new file mode 100644
index 00000000000..4bf2c37f303
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_video_classification_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataVideoClassificationSample {
+
+ public static void main(String[] args)
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ importDataVideoClassification(gcsSourceUri, project, datasetId);
+ }
+
+ static void importDataVideoClassification(String gcsSourceUri, String project, String datasetId)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "video_classification_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+ List importDataConfigs =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigs);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(1800, TimeUnit.SECONDS);
+
+ System.out.format(
+ "Import Data Video Classification Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_video_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java
new file mode 100644
index 00000000000..f8a07d91485
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_import_data_video_object_tracking_sample]
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.GcsSource;
+import com.google.cloud.aiplatform.v1beta1.ImportDataConfig;
+import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.ImportDataResponse;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class ImportDataVideoObjectTrackingSample {
+
+ public static void main(String[] args)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ String gcsSourceUri =
+ "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]";
+ String project = "YOUR_PROJECT_ID";
+ String datasetId = "YOUR_DATASET_ID";
+ importDataVideObjectTracking(gcsSourceUri, project, datasetId);
+ }
+
+ static void importDataVideObjectTracking(String gcsSourceUri, String project, String datasetId)
+ throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String location = "us-central1";
+ String importSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/ioformat/"
+ + "video_object_tracking_io_format_1.0.0.yaml";
+
+ GcsSource.Builder gcsSource = GcsSource.newBuilder();
+ gcsSource.addUris(gcsSourceUri);
+ DatasetName datasetName = DatasetName.of(project, location, datasetId);
+ List importDataConfigs =
+ Collections.singletonList(
+ ImportDataConfig.newBuilder()
+ .setGcsSource(gcsSource)
+ .setImportSchemaUri(importSchemaUri)
+ .build());
+
+ OperationFuture importDataResponseFuture =
+ datasetServiceClient.importDataAsync(datasetName, importDataConfigs);
+ System.out.format(
+ "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName());
+ System.out.println("Waiting for operation to finish...");
+ ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS);
+
+ System.out.format("Import Data Video Object Tracking Response: %s\n",
+ importDataResponse.toString());
+ }
+ }
+}
+// [END aiplatform_import_data_video_object_tracking_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java
new file mode 100644
index 00000000000..b63a914005e
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_image_classification_sample]
+
+import com.google.api.client.util.Base64;
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+
+public class PredictImageClassificationSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String fileName = "YOUR_IMAGE_FILE_PATH";
+ String endpointId = "YOUR_ENDPOINT_ID";
+ predictImageClassification(project, fileName, endpointId);
+ }
+
+ static void predictImageClassification(String project, String fileName, String endpointId)
+ throws IOException {
+ PredictionServiceSettings settings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(settings)) {
+ String location = "us-central1";
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
+ String content = new String(contents, StandardCharsets.UTF_8);
+
+ Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
+
+ String contentDict = "{\"content\": \"" + content + "\"}";
+ Value.Builder instance = Value.newBuilder();
+ JsonFormat.parser().merge(contentDict, instance);
+
+ List instances = new ArrayList<>();
+ instances.add(instance.build());
+
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instances, parameter);
+ System.out.println("Predict Image Classification Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_image_classification_sample]
diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java
new file mode 100644
index 00000000000..b7e832871f2
--- /dev/null
+++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START aiplatform_predict_image_object_detection_sample]
+
+import com.google.api.client.util.Base64;
+import com.google.cloud.aiplatform.v1beta1.EndpointName;
+import com.google.cloud.aiplatform.v1beta1.PredictResponse;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+
+public class PredictImageObjectDetectionSample {
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String project = "YOUR_PROJECT_ID";
+ String fileName = "YOUR_IMAGE_FILE_PATH";
+ String endpointId = "YOUR_ENDPOINT_ID";
+ predictImageObjectDetection(project, fileName, endpointId);
+ }
+
+ static void predictImageObjectDetection(String project, String fileName, String endpointId)
+ throws IOException {
+ PredictionServiceSettings settings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443")
+ .build();
+
+ // Initialize client that will be used to send requests. This client only needs to be created
+ // once, and can be reused for multiple requests. After completing all of your requests, call
+ // the "close" method on the client to safely clean up any remaining background resources.
+ try (PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(settings)) {
+ String location = "us-central1";
+ EndpointName endpointName = EndpointName.of(project, location, endpointId);
+
+ byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
+ String content = new String(contents, StandardCharsets.UTF_8);
+
+ Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
+
+ String contentDict = "{\"content\": \"" + content + "\"}";
+ Value.Builder instance = Value.newBuilder();
+ JsonFormat.parser().merge(contentDict, instance);
+
+ List instances = new ArrayList<>();
+ instances.add(instance.build());
+
+ PredictResponse predictResponse =
+ predictionServiceClient.predict(endpointName, instances, parameter);
+ System.out.println("Predict Image Object Detection Response");
+ System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
+
+ System.out.println("Predictions");
+ for (Value prediction : predictResponse.getPredictionsList()) {
+ System.out.format("\tPrediction: %s\n", prediction);
+ }
+ }
+ }
+}
+// [END aiplatform_predict_image_object_detection_sample]
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java
new file mode 100644
index 00000000000..3fa42715cda
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateBatchPredictionJobVideoClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_VIDEO_CLASS_MODEL_ID");
+ private static final String GCS_SOURCE_URI =
+ "gs://ucaip-samples-test-output/inputs/vcn_40_batch_prediction_input.jsonl";
+ private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String batchPredictionJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("BATCH_PREDICTION_VIDEO_CLASS_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Batch Prediction Job
+ CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Batch Prediction Job
+ DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Batch");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateBatchPredictionJobVideoClassificationSample() throws IOException {
+ // Act
+ String batchPredictionDisplayName =
+ String.format(
+ "batch_prediction_video_classification_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateBatchPredictionJobVideoClassificationSample.createBatchPredictionJobVideoClassification(
+ batchPredictionDisplayName,
+ MODEL_ID,
+ GCS_SOURCE_URI,
+ GCS_DESTINATION_OUTPUT_URI_PREFIX,
+ PROJECT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(batchPredictionDisplayName);
+ assertThat(got).contains("Create Batch Prediction Job Video Classification Response");
+ batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java
new file mode 100644
index 00000000000..06f934f4960
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateBatchPredictionJobVideoObjectTrackingSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_VIDEO_OBJECT_MODEL_ID");
+ private static final String GCS_SOURCE_URI =
+ "gs://ucaip-samples-test-output/inputs/vot_batch_prediction_input.jsonl";
+ private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/";
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String batchPredictionJobId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("BATCH_PREDICTION_VIDEO_OBJECT_MODEL_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Batch Prediction Job
+ CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Batch Prediction Job
+ DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Batch");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateBatchPredictionJobVideoObjectTrackingSample() throws IOException {
+ // Act
+ String batchPredictionDisplayName =
+ String.format(
+ "batch_prediction_video_object_tracking_display_name_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateBatchPredictionJobVideoObjectTrackingSample.batchPredictionJobVideoObjectTracking(
+ batchPredictionDisplayName,
+ MODEL_ID,
+ GCS_SOURCE_URI,
+ GCS_DESTINATION_OUTPUT_URI_PREFIX,
+ PROJECT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(batchPredictionDisplayName);
+ assertThat(got).contains("Create Batch Prediction Job Video Object Tracking Response");
+ batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java
new file mode 100644
index 00000000000..f2b95b3d4ee
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateDatasetImageSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String datasetId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Delete the created dataset
+ DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Dataset");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateDatasetSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ String datasetDisplayName =
+ String.format(
+ "temp_create_dataset_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDatasetImageSample.createDatasetImageSample(PROJECT, datasetDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(datasetDisplayName);
+ assertThat(got).contains("Create Image Dataset Response");
+ datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java
new file mode 100644
index 00000000000..b979692fa0d
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateDatasetVideoSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private String datasetId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Delete the created dataset
+ DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Dataset");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateDatasetVideoSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ String displayName =
+ String.format(
+ "temp_create_dataset_video_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateDatasetVideoSample.createDatasetSample(displayName, PROJECT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(displayName);
+ assertThat(got).contains("Create Dataset Video Response");
+ datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java
new file mode 100644
index 00000000000..747c9117fb6
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateTrainingPipelineImageClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineImageClassificationSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineImageClassificationSample.createTrainingPipelineImageClassificationSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Image Classification Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java
new file mode 100644
index 00000000000..c4295cb9440
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CreateTrainingPipelineImageObjectDetectionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineImageObjectDetectionSample() throws IOException {
+ String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26);
+ // Act
+ String trainingPipelineDisplayName =
+ String.format("temp_create_training_pipeline_test_%s", tempUuid);
+
+ String modelDisplayName =
+ String.format("temp_create_training_pipeline_model_test_%s", tempUuid);
+
+ CreateTrainingPipelineImageObjectDetectionSample
+ .createTrainingPipelineImageObjectDetectionSample(
+ PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Image Object Detection Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java
new file mode 100644
index 00000000000..d58fd4fc63b
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineVideoClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineVideoClassificationSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_classification_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_classification_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineVideoClassificationSample.createTrainingPipelineVideoClassification(
+ trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Video Classification Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java
new file mode 100644
index 00000000000..010dcc07526
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class CreateTrainingPipelineVideoObjectTrackingSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String DATASET_ID =
+ System.getenv("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+ private String trainingPipelineId;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // Cancel the Training Pipeline
+ CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String cancelResponse = bout.toString();
+ assertThat(cancelResponse).contains("Cancelled the Training Pipeline");
+ TimeUnit.MINUTES.sleep(2);
+
+ // Delete the Training Pipeline
+ DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId);
+
+ // Assert
+ String deleteResponse = bout.toString();
+ assertThat(deleteResponse).contains("Deleted Training Pipeline.");
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testCreateTrainingPipelineVideoObjectTrackingSample() throws IOException {
+ // Act
+ String trainingPipelineDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_object_tracking_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ String modelDisplayName =
+ String.format(
+ "temp_create_training_pipeline_video_object_tracking_model_test_%s",
+ UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
+
+ CreateTrainingPipelineVideoObjectTrackingSample.createTrainingPipelineVideoObjectTracking(
+ trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(DATASET_ID);
+ assertThat(got).contains("Create Training Pipeline Video Object Tracking Response");
+ trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0];
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java
new file mode 100644
index 00000000000..228fa605577
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GetModelEvaluationImageClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("IMAGE_CLASS_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("IMAGE_CLASS_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("IMAGE_CLASS_MODEL_ID");
+ requireEnvVar("IMAGE_CLASS_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationImageClassificationSample() throws IOException {
+ // Act
+ GetModelEvaluationImageClassificationSample.getModelEvaluationImageClassificationSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Image Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java
new file mode 100644
index 00000000000..b78ec23a61d
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GetModelEvaluationImageObjectDetectionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("IMAGE_OBJECT_DETECT_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("IMAGE_OBJECT_DETECT_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("IMAGE_OBJECT_DETECT_MODEL_ID");
+ requireEnvVar("IMAGE_OBJECT_DETECT_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationImageObjectDetectionSample() throws IOException {
+ // Act
+ GetModelEvaluationImageObjectDetectionSample.getModelEvaluationImageObjectDetectionSample(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Image Object Detection Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java
new file mode 100644
index 00000000000..4347485ef86
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetModelEvaluationVideoClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("VIDEO_CLASS_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("VIDEO_CLASS_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("VIDEO_CLASS_MODEL_ID");
+ requireEnvVar("VIDEO_CLASS_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationVideoClassificationSample() throws IOException {
+ // Act
+ GetModelEvaluationVideoClassificationSample.getModelEvaluationVideoClassification(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Video Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java
new file mode 100644
index 00000000000..cefe40345d3
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class GetModelEvaluationVideoObjectTrackingSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String MODEL_ID = System.getenv("VIDEO_OBJECT_DETECT_MODEL_ID");
+ private static final String EVALUATION_ID = System.getenv("VIDEO_OBJECT_DETECT_EVALUATION_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("VIDEO_OBJECT_DETECT_MODEL_ID");
+ requireEnvVar("VIDEO_OBJECT_DETECT_EVALUATION_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testGetModelEvaluationVideoObjectTrackingSample() throws IOException {
+ // Act
+ GetModelEvaluationVideoObjectTrackingSample.getModelEvaluationVideoObjectTracking(
+ PROJECT, MODEL_ID, EVALUATION_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains(MODEL_ID);
+ assertThat(got).contains("Get Model Evaluation Video Object Tracking Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java
new file mode 100644
index 00000000000..ed4d7111964
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Empty;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ImportDataImageClassificationSampleTest {
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String LOCATION = "us-central1";
+
+ private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl";
+ private String datasetId;
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+
+ // create a temp dataset for importing data
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml";
+ LocationName locationName = LocationName.of(PROJECT, LOCATION);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName("test_dataset_display_name")
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS);
+ String[] datasetValues = datasetResponse.getName().split("/");
+ datasetId = datasetValues[datasetValues.length - 1];
+ }
+ }
+
+ @After
+ public void tearDown() throws InterruptedException, ExecutionException, IOException {
+ // delete the temp dataset
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ operationFuture.get();
+ }
+
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ ImportDataImageClassificationSample.importDataImageClassificationSample(
+ PROJECT, datasetId, GCS_SOURCE_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Import Data Image Classification Response: ");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java
new file mode 100644
index 00000000000..451a7c230fb
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Empty;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ImportDataImageObjectDetectionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String LOCATION = "us-central1";
+ private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl";
+
+ private String datasetId;
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp()
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+
+ // create a temp dataset for importing data
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml";
+ LocationName locationName = LocationName.of(PROJECT, LOCATION);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName("test_dataset_display_name")
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS);
+ String[] datasetValues = datasetResponse.getName().split("/");
+ datasetId = datasetValues[datasetValues.length - 1];
+ }
+ }
+
+ @After
+ public void tearDown()
+ throws InterruptedException, ExecutionException, IOException, TimeoutException {
+ // delete the temp dataset
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ operationFuture.get();
+ }
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ ImportDataImageObjectDetectionSample.importDataImageObjectDetectionSample(
+ PROJECT, datasetId, GCS_SOURCE_URI);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Import Data Image Object Detection Response: ");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java
new file mode 100644
index 00000000000..66b0237cf45
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java
@@ -0,0 +1,129 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Empty;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class ImportDataVideoClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String LOCATION = "us-central1";
+ private static final String GCS_SOURCE_URI =
+ "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv";
+
+ private String datasetId;
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+
+ // create a temp dataset for importing data
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml";
+ LocationName locationName = LocationName.of(PROJECT, LOCATION);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName("test_dataset_display_name")
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+ String[] datasetValues = datasetResponse.getName().split("/");
+ datasetId = datasetValues[datasetValues.length - 1];
+ }
+ }
+
+ @After
+ public void tearDown() throws InterruptedException, ExecutionException, IOException {
+ // delete the temp dataset
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ operationFuture.get();
+ }
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataVideoClassificationSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ ImportDataVideoClassificationSample.importDataVideoClassification(
+ GCS_SOURCE_URI, PROJECT, datasetId);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Import Data Video Classification Response: ");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java
new file mode 100644
index 00000000000..6d8b5e7a7ed
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import com.google.api.gax.longrunning.OperationFuture;
+import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.Dataset;
+import com.google.cloud.aiplatform.v1beta1.DatasetName;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient;
+import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings;
+import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
+import com.google.cloud.aiplatform.v1beta1.LocationName;
+import com.google.protobuf.Empty;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class ImportDataVideoObjectTrackingSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String LOCATION = "us-central1";
+ private static final String GCS_SOURCE_URI =
+ "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv";
+ private String datasetId;
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ }
+
+ @Before
+ public void setUp()
+ throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+
+ // create a temp dataset for importing data
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ String metadataSchemaUri =
+ "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml";
+ LocationName locationName = LocationName.of(PROJECT, LOCATION);
+ Dataset dataset =
+ Dataset.newBuilder()
+ .setDisplayName("test_dataset_display_name")
+ .setMetadataSchemaUri(metadataSchemaUri)
+ .build();
+
+ OperationFuture datasetFuture =
+ datasetServiceClient.createDatasetAsync(locationName, dataset);
+ Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS);
+ String[] datasetValues = datasetResponse.getName().split("/");
+ datasetId = datasetValues[datasetValues.length - 1];
+ }
+ }
+
+ @After
+ public void tearDown() throws InterruptedException, ExecutionException, IOException {
+ // delete the temp dataset
+ DatasetServiceSettings datasetServiceSettings =
+ DatasetServiceSettings.newBuilder()
+ .setEndpoint("us-central1-aiplatform.googleapis.com:443")
+ .build();
+ try (DatasetServiceClient datasetServiceClient =
+ DatasetServiceClient.create(datasetServiceSettings)) {
+ DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId);
+
+ OperationFuture operationFuture =
+ datasetServiceClient.deleteDatasetAsync(datasetName);
+ operationFuture.get();
+ }
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testImportDataVideoObjectTrackingSample()
+ throws IOException, InterruptedException, ExecutionException, TimeoutException {
+ // Act
+ ImportDataVideoObjectTrackingSample.importDataVideObjectTracking(
+ GCS_SOURCE_URI, PROJECT, datasetId);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Import Data Video Object Tracking Response: ");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java
new file mode 100644
index 00000000000..8ca3fd95c5e
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictImageClassificationSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String FILE_NAME = "resources/image_flower_daisy.jpg";
+ private static final String ENDPOINT_ID = System.getenv("IMAGE_CLASS_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("IMAGE_CLASS_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictImageClassification() throws IOException {
+ // Act
+ PredictImageClassificationSample.predictImageClassification(PROJECT, FILE_NAME, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Image Classification Response");
+ }
+}
diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java
new file mode 100644
index 00000000000..1539c7dfbc1
--- /dev/null
+++ b/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static com.google.common.truth.Truth.assertThat;
+import static junit.framework.TestCase.assertNotNull;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintStream;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class PredictImageObjectDetectionSampleTest {
+
+ private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
+ private static final String FILE_NAME = "resources/caprese_salad.jpg";
+ private static final String ENDPOINT_ID = System.getenv("IMAGE_OBJECT_DETECTION_ENDPOINT_ID");
+ private ByteArrayOutputStream bout;
+ private PrintStream out;
+ private PrintStream originalPrintStream;
+
+ private static void requireEnvVar(String varName) {
+ String errorMessage =
+ String.format("Environment variable '%s' is required to perform these tests.", varName);
+ assertNotNull(errorMessage, System.getenv(varName));
+ }
+
+ @BeforeClass
+ public static void checkRequirements() {
+ requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
+ requireEnvVar("UCAIP_PROJECT_ID");
+ requireEnvVar("IMAGE_OBJECT_DETECTION_ENDPOINT_ID");
+ }
+
+ @Before
+ public void setUp() {
+ bout = new ByteArrayOutputStream();
+ out = new PrintStream(bout);
+ originalPrintStream = System.out;
+ System.setOut(out);
+ }
+
+ @After
+ public void tearDown() {
+ System.out.flush();
+ System.setOut(originalPrintStream);
+ }
+
+ @Test
+ public void testPredictImageObjectDetection() throws IOException {
+ // Act
+ PredictImageObjectDetectionSample.predictImageObjectDetection(PROJECT, FILE_NAME, ENDPOINT_ID);
+
+ // Assert
+ String got = bout.toString();
+ assertThat(got).contains("Predict Image Object Detection Response");
+ }
+}