From 413b908fa1b94d6e433b1a638b790f30b414be37 Mon Sep 17 00:00:00 2001 From: Mike <45373284+munkhuushmgl@users.noreply.github.com> Date: Fri, 6 Nov 2020 11:35:26 -0800 Subject: [PATCH] samples: ucaip samples batch 3 of 6 (#18) * samples:samples: ucaip samples batch 3 of 6 * made requested the changes * changed all instance of tables into tabular * fixed the lint * reversed some comments --- aiplatform/snippets/pom.xml | 4 +- .../CancelTrainingPipelineSample.java | 57 ++++ .../CreateDatasetTabularBigquerySample.java | 89 ++++++ .../CreateDatasetTabularGcsSample.java | 88 ++++++ ...ngPipelineTabularClassificationSample.java | 255 ++++++++++++++++++ ...ainingPipelineTabularRegressionSample.java | 254 +++++++++++++++++ .../java/aiplatform/DeleteDatasetSample.java | 67 +++++ .../aiplatform/DeleteExportModelSample.java | 45 ++++ .../DeleteTrainingPipelineSample.java | 68 +++++ ...xportModelTabularClassificationSample.java | 79 ++++++ ...EvaluationTabularClassificationSample.java | 78 ++++++ ...odelEvaluationTabularRegressionSample.java | 78 ++++++ .../PredictTabularClassificationSample.java | 73 +++++ .../PredictTabularRegressionSample.java | 73 +++++ ...reateDatasetTabularBigquerySampleTest.java | 93 +++++++ .../CreateDatasetTabularGcsSampleTest.java | 93 +++++++ ...pelineTabularClassificationSampleTest.java | 121 +++++++++ ...ngPipelineTabularRegressionSampleTest.java | 119 ++++++++ ...tModelTabularClassificationSampleTest.java | 89 ++++++ ...uationTabularClassificationSampleTest.java | 80 ++++++ ...EvaluationTabularRegressionSampleTest.java | 80 ++++++ ...redictTabularClassificationSampleTest.java | 80 ++++++ .../PredictTabularRegressionSampleTest.java | 100 +++++++ 23 files changed, 2161 insertions(+), 2 deletions(-) create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java create mode 100644 aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java create mode 100644 aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml index a89bbb4a567..d9652815f2f 100644 --- a/aiplatform/snippets/pom.xml +++ b/aiplatform/snippets/pom.xml @@ -23,14 +23,14 @@ UTF-8 + + - com.google.cloud google-cloud-aiplatform 0.0.1-SNAPSHOT - com.google.cloud google-cloud-storage diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java new file mode 100644 index 00000000000..4dd2902f328 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java @@ -0,0 +1,57 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_cancel_training_pipeline_sample] + +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.TrainingPipelineName; +import java.io.IOException; + +public class CancelTrainingPipelineSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID"; + String project = "YOUR_PROJECT_ID"; + cancelTrainingPipelineSample(project, trainingPipelineId); + } + + static void cancelTrainingPipelineSample(String project, String trainingPipelineId) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(project, location, trainingPipelineId); + + pipelineServiceClient.cancelTrainingPipeline(trainingPipelineName); + + System.out.println("Cancelled the Training Pipeline"); + } + } +} +// [END aiplatform_cancel_training_pipeline_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java new file mode 100644 index 00000000000..bcaf5c94eee --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularBigquerySample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_tabular_bigquery_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTabularBigquerySample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String bigqueryDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String bigqueryUri = + "bq://YOUR_GOOGLE_CLOUD_PROJECT_ID.BIGQUERY_DATASET_ID.BIGQUERY_TABLE_OR_VIEW_ID"; + createDatasetTableBigquery(project, bigqueryDisplayName, bigqueryUri); + } + + static void createDatasetTableBigquery( + String project, String bigqueryDisplayName, String bigqueryUri) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"input_config\": {\"bigquery_source\": {\"uri\": \"" + bigqueryUri + "\"}}}"; + Value.Builder metaData = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, metaData); + + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(bigqueryDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .setMetadata(metaData) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Table Bigquery sample"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + } + } +} +// [END aiplatform_create_dataset_tabular_bigquery_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java new file mode 100644 index 00000000000..2b2f17f41ba --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTabularGcsSample.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_tabular_gcs_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTabularGcsSample { + + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_gcs_table/file.csv"; + ; + createDatasetTableGcs(project, datasetDisplayName, gcsSourceUri); + } + + static void createDatasetTableGcs(String project, String datasetDisplayName, String gcsSourceUri) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + DatasetServiceSettings settings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = DatasetServiceClient.create(settings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/tables_1.0.0.yaml"; + LocationName locationName = LocationName.of(project, location); + + String jsonString = + "{\"input_config\": {\"gcs_source\": {\"uri\": [\"" + gcsSourceUri + "\"]}}}"; + Value.Builder metaData = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, metaData); + + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .setMetadata(metaData) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(300, TimeUnit.SECONDS); + + System.out.println("Create Dataset Table GCS sample"); + System.out.format("Name: %s\n", datasetResponse.getName()); + System.out.format("Display Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("Metadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("Metadata: %s\n", datasetResponse.getMetadata()); + } + } +} +// [END aiplatform_create_dataset_tabular_gcs_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java new file mode 100644 index 00000000000..de54c0a0a07 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java @@ -0,0 +1,255 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + String transformation = + "[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]"; + createTrainingPipelineTableClassification( + project, modelDisplayName, datasetId, targetColumn, transformation); + } + + static void createTrainingPipelineTableClassification( + String project, + String modelDisplayName, + String datasetId, + String targetColumn, + String transformation) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + String jsonString = + "{\"targetColumn\": \"" + + targetColumn + + "\",\"predictionType\": \"classification\",\"transformations\": " + + transformation + + ",\"trainBudgetMilliNodeHours\": 8000}"; + + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("\tExplanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("\t\tParameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("\t\tSampled Shapley Attribution"); + System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("\t\tMetadata"); + System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap()); + System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "\t\t\tFeature Attributions Schema Uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java new file mode 100644 index 00000000000..ca24862cad2 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java @@ -0,0 +1,254 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_tabular_regression_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + String datasetId = "YOUR_DATASET_ID"; + String targetColumn = "TARGET_COLUMN"; + String transformation = + "[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]"; + createTrainingPipelineTableRegression( + project, modelDisplayName, datasetId, targetColumn, transformation); + } + + static void createTrainingPipelineTableRegression( + String project, + String modelDisplayName, + String datasetId, + String targetColumn, + String transformation) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + LocationName locationName = LocationName.of(project, location); + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; + String jsonString = + "{\"targetColumn\": \"" + + targetColumn + + "\",\"predictionType\": \"regression\",\"transformations\": " + + transformation + + ",\"trainBudgetMilliNodeHours\": 8000}"; + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + FractionSplit fractionSplit = + FractionSplit.newBuilder() + .setTrainingFraction(0.8) + .setValidationFraction(0.1) + .setTestFraction(0.1) + .build(); + + InputDataConfig inputDataConfig = + InputDataConfig.newBuilder() + .setDatasetId(datasetId) + .setFractionSplit(fractionSplit) + .build(); + Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build(); + + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(modelDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(inputDataConfig) + .setModelToUpload(modelToUpload) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Tabular Regression Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + System.out.format( + "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + + System.out.format("\tState: %s\n", trainingPipelineResponse.getState()); + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId()); + System.out.format( + "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter()); + + FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format( + "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction()); + System.out.format( + "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction()); + + FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList().toString()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList().toString()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList().toString()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap()); + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + + System.out.println("\tPredict Schemata"); + System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (Model.ExportFormat supportedExportFormat : + modelResponse.getSupportedExportFormatsList()) { + System.out.println("\tSupported Export Format"); + System.out.format("\t\tId: %s\n", supportedExportFormat.getId()); + } + ModelContainerSpec containerSpec = modelResponse.getContainerSpec(); + + System.out.println("\tContainer Spec"); + System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri()); + System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList()); + System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList()); + System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute()); + System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute()); + + for (EnvVar envVar : containerSpec.getEnvList()) { + System.out.println("\t\tEnv"); + System.out.format("\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : containerSpec.getPortsList()) { + System.out.println("\t\tPort"); + System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\tDeployed Model"); + System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("\tExplanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("\t\tParameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("\t\tSampled Shapley Attribution"); + System.out.format("\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("\t\tMetadata"); + System.out.format("\t\t\tInput: %s\n", explanationMetadata.getInputsMap()); + System.out.format("\t\t\tOutput: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "\t\t\tFeature Attributions Schema Uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_tabular_regression_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java new file mode 100644 index 00000000000..39ad52d0fdf --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteDatasetSample.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_dataset_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteDatasetSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + deleteDatasetSample(project, datasetId); + } + + static void deleteDatasetSample(String project, String datasetId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + OperationFuture operationFuture = + datasetServiceClient.deleteDatasetAsync(datasetName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Deleted Dataset."); + } + } +} +// [END aiplatform_delete_dataset_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java new file mode 100644 index 00000000000..d6ed1995714 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteExportModelSample.java @@ -0,0 +1,45 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_export_model_sample] + +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; + +public class DeleteExportModelSample { + + public static void main(String[] args) { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String bucketName = "YOUR_BUCKET_NAME"; + String folderName = "YOUR_FOLDER_NAME"; + deleteExportModelSample(project, bucketName, folderName); + } + + static void deleteExportModelSample(String project, String bucketName, String folderName) { + Storage storage = StorageOptions.newBuilder().setProjectId(project).build().getService(); + Iterable blobs = + storage.list(bucketName, Storage.BlobListOption.prefix(folderName)).iterateAll(); + for (Blob blob : blobs) { + blob.delete(Blob.BlobSourceOption.generationMatch()); + } + System.out.println("Export Model Deleted"); + } +} +// [END aiplatform_delete_export_model_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java new file mode 100644 index 00000000000..d3819cd03ea --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/DeleteTrainingPipelineSample.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_delete_training_pipeline_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.TrainingPipelineName; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class DeleteTrainingPipelineSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineId = "YOUR_TRAINING_PIPELINE_ID"; + String project = "YOUR_PROJECT_ID"; + deleteTrainingPipelineSample(project, trainingPipelineId); + } + + static void deleteTrainingPipelineSample(String project, String trainingPipelineId) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + TrainingPipelineName trainingPipelineName = + TrainingPipelineName.of(project, location, trainingPipelineId); + + OperationFuture operationFuture = + pipelineServiceClient.deleteTrainingPipelineAsync(trainingPipelineName); + System.out.format("Operation name: %s\n", operationFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + operationFuture.get(300, TimeUnit.SECONDS); + + System.out.format("Deleted Training Pipeline."); + } + } +} +// [END aiplatform_delete_training_pipeline_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java new file mode 100644 index 00000000000..f3fedf710c2 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ExportModelTabularClassificationSample.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_export_model_tabular_classification_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.ExportModelOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExportModelRequest; +import com.google.cloud.aiplatform.v1beta1.ExportModelResponse; +import com.google.cloud.aiplatform.v1beta1.GcsDestination; +import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ExportModelTabularClassificationSample { + public static void main(String[] args) + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // TODO(developer): Replace these variables before running the sample. + String gcsDestinationOutputUriPrefix = "gs://your-gcs-bucket/destination_path"; + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + exportModelTableClassification(gcsDestinationOutputUriPrefix, project, modelId); + } + + static void exportModelTableClassification( + String gcsDestinationOutputUriPrefix, String project, String modelId) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelName modelName = ModelName.of(project, location, modelId); + + GcsDestination.Builder gcsDestination = GcsDestination.newBuilder(); + gcsDestination.setOutputUriPrefix(gcsDestinationOutputUriPrefix); + ExportModelRequest.OutputConfig outputConfig = + ExportModelRequest.OutputConfig.newBuilder() + .setExportFormatId("tf-saved-model") + .setArtifactDestination(gcsDestination) + .build(); + + OperationFuture exportModelResponseFuture = + modelServiceClient.exportModelAsync(modelName, outputConfig); + System.out.format( + "Operation name: %s\n", exportModelResponseFuture.getInitialFuture().get().getName()); + System.out.println("Waiting for operation to finish..."); + ExportModelResponse exportModelResponse = + exportModelResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format( + "Export Model Tabular Classification Response: %s", exportModelResponse.toString()); + } + } +} +// [END aiplatform_export_model_tabular_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java new file mode 100644 index 00000000000..e347bf820d6 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularClassificationSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_tabular_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationTabularClassification(project, modelId, evaluationId); + } + + static void getModelEvaluationTabularClassification( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Tabular Classification Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + + System.out.println("\tModel Explanation"); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + System.out.println("\t\tMean Attributions"); + System.out.format( + "\t\t\tBaseline Output value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Index Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_tabular_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java new file mode 100644 index 00000000000..bc9910a793b --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTabularRegressionSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_tabular_regression_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + getModelEvaluationTabularRegression(project, modelId, evaluationId); + } + + static void getModelEvaluationTabularRegression( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Tabular Regression Response"); + System.out.format("\tName: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + + System.out.println("\tModel Explanation"); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + System.out.println("\t\tMean Attributions"); + System.out.format( + "\t\t\tBaseline Output value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Index Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_tabular_regression_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java new file mode 100644 index 00000000000..302af2d55bd --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_tabular_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.List; + +public class PredictTabularClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictTabularClassification(instance, project, endpointId); + } + + static void predictTabularClassification(String instance, String project, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + ListValue.Builder listValue = ListValue.newBuilder(); + JsonFormat.parser().merge(instance, listValue); + List instanceList = listValue.getValuesList(); + + Value parameters = Value.newBuilder().build(); + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instanceList, parameters); + System.out.println("Predict Tabular Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_tabular_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java new file mode 100644 index 00000000000..7520f554d84 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_tabular_regression_sample] + +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.List; + +public class PredictTabularRegressionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String instance = "[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]"; + String endpointId = "YOUR_ENDPOINT_ID"; + predictTabularRegression(instance, project, endpointId); + } + + static void predictTabularRegression(String instance, String project, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + ListValue.Builder listValue = ListValue.newBuilder(); + JsonFormat.parser().merge(instance, listValue); + List instanceList = listValue.getValuesList(); + + Value parameters = Value.newBuilder().build(); + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instanceList, parameters); + System.out.println("Predict Tabular Regression Response"); + System.out.format("\tDisplay Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_tabular_regression_sample] diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java new file mode 100644 index 00000000000..42b002514a5 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularBigquerySampleTest.java @@ -0,0 +1,93 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateDatasetTabularBigquerySampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String GCS_SOURCE_URI = "bq://ucaip-sample-tests.table_test.all_bq_types"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetTabularBigquerySample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_table_bigquery_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTabularBigquerySample.createDatasetTableBigquery( + PROJECT, datasetDisplayName, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Dataset Table Bigquery sample"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java new file mode 100644 index 00000000000..10a26a5e144 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTabularGcsSampleTest.java @@ -0,0 +1,93 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateDatasetTabularGcsSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String GCS_SOURCE_URI = "gs://cloud-ml-tables-data/bank-marketing.csv"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetTabularGcsSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_table_gcs_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTabularGcsSample.createDatasetTableGcs(PROJECT, + datasetDisplayName, GCS_SOURCE_URI); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Dataset Table GCS sample"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java new file mode 100644 index 00000000000..65f7d041bf2 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularClassificationSampleTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID"); + private static final String TARGET_COLUMN = "TripType"; + private static final String TRANSFORMATION = + "[{\"numeric\":{\"columnName\":\"Age\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"Job\"}}," + + "{\"categorical\":{\"columnName\":\"MaritalStatus\"}}," + + "{\"categorical\":{\"columnName\":\"Default\"}}," + + "{\"numeric\":{\"columnName\":\"Balance\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"Housing\"}}," + + "{\"categorical\":{\"columnName\":\"Loan\"}}," + + "{\"categorical\":{\"columnName\":\"Contact\"}}," + + "{\"numeric\":{\"columnName\":\"Day\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"Month\"}}," + + "{\"numeric\":{\"columnName\":\"Duration\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Campaign\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"PDays\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Previous\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"POutcome\"}}," + + "{\"categorical\":{\"columnName\":\"Deposit\"}}]"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TABLES_CLASSIFICATION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void createTrainingPipelineTabularClassification() throws IOException { + // Act + String modelDisplayName = + String.format( + "temp_create_training_pipelinetabularclassification_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTabularClassificationSample.createTrainingPipelineTableClassification( + PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN, TRANSFORMATION); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Tabular Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java new file mode 100644 index 00000000000..3933106f3e6 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTabularRegressionSampleTest.java @@ -0,0 +1,119 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateTrainingPipelineTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); + private static final String TARGET_COLUMN = "Amount"; + private static final String TRANSFORMATION = + "[{\"categorical\":{\"columnName\":\"SC_Group_Desc\"}}," + + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_ID\"}}," + + "{\"categorical\":{\"columnName\":\"SC_GroupCommod_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SortOrder\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_GeographyIndented_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Commodity_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_Commodity_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Attribute_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"SC_Attribute_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"SC_Unit_ID\",\"invalidValuesAllowed\":false}}," + + "{\"numeric\":{\"columnName\":\"Year_ID\",\"invalidValuesAllowed\":false}}," + + "{\"categorical\":{\"columnName\":\"SC_Frequency_Desc\"}}," + + "{\"numeric\":{\"columnName\":\"Timeperiod_ID\",\"invalidValuesAllowed\":false}}," + + "{\"text\":{\"columnName\":\"Timeperiod_Desc\"}}]"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TABLES_REGRESSION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void createTrainingPipelineTabularRegression() throws IOException { + // Act + String modelDisplayName = + String.format( + "temp_create_training_pipelinetabularregression_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTabularRegressionSample.createTrainingPipelineTableRegression( + PROJECT, modelDisplayName, DATASET_ID, TARGET_COLUMN, TRANSFORMATION); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Tabular Regression Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java new file mode 100644 index 00000000000..9212dc3d920 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ExportModelTabularClassificationSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ExportModelTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("EXPORT_MODEL_TABLES_CLASSIFICATION_MODEL_ID"); + private static final String GCS_DESTINATION_URI_PREFIX = + "gs://ucaip-samples-test-output/tmp/export_model_test"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("EXPORT_MODEL_TABLES_CLASSIFICATION_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + // Delete the export model + String bucketName = GCS_DESTINATION_URI_PREFIX.split("/", 4)[2]; + String objectName = (GCS_DESTINATION_URI_PREFIX.split("/", 4)[3]).concat("model-" + MODEL_ID); + DeleteExportModelSample.deleteExportModelSample(PROJECT, bucketName, objectName); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Export Model Deleted"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void exportModelTabularClassification() + throws InterruptedException, ExecutionException, TimeoutException, IOException { + // Act + ExportModelTabularClassificationSample.exportModelTableClassification( + GCS_DESTINATION_URI_PREFIX, PROJECT, MODEL_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Export Model Tabular Classification Response: "); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java new file mode 100644 index 00000000000..6995dcd9f8a --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularClassificationSampleTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("MODEL_EVALUATION_TABLES_CLASSIFICATION_MODEL_ID"); + private static final String EVALUATION_ID = + System.getenv("MODEL_EVALUATION_TABLES_CLASSIFICATION_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("MODEL_EVALUATION_TABLES_CLASSIFICATION_MODEL_ID"); + requireEnvVar("MODEL_EVALUATION_TABLES_CLASSIFICATION_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void getModelEvaluationTabularClassification() throws IOException { + // Act + GetModelEvaluationTabularClassificationSample.getModelEvaluationTabularClassification( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Tabular Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java new file mode 100644 index 00000000000..81daedecc6d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTabularRegressionSampleTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class GetModelEvaluationTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = + System.getenv("MODEL_EVALUATION_TABLES_REGRESSION_MODEL_ID"); + private static final String EVALUATION_ID = + System.getenv("MODEL_EVALUATION_TABLES_REGRESSION_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("MODEL_EVALUATION_TABLES_REGRESSION_MODEL_ID"); + requireEnvVar("MODEL_EVALUATION_TABLES_REGRESSION_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void getModelEvaluationTabularRegression() throws IOException { + // Act + GetModelEvaluationTabularRegressionSample.getModelEvaluationTabularRegression( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Tabular Regression Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java new file mode 100644 index 00000000000..1574efe2ae1 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularClassificationSampleTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTabularClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String INSTANCE = + "[{\"petal_length\": '1.4'," + + " \"petal_width\": '1.3'," + + " \"sepal_length\": '5.1'," + + " \"sepal_width\": '2.8'}]"; + private static final String ENDPOINT_ID = + System.getenv("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("PREDICT_TABLES_CLASSIFCATION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTabularClassification() throws IOException { + // Act + PredictTabularClassificationSample.predictTabularClassification(INSTANCE, PROJECT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Tabular Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java new file mode 100644 index 00000000000..44f5bfdfa21 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTabularRegressionSampleTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTabularRegressionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String INSTANCE = + "[{\n" + + " \"BOOLEAN_2unique_NULLABLE\": False,\n" + + " \"DATETIME_1unique_NULLABLE\": '2019-01-01 00:00:00',\n" + + " \"DATE_1unique_NULLABLE\": '2019-01-01',\n" + + " \"FLOAT_5000unique_NULLABLE\": 1611,\n" + + " \"FLOAT_5000unique_REPEATED\": [2320,1192],\n" + + " \"INTEGER_5000unique_NULLABLE\": '8',\n" + + " \"NUMERIC_5000unique_NULLABLE\": 16,\n" + + " \"STRING_5000unique_NULLABLE\": 'str-2',\n" + + " \"STRUCT_NULLABLE\": {\n" + + " 'BOOLEAN_2unique_NULLABLE': False,\n" + + " 'DATE_1unique_NULLABLE': '2019-01-01',\n" + + " 'DATETIME_1unique_NULLABLE': '2019-01-01 00:00:00',\n" + + " 'FLOAT_5000unique_NULLABLE': 1308,\n" + + " 'FLOAT_5000unique_REPEATED': [2323, 1178],\n" + + " 'FLOAT_5000unique_REQUIRED': 3089,\n" + + " 'INTEGER_5000unique_NULLABLE': '1777',\n" + + " 'NUMERIC_5000unique_NULLABLE': 3323,\n" + + " 'TIME_1unique_NULLABLE': '23:59:59.999999',\n" + + " 'STRING_5000unique_NULLABLE': 'str-49',\n" + + " 'TIMESTAMP_1unique_NULLABLE': '1546387199999999'\n" + + " },\n" + + " \"TIMESTAMP_1unique_NULLABLE\": '1546387199999999',\n" + + " \"TIME_1unique_NULLABLE\": '23:59:59.999999'\n" + + "}]"; + private static final String ENDPOINT_ID = System.getenv("PREDICT_TABLES_REGRESSION_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("PREDICT_TABLES_REGRESSION_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTabularRegression() throws IOException { + // Act + PredictTabularRegressionSample.predictTabularRegression(INSTANCE, PROJECT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Tabular Regression Response"); + } +}