diff --git a/aiplatform/snippets/pom.xml b/aiplatform/snippets/pom.xml index 335f524539a..8ea1e440238 100644 --- a/aiplatform/snippets/pom.xml +++ b/aiplatform/snippets/pom.xml @@ -22,8 +22,6 @@ 1.8 UTF-8 - - @@ -31,6 +29,17 @@ google-cloud-aiplatform 0.0.1-SNAPSHOT + + + com.google.protobuf + protobuf-java-util + 4.0.0-rc-1 + + + com.google.cloud + google-cloud-storage + 1.111.0 + com.google.cloud google-cloud-storage diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java new file mode 100644 index 00000000000..ff3c93ee4d8 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDatasetTextSample.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_dataset_text_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.CreateDatasetOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.Dataset; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateDatasetTextSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetDisplayName = "YOUR_DATASET_DISPLAY_NAME"; + + createDatasetTextSample(project, datasetDisplayName); + } + + static void createDatasetTextSample(String project, String datasetDisplayName) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String metadataSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml"; + + LocationName locationName = LocationName.of(project, location); + Dataset dataset = + Dataset.newBuilder() + .setDisplayName(datasetDisplayName) + .setMetadataSchemaUri(metadataSchemaUri) + .build(); + + OperationFuture datasetFuture = + datasetServiceClient.createDatasetAsync(locationName, dataset); + System.out.format("Operation name: %s\n", datasetFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + Dataset datasetResponse = datasetFuture.get(120, TimeUnit.SECONDS); + + System.out.println("Create Text Dataset Response"); + System.out.format("\tName: %s\n", datasetResponse.getName()); + System.out.format("\tDisplay Name: %s\n", datasetResponse.getDisplayName()); + System.out.format("\tMetadata Schema Uri: %s\n", datasetResponse.getMetadataSchemaUri()); + System.out.format("\tMetadata: %s\n", datasetResponse.getMetadata()); + System.out.format("\tCreate Time: %s\n", datasetResponse.getCreateTime()); + System.out.format("\tUpdate Time: %s\n", datasetResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", datasetResponse.getLabelsMap()); + } + } +} +// [END aiplatform_create_dataset_text_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java new file mode 100644 index 00000000000..c67e49a058e --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java @@ -0,0 +1,235 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextClassificationSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextClassificationSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_classification_1.0.0.yaml"; + String jsonString = "{\"multiLabel\": false}"; + + LocationName locationName = LocationName.of(project, location); + + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Classification Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("\t\tExplanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("\t\t\tParameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("\t\t\t\tSampled Shapley Attribution"); + System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("\t\t\tMetadata"); + System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); + System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "\t\t\t\tFeature Attributions Schema_uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java new file mode 100644 index 00000000000..3aef27086ba --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java @@ -0,0 +1,236 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextEntityExtractionSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextEntityExtractionSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_extraction_1.0.0.yaml"; + String jsonString = "{}"; + + LocationName locationName = LocationName.of(project, location); + + // Training task inputs are empty for text entity extraction + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Entity Extraction Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("\t\tExplanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("\t\t\tParameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("\t\t\t\tSampled Shapley Attribution"); + System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("\t\t\tMetadata"); + System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); + System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "\t\t\t\tFeature Attributions Schema_uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_entity_extraction_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..a6139405d31 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSample.java @@ -0,0 +1,238 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_training_pipeline_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; +import com.google.cloud.aiplatform.v1beta1.EnvVar; +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; +import com.google.cloud.aiplatform.v1beta1.FilterSplit; +import com.google.cloud.aiplatform.v1beta1.FractionSplit; +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; +import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1beta1.Model; +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1beta1.Port; +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; +import com.google.protobuf.Any; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.Status; +import java.io.IOException; +import java.util.List; + +public class CreateTrainingPipelineTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; + + createTrainingPipelineTextSentimentAnalysisSample( + project, trainingPipelineDisplayName, datasetId, modelDisplayName); + } + + static void createTrainingPipelineTextSentimentAnalysisSample( + String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName) + throws IOException { + PipelineServiceSettings pipelineServiceSettings = + PipelineServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PipelineServiceClient pipelineServiceClient = + PipelineServiceClient.create(pipelineServiceSettings)) { + String location = "us-central1"; + String trainingTaskDefinition = + "gs://google-cloud-aiplatform/schema/trainingjob/definition/" + + "automl_text_sentiment_1.0.0.yaml"; + + // Sentiment max must be between 1 and 10 inclusive. + // Higher value means positive sentiment. + String jsonString = "{\"sentimentMax\": 4 }"; + + LocationName locationName = LocationName.of(project, location); + + Value.Builder trainingTaskInputs = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, trainingTaskInputs); + + InputDataConfig trainingInputDataConfig = + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); + TrainingPipeline trainingPipeline = + TrainingPipeline.newBuilder() + .setDisplayName(trainingPipelineDisplayName) + .setTrainingTaskDefinition(trainingTaskDefinition) + .setTrainingTaskInputs(trainingTaskInputs) + .setInputDataConfig(trainingInputDataConfig) + .setModelToUpload(model) + .build(); + + TrainingPipeline trainingPipelineResponse = + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); + + System.out.println("Create Training Pipeline Text Sentiment Analysis Response"); + System.out.format("\tName: %s\n", trainingPipelineResponse.getName()); + System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName()); + + System.out.format( + "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); + System.out.format( + "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); + System.out.format( + "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); + System.out.format("State: %s\n", trainingPipelineResponse.getState()); + + System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime()); + System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime()); + System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime()); + System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime()); + System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap()); + + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); + System.out.println("\tInput Data Config"); + System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId()); + System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); + + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); + System.out.println("\t\tFraction Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction()); + + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); + System.out.println("\t\tFilter Split"); + System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter()); + System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter()); + System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter()); + + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); + System.out.println("\t\tPredefined Split"); + System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey()); + + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); + System.out.println("\t\tTimestamp Split"); + System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction()); + System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction()); + System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction()); + System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey()); + + Model modelResponse = trainingPipelineResponse.getModelToUpload(); + System.out.println("\tModel To Upload"); + System.out.format("\t\tName: %s\n", modelResponse.getName()); + System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName()); + System.out.format("\t\tDescription: %s\n", modelResponse.getDescription()); + + System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); + System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata()); + System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline()); + System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri()); + + System.out.format( + "\t\tSupported Deployment Resources Types: %s\n", + modelResponse.getSupportedDeploymentResourcesTypesList()); + System.out.format( + "\t\tSupported Input Storage Formats: %s\n", + modelResponse.getSupportedInputStorageFormatsList()); + System.out.format( + "\t\tSupported Output Storage Formats: %s\n", + modelResponse.getSupportedOutputStorageFormatsList()); + + System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime()); + System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime()); + System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap()); + + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); + System.out.println("\t\tPredict Schemata"); + System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); + System.out.format( + "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); + System.out.format( + "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); + + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { + System.out.println("\t\tSupported Export Format"); + System.out.format("\t\t\tId: %s\n", exportFormat.getId()); + } + + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); + System.out.println("\t\tContainer Spec"); + System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri()); + System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList()); + System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList()); + System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute()); + System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute()); + + for (EnvVar envVar : modelContainerSpec.getEnvList()) { + System.out.println("\t\t\tEnv"); + System.out.format("\t\t\t\tName: %s\n", envVar.getName()); + System.out.format("\t\t\t\tValue: %s\n", envVar.getValue()); + } + + for (Port port : modelContainerSpec.getPortsList()) { + System.out.println("\t\t\tPort"); + System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort()); + } + + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { + System.out.println("\t\tDeployed Model"); + System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint()); + System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); + } + + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); + System.out.println("\t\tExplanation Spec"); + + ExplanationParameters explanationParameters = explanationSpec.getParameters(); + System.out.println("\t\t\tParameters"); + + SampledShapleyAttribution sampledShapleyAttribution = + explanationParameters.getSampledShapleyAttribution(); + System.out.println("\t\t\t\tSampled Shapley Attribution"); + System.out.format("\t\t\t\t\tPath Count: %s\n", sampledShapleyAttribution.getPathCount()); + + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); + System.out.println("\t\t\tMetadata"); + System.out.format("\t\t\t\tInputs: %s\n", explanationMetadata.getInputsMap()); + System.out.format("\t\t\t\tOutputs: %s\n", explanationMetadata.getOutputsMap()); + System.out.format( + "\t\t\t\tFeature Attributions Schema_uri: %s\n", + explanationMetadata.getFeatureAttributionsSchemaUri()); + + Status status = trainingPipelineResponse.getError(); + System.out.println("\tError"); + System.out.format("\t\tCode: %s\n", status.getCode()); + System.out.format("\t\tMessage: %s\n", status.getMessage()); + } + } +} +// [END aiplatform_create_training_pipeline_text_sentiment_analysis_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java new file mode 100644 index 00000000000..68820041ebe --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextClassificationSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextClassificationSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextClassificationSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Classification Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + + System.out.println("\tModel Explanation"); + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + + System.out.println("\t\tMean Attribution"); + System.out.format( + "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_text_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java new file mode 100644 index 00000000000..d85f9d3f3af --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextEntityExtractionSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextEntityExtractionSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextEntityExtractionSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Entity Extraction Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + + System.out.println("\tModel Explanation"); + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + + System.out.println("\t\tMean Attribution"); + System.out.format( + "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_text_entity_extraction_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..0ccd5286898 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSample.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_get_model_evaluation_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.v1beta1.Attribution; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluation; +import com.google.cloud.aiplatform.v1beta1.ModelEvaluationName; +import com.google.cloud.aiplatform.v1beta1.ModelExplanation; +import com.google.cloud.aiplatform.v1beta1.ModelServiceClient; +import com.google.cloud.aiplatform.v1beta1.ModelServiceSettings; +import java.io.IOException; + +public class GetModelEvaluationTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String modelId = "YOUR_MODEL_ID"; + String evaluationId = "YOUR_EVALUATION_ID"; + + getModelEvaluationTextSentimentAnalysisSample(project, modelId, evaluationId); + } + + static void getModelEvaluationTextSentimentAnalysisSample( + String project, String modelId, String evaluationId) throws IOException { + ModelServiceSettings modelServiceSettings = + ModelServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) { + String location = "us-central1"; + + ModelEvaluationName modelEvaluationName = + ModelEvaluationName.of(project, location, modelId, evaluationId); + ModelEvaluation modelEvaluation = modelServiceClient.getModelEvaluation(modelEvaluationName); + + System.out.println("Get Model Evaluation Text Sentiment Analysis Response"); + System.out.format("\tModel Name: %s\n", modelEvaluation.getName()); + System.out.format("\tMetrics Schema Uri: %s\n", modelEvaluation.getMetricsSchemaUri()); + System.out.format("\tMetrics: %s\n", modelEvaluation.getMetrics()); + System.out.format("\tCreate Time: %s\n", modelEvaluation.getCreateTime()); + System.out.format("\tSlice Dimensions: %s\n", modelEvaluation.getSliceDimensionsList()); + + System.out.println("\tModel Explanation"); + ModelExplanation modelExplanation = modelEvaluation.getModelExplanation(); + for (Attribution attribution : modelExplanation.getMeanAttributionsList()) { + + System.out.println("\t\tMean Attribution"); + System.out.format( + "\t\t\tBaseline Output Value: %s\n", attribution.getBaselineOutputValue()); + System.out.format( + "\t\t\tInstance Output Value: %s\n", attribution.getInstanceOutputValue()); + System.out.format("\t\t\tFeature Attributions: %s\n", attribution.getFeatureAttributions()); + System.out.format("\t\t\tOutput Index: %s\n", attribution.getOutputIndexList()); + System.out.format("\t\t\tOutput Display Name: %s\n", attribution.getOutputDisplayName()); + System.out.format("\t\t\tApproximation Error: %s\n", attribution.getApproximationError()); + } + } + } +} +// [END aiplatform_get_model_evaluation_text_sentiment_analysis_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java new file mode 100644 index 00000000000..cf08bd38ae9 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextClassificationSingleLabelSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_classification_single_label_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextClassificationSingleLabelSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]"; + + importDataTextClassificationSingleLabelSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextClassificationSingleLabelSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_classification_single_label_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format("Import Data Text Classification Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_classification_single_label_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java new file mode 100644 index 00000000000..6bd5c4f0297 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextEntityExtractionSample.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_entity_extraction_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextEntityExtractionSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.jsonl]"; + + importDataTextEntityExtractionSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextEntityExtractionSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_extraction_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format("Import Data Text Entity Extraction Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_entity_extraction_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..e1750678392 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/ImportDataTextSentimentAnalysisSample.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_import_data_text_sentiment_analysis_sample] + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.aiplatform.v1beta1.DatasetName; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceClient; +import com.google.cloud.aiplatform.v1beta1.DatasetServiceSettings; +import com.google.cloud.aiplatform.v1beta1.GcsSource; +import com.google.cloud.aiplatform.v1beta1.ImportDataConfig; +import com.google.cloud.aiplatform.v1beta1.ImportDataOperationMetadata; +import com.google.cloud.aiplatform.v1beta1.ImportDataResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ImportDataTextSentimentAnalysisSample { + + public static void main(String[] args) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String datasetId = "YOUR_DATASET_ID"; + String gcsSourceUri = + "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_text_source/[file.csv/file.jsonl]"; + + importDataTextSentimentAnalysisSample(project, datasetId, gcsSourceUri); + } + + static void importDataTextSentimentAnalysisSample( + String project, String datasetId, String gcsSourceUri) + throws IOException, InterruptedException, ExecutionException, TimeoutException { + DatasetServiceSettings datasetServiceSettings = + DatasetServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (DatasetServiceClient datasetServiceClient = + DatasetServiceClient.create(datasetServiceSettings)) { + String location = "us-central1"; + String importSchemaUri = + "gs://google-cloud-aiplatform/schema/dataset/ioformat/" + + "text_sentiment_io_format_1.0.0.yaml"; + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + gcsSource.addUris(gcsSourceUri); + DatasetName datasetName = DatasetName.of(project, location, datasetId); + + List importDataConfigList = + Collections.singletonList( + ImportDataConfig.newBuilder() + .setGcsSource(gcsSource) + .setImportSchemaUri(importSchemaUri) + .build()); + + OperationFuture importDataResponseFuture = + datasetServiceClient.importDataAsync(datasetName, importDataConfigList); + System.out.format( + "Operation name: %s\n", importDataResponseFuture.getInitialFuture().get().getName()); + + System.out.println("Waiting for operation to finish..."); + ImportDataResponse importDataResponse = importDataResponseFuture.get(300, TimeUnit.SECONDS); + System.out.format("Import Data Text Sentiment Analysis Response: %s\n", + importDataResponse.toString()); + } + } +} +// [END aiplatform_import_data_text_sentiment_analysis_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java new file mode 100644 index 00000000000..84e323f59a5 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_classification_sample] + +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextClassificationSingleLabelSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextClassificationSingleLabel(project, content, endpointId); + } + + static void predictTextClassificationSingleLabel( + String project, String content, String endpointId) throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + String jsonString = "{\"content\": \"" + content + "\"}"; + + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Text Classification Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_text_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java new file mode 100644 index 00000000000..3fb1559ec90 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_entity_extraction_sample] + +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextEntityExtraction(project, content, endpointId); + } + + static void predictTextEntityExtraction(String project, String content, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + String jsonString = "{\"content\": \"" + content + "\"}"; + + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Text Entity Extraction Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_text_entity_extraction_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..8ac212a26ef --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextSentimentAnalysisSample.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_predict_text_sentiment_analysis_sample] + +import com.google.cloud.aiplatform.v1beta1.EndpointName; +import com.google.cloud.aiplatform.v1beta1.PredictResponse; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class PredictTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "YOUR_PROJECT_ID"; + String content = "YOUR_TEXT_CONTENT"; + String endpointId = "YOUR_ENDPOINT_ID"; + + predictTextSentimentAnalysis(project, content, endpointId); + } + + static void predictTextSentimentAnalysis(String project, String content, String endpointId) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint("us-central1-prediction-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + String location = "us-central1"; + String jsonString = "{\"content\": \"" + content + "\"}"; + + EndpointName endpointName = EndpointName.of(project, location, endpointId); + + Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build(); + Value.Builder instance = Value.newBuilder(); + JsonFormat.parser().merge(jsonString, instance); + + List instances = new ArrayList<>(); + instances.add(instance.build()); + + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameter); + System.out.println("Predict Text Sentiment Analysis Response"); + System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); + + System.out.println("Predictions"); + for (Value prediction : predictResponse.getPredictionsList()) { + System.out.format("\tPrediction: %s\n", prediction); + } + } + } +} +// [END aiplatform_predict_text_sentiment_analysis_sample] diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java new file mode 100644 index 00000000000..a4048e5d96d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDatasetTextSampleTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateDatasetTextSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String datasetId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Delete the created dataset + DeleteDatasetSample.deleteDatasetSample(PROJECT, datasetId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Dataset."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateDatasetSample() + throws IOException, InterruptedException, ExecutionException, TimeoutException { + // Act + String datasetDisplayName = + String.format( + "temp_create_dataset_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateDatasetTextSample.createDatasetTextSample(PROJECT, datasetDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(datasetDisplayName); + assertThat(got).contains("Create Text Dataset Response"); + datasetId = got.split("Name: ")[1].split("datasets/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java new file mode 100644 index 00000000000..5b68dab26f6 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_CLASS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineTextClassificationSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTextClassificationSample.createTrainingPipelineTextClassificationSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Classification Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..367b9b3ac5c --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSampleTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = + System.getenv("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_ENTITY_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineTextEntityExtractionSample() throws IOException { + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateTrainingPipelineTextEntityExtractionSample + .createTrainingPipelineTextEntityExtractionSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Entity Extraction Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..bd0f29461bf --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateTrainingPipelineTextSentimentAnalysisSampleTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CreateTrainingPipelineTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String trainingPipelineId; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TRAINING_PIPELINE_TEXT_SENTI_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + // Cancel the Training Pipeline + CancelTrainingPipelineSample.cancelTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Training Pipeline"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Training Pipeline + DeleteTrainingPipelineSample.deleteTrainingPipelineSample(PROJECT, trainingPipelineId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Training Pipeline."); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateTrainingPipelineTextSentimentAnalysisSample() throws IOException { + String tempUuid = UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26); + // Act + String trainingPipelineDisplayName = + String.format( + "temp_create_training_pipeline_test_%s", + tempUuid); + + String modelDisplayName = + String.format( + "temp_create_training_pipeline_model_test_%s", + tempUuid); + + CreateTrainingPipelineTextSentimentAnalysisSample + .createTrainingPipelineTextSentimentAnalysisSample( + PROJECT, trainingPipelineDisplayName, DATASET_ID, modelDisplayName); + + // Assert + String got = bout.toString(); + assertThat(got).contains(DATASET_ID); + assertThat(got).contains("Create Training Pipeline Text Sentiment Analysis Response"); + trainingPipelineId = got.split("Name: ")[1].split("trainingPipelines/")[1].split("\n")[0]; + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java new file mode 100644 index 00000000000..b0b646723a6 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextClassificationSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextClassificationSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("TEXT_CLASS_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("TEXT_CLASS_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_CLASS_MODEL_ID"); + requireEnvVar("TEXT_CLASS_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextClassificationSample() throws IOException { + // Act + GetModelEvaluationTextClassificationSample.getModelEvaluationTextClassificationSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..55d89370254 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextEntityExtractionSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("TEXT_ENTITY_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("TEXT_ENTITY_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_ENTITY_MODEL_ID"); + requireEnvVar("TEXT_ENTITY_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextEntityExtractionSample() throws IOException { + // Act + GetModelEvaluationTextEntityExtractionSample.getModelEvaluationTextEntityExtractionSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Entity Extraction Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..1aeb635914d --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/GetModelEvaluationTextSentimentAnalysisSampleTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GetModelEvaluationTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("TEXT_SENTI_MODEL_ID"); + private static final String EVALUATION_ID = System.getenv("TEXT_SENTI_EVALUATION_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_SENTI_MODEL_ID"); + requireEnvVar("TEXT_SENTI_EVALUATION_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testGetModelEvaluationTextSentimentAnalysisSample() throws IOException { + // Act + GetModelEvaluationTextSentimentAnalysisSample.getModelEvaluationTextSentimentAnalysisSample( + PROJECT, MODEL_ID, EVALUATION_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains(MODEL_ID); + assertThat(got).contains("Get Model Evaluation Text Sentiment Analysis Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java new file mode 100644 index 00000000000..72d5ad90f5f --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/ImportDataSampleTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import io.grpc.StatusRuntimeException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ImportDataSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String DATASET_ID = "000000000000000000000"; + + private static final String GCS_SOURCE_URI = + "gs://automl-cloud-dataset/SMSSpamCollection_train_dataset_2.csv"; + + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testImportDataSample() + throws TimeoutException { + // As import data into dataset can take a long time, instead try to import data into a + // nonexistent dataset and confirm that the model was not found, but other + // elements of the request were valid. + try { + ImportDataTextClassificationSingleLabelSample.importDataTextClassificationSingleLabelSample( + PROJECT, DATASET_ID, GCS_SOURCE_URI); + // Assert + String got = bout.toString(); + assertThat(got).contains("The Dataset does not exist."); + } catch (StatusRuntimeException | ExecutionException | InterruptedException | IOException e) { + assertThat(e.getMessage()).contains("The Dataset does not exist."); + } + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java new file mode 100644 index 00000000000..a47674098a9 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextClassificationSingleLabelSampleTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextClassificationSingleLabelSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = "This is the test String!"; + private static final String ENDPOINT_ID = System.getenv("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_CLASS_SINGLE_LABEL_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextClassification() throws IOException { + // Act + PredictTextClassificationSingleLabelSample.predictTextClassificationSingleLabel( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Classification Response"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..dde83fa06d9 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextEntityExtractionSampleTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextEntityExtractionSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = + "1127526\\tAnalbuminemia in a neonate.\\tA small-for-gestational-age infant , found to have" + + " analbuminemia in the neonatal period , is reported and the twelve cases recorded in" + + " the world literature are reviewed . Patients lacking this serum protein are" + + " essentially asymptomatic , apart from minimal ankle edema and ease of fatigue ." + + " Apparent compensatory mechanisms which come into play when serum albumin is low" + + " include prolonged half-life of albumin and transferrin , an increase in serum" + + " globulins , beta lipoprotein , and glycoproteins , arterial hypotension with reduced" + + " capillary hydrostatic pressure , and the ability to respond with rapid sodium and" + + " chloride diuresis in response to small volume changes . Examination of plasma amino" + + " acids , an investigation not previously reported , revealed an extremely low plasma" + + " tryptophan level , a finding which may be important in view of the role of" + + " tryptophan in albumin synthesis."; + private static final String ENDPOINT_ID = System.getenv("TEXT_ENTITY_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_ENTITY_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextEntityExtraction() throws IOException { + // Act + PredictTextEntityExtractionSample.predictTextEntityExtraction( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Entity Extraction Response"); + } +} + diff --git a/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..1189e391728 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/PredictTextSentimentAnalysisSampleTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class PredictTextSentimentAnalysisSampleTest { + + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String TEXT_CONTENT = + "I was excited at the concept of my favorite comic book hero being on television... and" + + " sorely disappointed at the end result.

The only amazing thing was the" + + " wall crawling (despite the visibility of the cable). I didn't think Nick Hammond was" + + " Peter Parker... and he was visibly of a different build than the guy who did the" + + " stunts in the spider suit. You could tell they were two different actors.
Granted, I can also spot in the modern Spider-Man movies when I am looking at" + + " Tobey Macguire and when I am looking at CGI. But that is from a trained eye and" + + " experience working with CGI. Still, the 70's version could have been better despite" + + " lack of Special FX.

The webs were hokey and looked like ropes that seemed" + + " to wrap around things rather than stick to them. And what was up with giving him a" + + " spider mobile to ride around in. Hello? He's the web slinger people.
Sorry... didn't mean to get so worked up, but our beloved wall crawler deserved" + + " better."; + private static final String ENDPOINT_ID = System.getenv("TEXT_SENTI_ENDPOINT_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_SENTI_ENDPOINT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredictTextSentimentAnalysis() throws IOException { + // Act + PredictTextSentimentAnalysisSample.predictTextSentimentAnalysis( + PROJECT, TEXT_CONTENT, ENDPOINT_ID); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Predict Text Sentiment Analysis Response"); + } +} +