Skip to content

Commit

Permalink
samples: Automl (#1162)
Browse files Browse the repository at this point in the history
* Test push

* Vision AutoML

* Vision AutoML updates + Translate AutoML

* Translate README fixes

* Fixing Kokoro failure issue

* Language AutoML

* Vision AutoML

* Translate AutoML files added

* Triggering tests

* Triggering tests

* Updates based on comments

* Updates after review comments

* Fixed build issue
  • Loading branch information
nirupa-kumar authored and anguillanneuf committed Dec 5, 2022
1 parent 92c4643 commit 7a973d5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@

/**
* Google Cloud AutoML Translate API sample application. Example usage: mvn package exec:java
* -Dexec.mainClass ='com.google.cloud.vision.samples.automl.PredictionApi' -Dexec.args='predict
* [modelId] [path-to-image] [scoreThreshold]'
* -Dexec.mainClass ='com.google.cloud.translate.automl.PredictionApi' -Dexec.args='predict
* [modelId] [file-path]'
*/
public class PredictionApi {

Expand All @@ -61,16 +61,11 @@ public class PredictionApi {
* @param computeRegion the Region name.
* @param modelId the Id of the model which will be used for text classification.
* @param filePath the Local text file path of the content to be classified.
* @param translationAllowFallback set to true to use a Google translation.
* @throws IOException on Input/Output errors.
*/
public static void predict(
String projectId,
String computeRegion,
String modelId,
String filePath,
boolean translationAllowFallback)
throws IOException {
String projectId, String computeRegion, String modelId, String filePath) throws IOException {

// Instantiate client for prediction service.
PredictionServiceClient predictionClient = PredictionServiceClient.create();

Expand All @@ -87,9 +82,6 @@ public static void predict(

// Additional parameters that can be provided for prediction
Map<String, String> params = new HashMap<>();
if (translationAllowFallback) {
params.put("translation_allow_fallback", "True");//Allow Google Translation Model
}

PredictResponse response = predictionClient.predict(name, payload, params);
TextSnippet translatedContent = response.getPayload(0).getTranslation().getTranslatedContent();
Expand All @@ -104,20 +96,17 @@ public static void main(String[] args) throws IOException {
}

public static void argsHelper(String[] args, PrintStream out) throws IOException {
ArgumentParser parser = ArgumentParsers.newFor("PredictionApi")
.build()
.defaultHelp(true)
.description("Prediction API Operation");
ArgumentParser parser =
ArgumentParsers.newFor("PredictionApi")
.build()
.defaultHelp(true)
.description("Prediction API Operation");

Subparsers subparsers = parser.addSubparsers().dest("command");

Subparser predictParser = subparsers.addParser("predict");
predictParser.addArgument("modelId");
predictParser.addArgument("filePath");
predictParser
.addArgument("translationAllowFallback")
.nargs("?")
.type(Boolean.class)
.setDefault(Boolean.FALSE);

String projectId = System.getenv("PROJECT_ID");
String computeRegion = System.getenv("REGION_NAME");
Expand All @@ -126,12 +115,8 @@ public static void argsHelper(String[] args, PrintStream out) throws IOException
try {
ns = parser.parseArgs(args);
if (ns.get("command").equals("predict")) {
predict(
projectId,
computeRegion,
ns.getString("modelId"),
ns.getString("filePath"),
ns.getBoolean("translationAllowFallback"));
predict(projectId, computeRegion, ns.getString("modelId"), ns.getString("filePath"));

}
} catch (ArgumentParserException e) {
parser.handleError(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void tearDown() {
@Test
public void testPredict() throws Exception {
// Act
PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath,FALSE);
PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath);

// Assert
String got = bout.toString();
Expand Down

0 comments on commit 7a973d5

Please sign in to comment.