From 799368d3683fb05f87b7b2f9e66b54cbba8e41ea Mon Sep 17 00:00:00 2001 From: Noah Negrey Date: Fri, 25 Oct 2019 13:39:30 -0600 Subject: [PATCH] Fix failing tests: (#1599) * Update samples to use try for clients and update model IDs due to API backend changes * Update to latest library to fix timeout --- language/automl/pom.xml | 2 +- .../cloud/language/samples/DatasetApi.java | 239 ++++++------- .../cloud/language/samples/ModelApi.java | 313 +++++++++--------- .../cloud/language/samples/PredictionApi.java | 37 ++- .../cloud/language/samples/DatasetApiIT.java | 3 +- .../cloud/language/samples/ModelApiIT.java | 3 +- .../language/samples/PredictionApiIT.java | 4 +- 7 files changed, 307 insertions(+), 294 deletions(-) diff --git a/language/automl/pom.xml b/language/automl/pom.xml index a64d7fdd18a..7a96e47166f 100644 --- a/language/automl/pom.xml +++ b/language/automl/pom.xml @@ -40,7 +40,7 @@ com.google.cloud google-cloud-automl - 0.55.1-beta + 0.114.0-beta diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java index 5169ca5c922..5f44f95cf43 100644 --- a/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java +++ b/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java @@ -32,6 +32,7 @@ import java.io.IOException; import java.io.PrintStream; +import java.util.Arrays; import net.sourceforge.argparse4j.ArgumentParsers; import net.sourceforge.argparse4j.inf.ArgumentParser; @@ -62,76 +63,33 @@ public static void createDataset( String projectId, String computeRegion, String datasetName, Boolean multiLabel) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // A resource that represents Google Cloud Platform location. - LocationName projectLocation = LocationName.of(projectId, computeRegion); - - // Classification type assigned based on multilabel value. - ClassificationType classificationType = - multiLabel ? ClassificationType.MULTILABEL : ClassificationType.MULTICLASS; - - // Specify the text classification type for the dataset. - TextClassificationDatasetMetadata textClassificationDatasetMetadata = - TextClassificationDatasetMetadata.newBuilder() - .setClassificationType(classificationType) - .build(); - - // Set dataset name and dataset metadata. - Dataset myDataset = - Dataset.newBuilder() - .setDisplayName(datasetName) - .setTextClassificationDatasetMetadata(textClassificationDatasetMetadata) - .build(); - - // Create a dataset with the dataset metadata in the region. - Dataset dataset = client.createDataset(projectLocation, myDataset); - - // Display the dataset information. - System.out.println(String.format("Dataset name: %s", dataset.getName())); - System.out.println( - String.format( - "Dataset id: %s", - dataset.getName().split("/")[dataset.getName().split("/").length - 1])); - System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); - System.out.println("Text classification dataset metadata:"); - System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); - System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); - System.out.println("Dataset create time:"); - System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); - System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); - } - // [END automl_language_create_dataset] + try (AutoMlClient client = AutoMlClient.create()) { - // [START automl_language_list_datasets] - /** - * Demonstrates using the AutoML client to list all datasets. - * - * @param projectId the Id of the project. - * @param computeRegion the Region name. - * @param filter the Filter expression. - * @throws IOException on Input/Output errors. - */ - public static void listDatasets(String projectId, String computeRegion, String filter) - throws IOException { - // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); - // A resource that represents Google Cloud Platform location. - LocationName projectLocation = LocationName.of(projectId, computeRegion); + // Classification type assigned based on multilabel value. + ClassificationType classificationType = + multiLabel ? ClassificationType.MULTILABEL : ClassificationType.MULTICLASS; - // Build the List datasets request - ListDatasetsRequest request = - ListDatasetsRequest.newBuilder() - .setParent(projectLocation.toString()) - .setFilter(filter) - .build(); + // Specify the text classification type for the dataset. + TextClassificationDatasetMetadata textClassificationDatasetMetadata = + TextClassificationDatasetMetadata.newBuilder() + .setClassificationType(classificationType) + .build(); + + // Set dataset name and dataset metadata. + Dataset myDataset = + Dataset.newBuilder() + .setDisplayName(datasetName) + .setTextClassificationDatasetMetadata(textClassificationDatasetMetadata) + .build(); + + // Create a dataset with the dataset metadata in the region. + Dataset dataset = client.createDataset(projectLocation, myDataset); - // List all the datasets available in the region by applying filter. - System.out.println("List of datasets:"); - for (Dataset dataset : client.listDatasets(request).iterateAll()) { // Display the dataset information. - System.out.println(String.format("\nDataset name: %s", dataset.getName())); + System.out.println(String.format("Dataset name: %s", dataset.getName())); System.out.println( String.format( "Dataset id: %s", @@ -145,6 +103,51 @@ public static void listDatasets(String projectId, String computeRegion, String f System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); } } + // [END automl_language_create_dataset] + + // [START automl_language_list_datasets] + /** + * Demonstrates using the AutoML client to list all datasets. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listDatasets(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + try (AutoMlClient client = AutoMlClient.create()) { + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Build the List datasets request + ListDatasetsRequest request = + ListDatasetsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + // List all the datasets available in the region by applying filter. + System.out.println("List of datasets:"); + for (Dataset dataset : client.listDatasets(request).iterateAll()) { + // Display the dataset information. + System.out.println(String.format("\nDataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Text classification dataset metadata:"); + System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + } + } // [END automl_language_list_datasets] // [START automl_language_get_dataset] @@ -159,27 +162,28 @@ public static void listDatasets(String projectId, String computeRegion, String f public static void getDataset(String projectId, String computeRegion, String datasetId) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // Get the complete path of the dataset. - DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); - - // Get all the information about a given dataset. - Dataset dataset = client.getDataset(datasetFullId); - - // Display the dataset information. - System.out.println(String.format("Dataset name: %s", dataset.getName())); - System.out.println( - String.format( - "Dataset id: %s", - dataset.getName().split("/")[dataset.getName().split("/").length - 1])); - System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); - System.out.println("Text classification dataset metadata:"); - System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); - System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); - System.out.println("Dataset create time:"); - System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); - System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Get all the information about a given dataset. + Dataset dataset = client.getDataset(datasetFullId); + + // Display the dataset information. + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Text classification dataset metadata:"); + System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } } // [END automl_language_get_dataset] @@ -197,25 +201,22 @@ public static void getDataset(String projectId, String computeRegion, String dat public static void importData( String projectId, String computeRegion, String datasetId, String path) throws Exception { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // Get the complete path of the dataset. - DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + try (AutoMlClient client = AutoMlClient.create()) { - GcsSource.Builder gcsSource = GcsSource.newBuilder(); + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); - // Get multiple training data files to be imported - String[] inputUris = path.split(","); - for (String inputUri : inputUris) { - gcsSource.addInputUris(inputUri); - } + // Get multiple training data files to be imported + GcsSource gcsSource = + GcsSource.newBuilder().addAllInputUris(Arrays.asList(path.split(","))).build(); - // Import data from the input URI - InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build(); - System.out.println("Processing import..."); + // Import data from the input URI + InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build(); + System.out.println("Processing import..."); - Empty response = client.importDataAsync(datasetFullId, inputConfig).get(); - System.out.println(String.format("Dataset imported. %s", response)); + Empty response = client.importDataAsync(datasetFullId, inputConfig).get(); + System.out.println(String.format("Dataset imported. %s", response)); + } } // [END automl_language_import_data] @@ -232,20 +233,23 @@ public static void importData( public static void exportData( String projectId, String computeRegion, String datasetId, String gcsUri) throws Exception { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + try (AutoMlClient client = AutoMlClient.create()) { - // Get the complete path of the dataset. - DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); - // Set the output URI. - GcsDestination gcsDestination = GcsDestination.newBuilder().setOutputUriPrefix(gcsUri).build(); + // Set the output URI. + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsUri).build(); - // Export the data to the output URI. - OutputConfig outputConfig = OutputConfig.newBuilder().setGcsDestination(gcsDestination).build(); - System.out.println(String.format("Processing export...")); + // Export the data to the output URI. + OutputConfig outputConfig = + OutputConfig.newBuilder().setGcsDestination(gcsDestination).build(); + System.out.println(String.format("Processing export...")); - Empty response = client.exportDataAsync(datasetFullId, outputConfig).get(); - System.out.println(String.format("Dataset exported. %s", response)); + Empty response = client.exportDataAsync(datasetFullId, outputConfig).get(); + System.out.println(String.format("Dataset exported. %s", response)); + } } // [END automl_language_export_data] @@ -261,15 +265,16 @@ public static void exportData( public static void deleteDataset(String projectId, String computeRegion, String datasetId) throws Exception { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + try (AutoMlClient client = AutoMlClient.create()) { - // Get the complete path of the dataset. - DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); - // Delete a dataset. - Empty response = client.deleteDatasetAsync(datasetFullId).get(); + // Delete a dataset. + Empty response = client.deleteDatasetAsync(datasetFullId).get(); - System.out.println(String.format("Dataset deleted. %s", response)); + System.out.println(String.format("Dataset deleted. %s", response)); + } } // [END automl_language_delete_dataset] diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java index 16883b221cf..f827742ca2a 100644 --- a/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java +++ b/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java @@ -67,30 +67,32 @@ public static void createModel( throws IOException, InterruptedException, ExecutionException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // A resource that represents Google Cloud Platform location. - LocationName projectLocation = LocationName.of(projectId, computeRegion); - - // Set model meta data - TextClassificationModelMetadata textClassificationModelMetadata = - TextClassificationModelMetadata.newBuilder().build(); - - // Set model name, dataset and metadata. - Model myModel = - Model.newBuilder() - .setDisplayName(modelName) - .setDatasetId(dataSetId) - .setTextClassificationModelMetadata(textClassificationModelMetadata) - .build(); - - // Create a model with the model metadata in the region. - OperationFuture response = - client.createModelAsync(projectLocation, myModel); - - System.out.println( - String.format("Training operation name: %s", response.getInitialFuture().get().getName())); - System.out.println("Training started..."); + try (AutoMlClient client = AutoMlClient.create()) { + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Set model meta data + TextClassificationModelMetadata textClassificationModelMetadata = + TextClassificationModelMetadata.newBuilder().build(); + + // Set model name, dataset and metadata. + Model myModel = + Model.newBuilder() + .setDisplayName(modelName) + .setDatasetId(dataSetId) + .setTextClassificationModelMetadata(textClassificationModelMetadata) + .build(); + + // Create a model with the model metadata in the region. + OperationFuture response = + client.createModelAsync(projectLocation, myModel); + + System.out.println( + String.format( + "Training operation name: %s", response.getInitialFuture().get().getName())); + System.out.println("Training started..."); + } } // [END automl_language_create_model] @@ -104,12 +106,13 @@ public static void createModel( */ public static void getOperationStatus(String operationFullId) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + try (AutoMlClient client = AutoMlClient.create()) { - // Get the latest state of a long-running operation. - Operation response = client.getOperationsClient().getOperation(operationFullId); + // Get the latest state of a long-running operation. + Operation response = client.getOperationsClient().getOperation(operationFullId); - System.out.println(String.format("Operation status: %s", response)); + System.out.println(String.format("Operation status: %s", response)); + } } // [END automl_language_get_operation_status] @@ -125,30 +128,31 @@ public static void getOperationStatus(String operationFullId) throws IOException public static void listModels(String projectId, String computeRegion, String filter) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // A resource that represents Google Cloud Platform location. - LocationName projectLocation = LocationName.of(projectId, computeRegion); - - // Create list models request. - ListModelsRequest listModelsRequest = - ListModelsRequest.newBuilder() - .setParent(projectLocation.toString()) - .setFilter(filter) - .build(); - - System.out.println("List of models:"); - for (Model model : client.listModels(listModelsRequest).iterateAll()) { - // Display the model information. - System.out.println(String.format("Model name: %s", model.getName())); - System.out.println( - String.format( - "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); - System.out.println(String.format("Model display name: %s", model.getDisplayName())); - System.out.println("Model create time:"); - System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); - System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); - System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + try (AutoMlClient client = AutoMlClient.create()) { + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Create list models request. + ListModelsRequest listModelsRequest = + ListModelsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + System.out.println("List of models:"); + for (Model model : client.listModels(listModelsRequest).iterateAll()) { + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } } } // [END automl_language_list_models] @@ -165,24 +169,25 @@ public static void listModels(String projectId, String computeRegion, String fil public static void getModel(String projectId, String computeRegion, String modelId) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // Get the full path of the model. - ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); - - // Get complete detail of the model. - Model model = client.getModel(modelFullId); - - // Display the model information. - System.out.println(String.format("Model name: %s", model.getName())); - System.out.println( - String.format( - "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); - System.out.println(String.format("Model display name: %s", model.getDisplayName())); - System.out.println("Model create time:"); - System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); - System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); - System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Get complete detail of the model. + Model model = client.getModel(modelFullId); + + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } } // [END automl_language_get_model] @@ -199,22 +204,23 @@ public static void getModel(String projectId, String computeRegion, String model public static void listModelEvaluations( String projectId, String computeRegion, String modelId, String filter) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // Get the full path of the model. - ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); - - // Create list model evaluations request. - ListModelEvaluationsRequest modelEvaluationsRequest = - ListModelEvaluationsRequest.newBuilder() - .setParent(modelFullId.toString()) - .setFilter(filter) - .build(); - - // List all the model evaluations in the model by applying filter. - for (ModelEvaluation element : - client.listModelEvaluations(modelEvaluationsRequest).iterateAll()) { - System.out.println(element); + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Create list model evaluations request. + ListModelEvaluationsRequest modelEvaluationsRequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // List all the model evaluations in the model by applying filter. + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsRequest).iterateAll()) { + System.out.println(element); + } } } // [END automl_language_list_model_evaluations] @@ -233,16 +239,17 @@ public static void getModelEvaluation( String projectId, String computeRegion, String modelId, String modelEvaluationId) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + try (AutoMlClient client = AutoMlClient.create()) { - // Get the full path of the model evaluation. - ModelEvaluationName modelEvaluationFullId = - ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + // Get the full path of the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); - // Get complete detail of the model evaluation. - ModelEvaluation response = client.getModelEvaluation(modelEvaluationFullId); + // Get complete detail of the model evaluation. + ModelEvaluation response = client.getModelEvaluation(modelEvaluationFullId); - System.out.println(response); + System.out.println(response); + } } // [END automl_language_get_model_evaluation] @@ -259,61 +266,62 @@ public static void getModelEvaluation( public static void displayEvaluation( String projectId, String computeRegion, String modelId, String filter) throws IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); - - // Get the full path of the model. - ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); - - // List all the model evaluations in the model by applying. - ListModelEvaluationsRequest modelEvaluationsrequest = - ListModelEvaluationsRequest.newBuilder() - .setParent(modelFullId.toString()) - .setFilter(filter) - .build(); - - // Iterate through the results. - String modelEvaluationId = ""; - for (ModelEvaluation element : - client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { - if (element.getAnnotationSpecId() != null) { - modelEvaluationId = element.getName().split("/")[element.getName().split("/").length - 1]; + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // List all the model evaluations in the model by applying. + ListModelEvaluationsRequest modelEvaluationsrequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // Iterate through the results. + String modelEvaluationId = ""; + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + if (element.getAnnotationSpecId() != null) { + modelEvaluationId = element.getName().split("/")[element.getName().split("/").length - 1]; + } } - } - - // Resource name for the model evaluation. - ModelEvaluationName modelEvaluationFullId = - ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); - - // Get a model evaluation. - ModelEvaluation modelEvaluation = client.getModelEvaluation(modelEvaluationFullId); - ClassificationEvaluationMetrics classMetrics = - modelEvaluation.getClassificationEvaluationMetrics(); - List confidenceMetricsEntries = - classMetrics.getConfidenceMetricsEntryList(); - - // Showing model score based on threshold of 0.5 - for (ConfidenceMetricsEntry confidenceMetricsEntry : confidenceMetricsEntries) { - if (confidenceMetricsEntry.getConfidenceThreshold() == 0.5) { - System.out.println("Precision and recall are based on a score threshold of 0.5"); - System.out.println( - String.format("Model Precision: %.2f ", confidenceMetricsEntry.getPrecision() * 100) - + '%'); - System.out.println( - String.format("Model Recall: %.2f ", confidenceMetricsEntry.getRecall() * 100) + '%'); - System.out.println( - String.format("Model F1 Score: %.2f ", confidenceMetricsEntry.getF1Score() * 100) - + '%'); - System.out.println( - String.format( - "Model Precision@1: %.2f ", confidenceMetricsEntry.getPrecisionAt1() * 100) - + '%'); - System.out.println( - String.format("Model Recall@1: %.2f ", confidenceMetricsEntry.getRecallAt1() * 100) - + '%'); - System.out.println( - String.format("Model F1 Score@1: %.2f ", confidenceMetricsEntry.getF1ScoreAt1() * 100) - + '%'); + // Resource name for the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + + // Get a model evaluation. + ModelEvaluation modelEvaluation = client.getModelEvaluation(modelEvaluationFullId); + + ClassificationEvaluationMetrics classMetrics = + modelEvaluation.getClassificationEvaluationMetrics(); + List confidenceMetricsEntries = + classMetrics.getConfidenceMetricsEntryList(); + + // Showing model score based on threshold of 0.5 + for (ConfidenceMetricsEntry confidenceMetricsEntry : confidenceMetricsEntries) { + if (confidenceMetricsEntry.getConfidenceThreshold() == 0.5) { + System.out.println("Precision and recall are based on a score threshold of 0.5"); + System.out.println( + String.format("Model Precision: %.2f ", confidenceMetricsEntry.getPrecision() * 100) + + '%'); + System.out.println( + String.format("Model Recall: %.2f ", confidenceMetricsEntry.getRecall() * 100) + '%'); + System.out.println( + String.format("Model F1 Score: %.2f ", confidenceMetricsEntry.getF1Score() * 100) + + '%'); + System.out.println( + String.format( + "Model Precision@1: %.2f ", confidenceMetricsEntry.getPrecisionAt1() * 100) + + '%'); + System.out.println( + String.format("Model Recall@1: %.2f ", confidenceMetricsEntry.getRecallAt1() * 100) + + '%'); + System.out.println( + String.format("Model F1 Score@1: %.2f ", confidenceMetricsEntry.getF1ScoreAt1() * 100) + + '%'); + } } } } @@ -331,15 +339,16 @@ public static void displayEvaluation( public static void deleteModel(String projectId, String computeRegion, String modelId) throws InterruptedException, ExecutionException, IOException { // Instantiates a client - AutoMlClient client = AutoMlClient.create(); + try (AutoMlClient client = AutoMlClient.create()) { - // Get the full path of the model. - ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); - // Delete a model. - Empty response = client.deleteModelAsync(modelFullId).get(); + // Delete a model. + Empty response = client.deleteModelAsync(modelFullId).get(); - System.out.println("Model deletion started..."); + System.out.println("Model deletion started..."); + } } // [END automl_language_delete_model] diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java index 18da9d8c61f..6985470c312 100644 --- a/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java +++ b/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java @@ -60,29 +60,30 @@ public static void predict( String projectId, String computeRegion, String modelId, String filePath) throws IOException { // Create client for prediction service. - PredictionServiceClient predictionClient = PredictionServiceClient.create(); + try (PredictionServiceClient predictionClient = PredictionServiceClient.create()) { - // Get full path of model - ModelName name = ModelName.of(projectId, computeRegion, modelId); + // Get full path of model + ModelName name = ModelName.of(projectId, computeRegion, modelId); - // Read the file content for prediction. - String content = new String(Files.readAllBytes(Paths.get(filePath))); + // Read the file content for prediction. + String content = new String(Files.readAllBytes(Paths.get(filePath))); - // Set the payload by giving the content and type of the file. - TextSnippet textSnippet = - TextSnippet.newBuilder().setContent(content).setMimeType("text/plain").build(); - ExamplePayload payload = ExamplePayload.newBuilder().setTextSnippet(textSnippet).build(); + // Set the payload by giving the content and type of the file. + TextSnippet textSnippet = + TextSnippet.newBuilder().setContent(content).setMimeType("text/plain").build(); + ExamplePayload payload = ExamplePayload.newBuilder().setTextSnippet(textSnippet).build(); - // params is additional domain-specific parameters. - // currently there is no additional parameters supported. - Map params = new HashMap(); - PredictResponse response = predictionClient.predict(name, payload, params); + // params is additional domain-specific parameters. + // currently there is no additional parameters supported. + Map params = new HashMap(); + PredictResponse response = predictionClient.predict(name, payload, params); - System.out.println("Prediction results:"); - for (AnnotationPayload annotationPayload : response.getPayloadList()) { - System.out.println("Predicted Class name :" + annotationPayload.getDisplayName()); - System.out.println( - "Predicted Class Score :" + annotationPayload.getClassification().getScore()); + System.out.println("Prediction results:"); + for (AnnotationPayload annotationPayload : response.getPayloadList()) { + System.out.println("Predicted Class name :" + annotationPayload.getDisplayName()); + System.out.println( + "Predicted Class Score :" + annotationPayload.getClassification().getScore()); + } } } // [END automl_language_predict] diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java index 793c64ed11f..c10c9918cc1 100644 --- a/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java +++ b/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java @@ -39,9 +39,8 @@ public class DatasetApiIT { private static final String DATASET_NAME = "test_language_dataset"; private ByteArrayOutputStream bout; private PrintStream out; - private DatasetApi app; private String datasetId; - private String getdatasetId = "8477830379477056918"; + private String getdatasetId = "TCN8477830379477056918"; @Before public void setUp() { diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java index c142c3ef5b2..49f93ebb089 100644 --- a/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java +++ b/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java @@ -35,9 +35,8 @@ public class ModelApiIT { private static final String COMPUTE_REGION = "us-central1"; private ByteArrayOutputStream bout; private PrintStream out; - private ModelApi app; private String modelId; - private String modelIdGetevaluation = "342705131419266916"; + private String modelIdGetevaluation = "TCN342705131419266916"; private String modelEvaluationId = "3666189665418739402"; @Before diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java index 3549eb244ee..ed5a115852a 100644 --- a/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java +++ b/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java @@ -32,8 +32,8 @@ @SuppressWarnings("checkstyle:abbreviationaswordinname") public class PredictionApiIT { private static final String COMPUTE_REGION = "us-central1"; - private static final String PROJECT_ID = "java-docs-samples-testing"; - private static final String modelId = "342705131419266916"; + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String modelId = "TCN6871084728972835631"; private static final String filePath = "./resources/input.txt"; private ByteArrayOutputStream bout; private PrintStream out;