diff --git a/aiplatform/snippets/format.sh b/aiplatform/snippets/format.sh new file mode 100644 index 00000000000..153ed361ef9 --- /dev/null +++ b/aiplatform/snippets/format.sh @@ -0,0 +1,6 @@ +touch format.sh + chmod +rx format.sh + +git add . +git reset HEAD format.sh +./format.sh \ No newline at end of file diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml index 8ea1e440238..bad1d0b69f3 100644 --- a/aiplatform/snippets/pom.xml +++ b/aiplatform/snippets/pom.xml @@ -30,16 +30,6 @@ 0.0.1-SNAPSHOT - - com.google.protobuf - protobuf-java-util - 4.0.0-rc-1 - - - com.google.cloud - google-cloud-storage - 1.111.0 - com.google.cloud google-cloud-storage @@ -50,8 +40,6 @@ protobuf-java-util 4.0.0-rc-1 - - junit junit diff --git a/aiplatform/snippets/resources/caprese_salad.jpg b/aiplatform/snippets/resources/caprese_salad.jpg new file mode 100644 index 00000000000..fbd7e6575c3 Binary files /dev/null and b/aiplatform/snippets/resources/caprese_salad.jpg differ diff --git a/aiplatform/snippets/resources/image_flower_daisy.jpg b/aiplatform/snippets/resources/image_flower_daisy.jpg new file mode 100644 index 00000000000..3ba1d67705a Binary files /dev/null and b/aiplatform/snippets/resources/image_flower_daisy.jpg differ diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java new file mode 100644 index 00000000000..61931a9fd2e --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java @@ -0,0 +1,56 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_cancel_batch_prediction_job_sample] + +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import java.io.IOException; + +public class CancelBatchPredictionJobSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID"; + cancelBatchPredictionJobSample(project, batchPredictionJobId); + } + + static void cancelBatchPredictionJobSample(String project, String batchPredictionJobId) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + BatchPredictionJobName batchPredictionJobName = + BatchPredictionJobName.of(project, location, batchPredictionJobId); + + jobServiceClient.cancelBatchPredictionJob(batchPredictionJobName); + + System.out.println("Cancelled the Batch Prediction Job"); + } + } +} +// [END aiplatform_cancel_batch_prediction_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java new file mode 100644 index 00000000000..a89f2bfe3d5 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java @@ -0,0 +1,200 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; +import com.google.cloud.aiplatform.v1beta1.BigQuerySource; +import com.google.cloud.aiplatform.v1beta1.CompletionStats; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateBatchPredictionJobVideoClassificationSample { + + public static void main(String[] args) throws IOException { + String batchPredictionDisplayName = "YOUR_VIDEO_CLASSIFICATION_DISPLAY_NAME"; + String modelId = "YOUR_MODEL_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String gcsDestinationOutputUriPrefix = + "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/"; + String project = "YOUR_PROJECT_ID"; + createBatchPredictionJobVideoClassification( + batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project); + } + + static void createBatchPredictionJobVideoClassification( + String batchPredictionDisplayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix, + String project) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"confidenceThreshold\": 0.5,\"maxPredictions\": 10000,\"segmentClassification\":" + + " True,\"shotClassification\": True,\"oneSecIntervalClassification\": True}"; + Value.Builder modelParameters = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, modelParameters); + + ModelName modelName = ModelName.of(project, location, modelId); + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + InputConfig inputConfig = + InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build(); + + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + OutputConfig outputConfig = + OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(batchPredictionDisplayName) + .setModel(modelName.toString()) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + BatchPredictionJob batchPredictionJobResponse = + jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob); + + System.out.println("Create Batch Prediction Job Video Classification Response"); + System.out.format("\tName: %s\n", batchPredictionJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName()); + System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); + System.out.format( + "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); + System.out.format( + "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation()); + + System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); + System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap()); + + InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig(); + System.out.println("\tInput Config"); + System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat()); + + GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource(); + System.out.println("\t\tGcs Source"); + System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList()); + + BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource(); + System.out.println("\t\tBigquery Source"); + System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri()); + + OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig(); + System.out.println("\tOutput Config"); + System.out.format( + "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat()); + + GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination(); + System.out.println("\t\tGcs Destination"); + System.out.format( + "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix()); + + BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination(); + System.out.println("\t\tBig Query Destination"); + System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri()); + + BatchDedicatedResources batchDedicatedResources = + batchPredictionJobResponse.getDedicatedResources(); + System.out.println("\tBatch Dedicated Resources"); + System.out.format( + "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount()); + System.out.format( + "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount()); + + MachineSpec machineSpec = batchDedicatedResources.getMachineSpec(); + System.out.println("\t\tMachine Spec"); + System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType()); + System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType()); + System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount()); + + ManualBatchTuningParameters manualBatchTuningParameters = + batchPredictionJobResponse.getManualBatchTuningParameters(); + System.out.println("\tManual Batch Tuning Parameters"); + System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize()); + + OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo(); + System.out.println("\tOutput Info"); + System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory()); + System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset()); + + Status status = batchPredictionJobResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + List details = status.getDetailsList(); + + for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) { + System.out.println("\tPartial Failure"); + System.out.format("\t\tCode: %s\n", partialFailure.getCode()); + System.out.format("\t\tMessage: %s\n", partialFailure.getMessage()); + List partialFailureDetailsList = partialFailure.getDetailsList(); + } + + ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed(); + System.out.println("\tResources Consumed"); + System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours()); + + CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats(); + System.out.println("\tCompletion Stats"); + System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount()); + System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount()); + System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount()); + } + } +} +// [END aiplatform_create_batch_prediction_job_video_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java new file mode 100644 index 00000000000..da0550b2607 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java @@ -0,0 +1,199 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_video_object_tracking_sample] + +import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; +import com.google.cloud.aiplatform.v1beta1.BigQuerySource; +import com.google.cloud.aiplatform.v1beta1.CompletionStats; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.MachineSpec; +import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateBatchPredictionJobVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String batchPredictionDisplayName = "YOUR_VIDEO_OBJECT_TRACKING_DISPLAY_NAME"; + String modelId = "YOUR_MODEL_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String gcsDestinationOutputUriPrefix = + "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/"; + String project = "YOUR_PROJECT_ID"; + batchPredictionJobVideoObjectTracking( + batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project); + } + + static void batchPredictionJobVideoObjectTracking( + String batchPredictionDisplayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix, + String project) + throws IOException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + ModelName modelName = ModelName.of(project, location, modelId); + + String jsonString = "{\"confidenceThreshold\": 0.0}"; + Value.Builder modelParameters = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, modelParameters); + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + InputConfig inputConfig = + InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build(); + + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + OutputConfig outputConfig = + OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(batchPredictionDisplayName) + .setModel(modelName.toString()) + .setModelParameters(modelParameters) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + BatchPredictionJob batchPredictionJobResponse = + jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob); + + System.out.println("Create Batch Prediction Job Video Object Tracking Response"); + System.out.format("\tName: %s\n", batchPredictionJobResponse.getName()); + System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName()); + System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); + System.out.format( + "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); + System.out.format( + "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation()); + + System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); + System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap()); + + InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig(); + System.out.println("\tInput Config"); + System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat()); + + GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource(); + System.out.println("\t\tGcs Source"); + System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList()); + + BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource(); + System.out.println("\t\tBigquery Source"); + System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri()); + + OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig(); + System.out.println("\tOutput Config"); + System.out.format( + "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat()); + + GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination(); + System.out.println("\t\tGcs Destination"); + System.out.format( + "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix()); + + BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination(); + System.out.println("\t\tBig Query Destination"); + System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri()); + + BatchDedicatedResources batchDedicatedResources = + batchPredictionJobResponse.getDedicatedResources(); + System.out.println("\tBatch Dedicated Resources"); + System.out.format( + "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount()); + System.out.format( + "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount()); + + MachineSpec machineSpec = batchDedicatedResources.getMachineSpec(); + System.out.println("\t\tMachine Spec"); + System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType()); + System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType()); + System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount()); + + ManualBatchTuningParameters manualBatchTuningParameters = + batchPredictionJobResponse.getManualBatchTuningParameters(); + System.out.println("\tManual Batch Tuning Parameters"); + System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize()); + + OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo(); + System.out.println("\tOutput Info"); + System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory()); + System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset()); + + Status status = batchPredictionJobResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + List details = status.getDetailsList(); + + for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) { + System.out.println("\tPartial Failure"); + System.out.format("\t\tCode: %s\n", partialFailure.getCode()); + System.out.format("\t\tMessage: %s\n", partialFailure.getMessage()); + List partialFailureDetailsList = partialFailure.getDetailsList(); + } + + ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed(); + System.out.println("\tResources Consumed"); + System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours()); + + CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats(); + System.out.println("\tCompletion Stats"); + System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount()); + System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount()); + System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount()); + } + } +} +// [END aiplatform_create_batch_prediction_job_video_object_tracking_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java new file mode 100644 index 00000000000..0ce9767c48f --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetImageSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_image_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetImageSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + createDatasetImageSample(project, datasetDisplayName); + } + + static void createDatasetImageSample(String project, String datasetDisplayName) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + + System.out.println("Create Image Dataset Response"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + System.out.format("Create Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("Labels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_image_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java new file mode 100644 index 00000000000..537525c8136 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetVideoSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_video_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetVideoSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetVideoDisplayName = "YOUR_DATASET_VIDEO_DISPLAY_NAME"; + createDatasetSample(datasetVideoDisplayName, project); + } + + static void createDatasetSample(String datasetVideoDisplayName, String project) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetVideoDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Video Response"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + System.out.format("Create Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("Update Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("Labels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_video_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java new file mode 100644 index 00000000000..7327cba9b20 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java @@ -0,0 +1,233 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_image_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + createTrainingPipelineImageClassificationSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineImageClassificationSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_classification_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000," + + " \"disableEarlyStopping\": false}"; + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Image Classification Response"); + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data Config"); + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("Fraction Split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("Predict Schemata"); + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("Supported Export Format"); + System.out.format("Id: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("Container Spec"); + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("Env"); + System.out.format("Name: %s\n", envVar.getName()); + System.out.format("Value: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("Port"); + System.out.format("Container Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("Deployed Model"); + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("Explanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("Parameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("Sampled Shapley Attribution"); + System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("Metadata"); + System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); + System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "Feature Attributions Schema_uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_image_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java new file mode 100644 index 00000000000..636ab022418 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java @@ -0,0 +1,233 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_image_object_detection_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + createTrainingPipelineImageObjectDetectionSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineImageObjectDetectionSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_image_object_detection_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"modelType\": \"CLOUD_HIGH_ACCURACY_1\", \"budgetMilliNodeHours\": 20000," + + " \"disableEarlyStopping\": false}"; + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Image Object Detection Response"); + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data Config"); + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("Fraction Split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("Predict Schemata"); + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("Supported Export Format"); + System.out.format("Id: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("Container Spec"); + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("Env"); + System.out.format("Name: %s\n", envVar.getName()); + System.out.format("Value: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("Port"); + System.out.format("Container Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("Deployed Model"); + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("Explanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("Parameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("Sampled Shapley Attribution"); + System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("Metadata"); + System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); + System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "Feature Attributions Schema_uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_image_object_detection_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java new file mode 100644 index 00000000000..383e56954af --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoClassificationSample.java @@ -0,0 +1,162 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineVideoClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String videoClassificationDisplayName = + "YOUR_TRAINING_PIPELINE_VIDEO_CLASSIFICATION_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + createTrainingPipelineVideoClassification( + videoClassificationDisplayName, datasetId, modelDisplayName, project); + } + + static void createTrainingPipelineVideoClassification( + String videoClassificationDisplayName, + String datasetId, + String modelDisplayName, + String project) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_classification_1.0.0.yaml"; + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(videoClassificationDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(Value.newBuilder()) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Video Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_video_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java new file mode 100644 index 00000000000..d49fcff96dd --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java @@ -0,0 +1,174 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_video_object_tracking_sample] + +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + String trainingPipelineVideoObjectTracking = + "YOUR_TRAINING_PIPELINE_VIDEO_OBJECT_TRACKING_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + createTrainingPipelineVideoObjectTracking( + trainingPipelineVideoObjectTracking, datasetId, modelDisplayName, project); + } + + static void createTrainingPipelineVideoObjectTracking( + String trainingPipelineVideoObjectTracking, + String datasetId, + String modelDisplayName, + String project) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_video_object_tracking_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = "{\"modelType\": \"CLOUD\"}"; + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineVideoObjectTracking) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline createTrainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Video Object Tracking Response"); + System.out.format("Name: %s\n", createTrainingPipelineResponse.getName()); + System.out.format("Display Name: %s\n", createTrainingPipelineResponse.getDisplayName()); + + System.out.format( + "Training Task Definition %s\n", + createTrainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "Training Task Inputs: %s\n", + createTrainingPipelineResponse.getTrainingTaskInputs().toString()); + System.out.format( + "Training Task Metadata: %s\n", + createTrainingPipelineResponse.getTrainingTaskMetadata().toString()); + + System.out.format("State: %s\n", createTrainingPipelineResponse.getState().toString()); + System.out.format( + "Create Time: %s\n", createTrainingPipelineResponse.getCreateTime().toString()); + System.out.format("StartTime %s\n", createTrainingPipelineResponse.getStartTime().toString()); + System.out.format("End Time: %s\n", createTrainingPipelineResponse.getEndTime().toString()); + System.out.format( + "Update Time: %s\n", createTrainingPipelineResponse.getUpdateTime().toString()); + System.out.format("Labels: %s\n", createTrainingPipelineResponse.getLabelsMap().toString()); + + InputDataConfig inputDataConfigResponse = createTrainingPipelineResponse.getInputDataConfig(); + System.out.println("Input Data config"); + System.out.format("Dataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format("Annotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit(); + System.out.println("Fraction split"); + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("Filter Split"); + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("Predefined Split"); + System.out.format("Key: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("Timestamp Split"); + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("Key: %s\n", timestampSplit.getKey()); + + Model modelResponse = createTrainingPipelineResponse.getModelToUpload(); + System.out.println("Model To Upload"); + System.out.format("Name: %s\n", modelResponse.getName()); + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); + System.out.format("Description: %s\n", modelResponse.getDescription()); + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); + + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "Supported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "Supported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "Supported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("Labels: %s\n", modelResponse.getLabelsMap()); + + Status status = createTrainingPipelineResponse.getError(); + System.out.println("Error"); + System.out.format("Code: %s\n", status.getCode()); + System.out.format("Message: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_video_object_tracking_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java new file mode 100644 index 00000000000..c128689d78a --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteBatchPredictionJobSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_batch_prediction_job_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.JobServiceClient; +import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteBatchPredictionJobSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String batchPredictionJobId = "YOUR_BATCH_PREDICTION_JOB_ID"; + deleteBatchPredictionJobSample(project, batchPredictionJobId); + } + + static void deleteBatchPredictionJobSample(String project, String batchPredictionJobId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + JobServiceSettings jobServiceSettings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) { + String location = "us-central1"; + + BatchPredictionJobName batchPredictionJobName = + BatchPredictionJobName.of(project, location, batchPredictionJobId); + + OperationFuture operationFuture = + jobServiceClient.deleteBatchPredictionJobAsync(batchPredictionJobName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Deleted Batch Prediction Job."); + } + } +} +// [END aiplatform_delete_batch_prediction_job_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java index 39ad52d0fdf..a9989b564e1 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java @@ -59,7 +59,7 @@ static void deleteDatasetSample(String project, String datasetId) System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); System.out.println("Waiting for operation to finish..."); operationFuture.get(300, TimeUnit.SECONDS); - + System.out.format("Deleted Dataset."); } } diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java new file mode 100644 index 00000000000..7c1a3bfad6a --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageClassificationSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_image_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationImageClassificationSample(project, modelId, evaluationId); + } + + static void getModelEvaluationImageClassificationSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Image Classification Response"); + System.out.format("Model Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + System.out.println("\t\tMean Attribution"); + System.out.format( + "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_image_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java new file mode 100644 index 00000000000..cd8f7a1cb43 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationImageObjectDetectionSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_image_object_detection_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationImageObjectDetectionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationImageObjectDetectionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Image Object Detection Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + System.out.println("\t\tMean Attribution"); + System.out.format( + "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_image_object_detection_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java new file mode 100644 index 00000000000..5dc3d85ab43 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoClassificationSample.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_video_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationVideoClassification(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoClassification( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Video Classification Response"); + System.out.format("Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_video_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java new file mode 100644 index 00000000000..8cd4ccb5905 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationVideoObjectTrackingSample.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_object_tracking_sample] + +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationVideoObjectTrackingSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationVideoObjectTracking(project, modelId, evaluationId); + } + + static void getModelEvaluationVideoObjectTracking( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Video Object Tracking Response"); + System.out.format("Name: %s\n", modelEvaluation.getName()); + System.out.format("Metrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("Metrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("Create Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("Slice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + } + } +} +// [END aiplatform_get_model_evaluation_object_tracking_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java new file mode 100644 index 00000000000..04beb8c54b2 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageClassificationSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_image_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataImageClassificationSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]"; + importDataImageClassificationSample(project, datasetId, gcsSourceUri); + } + + static void importDataImageClassificationSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "image_classification_single_label_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Import Data Image Classification Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_image_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java new file mode 100644 index 00000000000..ae17cfd3a49 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataImageObjectDetectionSample.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_image_object_detection_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataImageObjectDetectionSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_image_source/[file.csv/file.jsonl]"; + importDataImageObjectDetectionSample(project, datasetId, gcsSourceUri); + } + + static void importDataImageObjectDetectionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "image_bounding_box_io_format_1.0.0.yaml"; + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Import Data Image Object Detection Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_image_object_detection_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java new file mode 100644 index 00000000000..4bf2c37f303 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoClassificationSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataVideoClassificationSample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + importDataVideoClassification(gcsSourceUri, project, datasetId); + } + + static void importDataVideoClassification(String gcsSourceUri, String project, String datasetId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_classification_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + + DatasetName datasetName = DatasetName.of(project, location, datasetId); + List importDataConfigs = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigs); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(1800, TimeUnit.SECONDS); + + System.out.format( + "Import Data Video Classification Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_video_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java new file mode 100644 index 00000000000..f8a07d91485 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataVideoObjectTrackingSample.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_video_object_tracking_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataVideoObjectTrackingSample { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + importDataVideObjectTracking(gcsSourceUri, project, datasetId); + } + + static void importDataVideObjectTracking(String gcsSourceUri, String project, String datasetId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "video_object_tracking_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + List importDataConfigs = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigs); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Import Data Video Object Tracking Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_video_object_tracking_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java new file mode 100644 index 00000000000..b63a914005e --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_image_classification_sample] + +import com.google.api.client.util.Base64; +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class PredictImageClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String fileName = "YOUR_IMAGE_FILE_PATH"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictImageClassification(project, fileName, endpointId); + } + + static void predictImageClassification(String project, String fileName, String endpointId) + throws IOException { + PredictionServiceSettings settings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(settings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName))); + String content = new String(contents, StandardCharsets.UTF_8); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + + String contentDict = "{\"content\": \"" + content + "\"}"; + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(contentDict, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Image Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_image_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java new file mode 100644 index 00000000000..b7e832871f2 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_image_object_detection_sample] + +import com.google.api.client.util.Base64; +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class PredictImageObjectDetectionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String fileName = "YOUR_IMAGE_FILE_PATH"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictImageObjectDetection(project, fileName, endpointId); + } + + static void predictImageObjectDetection(String project, String fileName, String endpointId) + throws IOException { + PredictionServiceSettings settings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(settings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName))); + String content = new String(contents, StandardCharsets.UTF_8); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + + String contentDict = "{\"content\": \"" + content + "\"}"; + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(contentDict, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Image Object Detection Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_image_object_detection_sample] diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java new file mode 100644 index 00000000000..3fa42715cda --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_VIDEO_CLASS_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/vcn_40_batch_prediction_input.jsonl"; + private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_VIDEO_CLASS_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Batch Prediction Job + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobVideoClassificationSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_classification_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoClassificationSample.createBatchPredictionJobVideoClassification( + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_DESTINATION_OUTPUT_URI_PREFIX, + PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("Create Batch Prediction Job Video Classification Response"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..06f934f4960 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("BATCH_PREDICTION_VIDEO_OBJECT_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/vot_batch_prediction_input.jsonl"; + private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String batchPredictionJobId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("BATCH_PREDICTION_VIDEO_OBJECT_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Batch Prediction Job + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobVideoObjectTrackingSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "batch_prediction_video_object_tracking_display_name_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobVideoObjectTrackingSample.batchPredictionJobVideoObjectTracking( + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_DESTINATION_OUTPUT_URI_PREFIX, + PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("Create Batch Prediction Job Video Object Tracking Response"); + batchPredictionJobId = got.split("Name: ")[1].split("batchPredictionJobs/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java new file mode 100644 index 00000000000..f2b95b3d4ee --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetImageSampleTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetImageSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetImageSample.createDatasetImageSample(PROJECT, datasetDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Image Dataset Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java new file mode 100644 index 00000000000..b979692fa0d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetVideoSampleTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetVideoSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetVideoSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String displayName = + String.format( + "temp_create_dataset_video_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetVideoSample.createDatasetSample(displayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(displayName); + assertThat(got).contains("Create Dataset Video Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java new file mode 100644 index 00000000000..747c9117fb6 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageClassificationSampleTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineImageClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineImageClassificationSample.createTrainingPipelineImageClassificationSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Image Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..c4295cb9440 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSampleTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_IMAGE_OBJECT_DETECT_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineImageObjectDetectionSample() throws IOException { + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); + // Act + String trainingPipelineDisplayName = + String.format("temp_create_training_pipeline_test_%s", tempUuid); + + String modelDisplayName = + String.format("temp_create_training_pipeline_model_test_%s", tempUuid); + + CreateTrainingPipelineImageObjectDetectionSample + .createTrainingPipelineImageObjectDetectionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Image Object Detection Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java new file mode 100644 index 00000000000..d58fd4fc63b --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoClassificationSampleTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineVideoClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_classification_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_classification_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoClassificationSample.createTrainingPipelineVideoClassification( + trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Video Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..010dcc07526 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSampleTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_VIDEO_OBJECT_DETECT_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineVideoObjectTrackingSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_video_object_tracking_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_video_object_tracking_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineVideoObjectTrackingSample.createTrainingPipelineVideoObjectTracking( + trainingPipelineDisplayName, DATASET_ID, modelDisplayName, PROJECT); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Video Object Tracking Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java new file mode 100644 index 00000000000..228fa605577 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageClassificationSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("IMAGE_CLASS_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("IMAGE_CLASS_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_CLASS_MODEL_ID"); + requireEnvVar("IMAGE_CLASS_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationImageClassificationSample() throws IOException { + // Act + GetModelEvaluationImageClassificationSample.getModelEvaluationImageClassificationSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Image Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..b78ec23a61d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationImageObjectDetectionSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("IMAGE_OBJECT_DETECT_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("IMAGE_OBJECT_DETECT_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_OBJECT_DETECT_MODEL_ID"); + requireEnvVar("IMAGE_OBJECT_DETECT_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationImageObjectDetectionSample() throws IOException { + // Act + GetModelEvaluationImageObjectDetectionSample.getModelEvaluationImageObjectDetectionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Image Object Detection Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java new file mode 100644 index 00000000000..4347485ef86 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoClassificationSampleTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_CLASS_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("VIDEO_CLASS_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_CLASS_MODEL_ID"); + requireEnvVar("VIDEO_CLASS_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoClassificationSample() throws IOException { + // Act + GetModelEvaluationVideoClassificationSample.getModelEvaluationVideoClassification( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Video Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..cefe40345d3 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationVideoObjectTrackingSampleTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("VIDEO_OBJECT_DETECT_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("VIDEO_OBJECT_DETECT_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_OBJECT_DETECT_MODEL_ID"); + requireEnvVar("VIDEO_OBJECT_DETECT_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationVideoObjectTrackingSample() throws IOException { + // Act + GetModelEvaluationVideoObjectTrackingSample.getModelEvaluationVideoObjectTracking( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Video Object Tracking Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java new file mode 100644 index 00000000000..ed4d7111964 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageClassificationSampleTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataImageClassificationSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + + private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl"; + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataImageClassificationSample.importDataImageClassificationSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Image Classification Response: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..451a7c230fb --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataImageObjectDetectionSampleTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = "gs://ucaip-sample-resources/input.jsonl"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataImageObjectDetectionSample.importDataImageObjectDetectionSample( + PROJECT, datasetId, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Image Object Detection Response: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java new file mode 100644 index 00000000000..66b0237cf45 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoClassificationSampleTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ImportDataVideoClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv"; + + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataVideoClassificationSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoClassificationSample.importDataVideoClassification( + GCS_SOURCE_URI, PROJECT, datasetId); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Video Classification Response: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java new file mode 100644 index 00000000000..6d8b5e7a7ed --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataVideoObjectTrackingSampleTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Empty; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ImportDataVideoObjectTrackingSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String GCS_SOURCE_URI = + "gs://automl-video-demo-data/traffic_videos/traffic_videos_train.csv"; + private String datasetId; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + + // create a temp dataset for importing data + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml"; + LocationName locationName = LocationName.of(PROJECT, LOCATION); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName("test_dataset_display_name") + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + String[] datasetValues = datasetResponse.getName().split("/"); + datasetId = datasetValues[datasetValues.length - 1]; + } + } + + @After + public void tearDown() throws InterruptedException, ExecutionException, IOException { + // delete the temp dataset + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + DatasetName datasetName = DatasetName.of(PROJECT, LOCATION, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + operationFuture.get(); + } + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataVideoObjectTrackingSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + ImportDataVideoObjectTrackingSample.importDataVideObjectTracking( + GCS_SOURCE_URI, PROJECT, datasetId); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Import Data Video Object Tracking Response: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java new file mode 100644 index 00000000000..8ca3fd95c5e --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictImageClassificationSampleTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictImageClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String FILE_NAME = "resources/image_flower_daisy.jpg"; + private static final String ENDPOINT_ID = System.getenv("IMAGE_CLASS_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_CLASS_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictImageClassification() throws IOException { + // Act + PredictImageClassificationSample.predictImageClassification(PROJECT, FILE_NAME, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Image Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java new file mode 100644 index 00000000000..1539c7dfbc1 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictImageObjectDetectionSampleTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictImageObjectDetectionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String FILE_NAME = "resources/caprese_salad.jpg"; + private static final String ENDPOINT_ID = System.getenv("IMAGE_OBJECT_DETECTION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("IMAGE_OBJECT_DETECTION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictImageObjectDetection() throws IOException { + // Act + PredictImageObjectDetectionSample.predictImageObjectDetection(PROJECT, FILE_NAME, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Image Object Detection Response"); + } +}