diff --git a/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java b/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java index df0cb32538..804d258f46 100644 --- a/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java +++ b/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java @@ -42,6 +42,7 @@ public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { this.tempLocation = runnerConfigOptions.getTempLocation(); this.maxNumWorkers = runnerConfigOptions.getMaxNumWorkers(); this.deadLetterTableSpec = runnerConfigOptions.getDeadLetterTableSpec(); + this.diskSizeGb = runnerConfigOptions.getDiskSizeGb(); this.labels = runnerConfigOptions.getLabelsMap(); validate(); } @@ -85,6 +86,9 @@ public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { /* BigQuery table specification, e.g. PROJECT_ID:DATASET_ID.PROJECT_ID */ public String deadLetterTableSpec; + /* Disk size to use on each remote Compute Engine worker instance */ + public Integer diskSizeGb; + public Map labels; /** Validates Dataflow runner configuration options */ diff --git a/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java b/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java index d5c69af6a7..925e48aec1 100644 --- a/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java +++ b/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java @@ -42,6 +42,7 @@ public void shouldConvertToPipelineArgs() throws IllegalAccessException { .setUsePublicIps(false) .setWorkerMachineType("n1-standard-1") .setDeadLetterTableSpec("project_id:dataset_id.table_id") + .setDiskSizeGb(100) .putLabels("key", "value") .build(); @@ -60,9 +61,46 @@ public void shouldConvertToPipelineArgs() throws IllegalAccessException { "--usePublicIps=false", "--workerMachineType=n1-standard-1", "--deadLetterTableSpec=project_id:dataset_id.table_id", + "--diskSizeGb=100", "--labels={\"key\":\"value\"}") .toArray(String[]::new); assertThat(args.size(), equalTo(expectedArgs.length)); assertThat(args, containsInAnyOrder(expectedArgs)); } + + @Test + public void shouldIgnoreOptionalArguments() throws IllegalAccessException { + DataflowRunnerConfigOptions opts = + DataflowRunnerConfigOptions.newBuilder() + .setProject("my-project") + .setRegion("asia-east1") + .setZone("asia-east1-a") + .setTempLocation("gs://bucket/tempLocation") + .setNetwork("default") + .setSubnetwork("regions/asia-east1/subnetworks/mysubnetwork") + .setMaxNumWorkers(1) + .setAutoscalingAlgorithm("THROUGHPUT_BASED") + .setUsePublicIps(false) + .setWorkerMachineType("n1-standard-1") + .build(); + + DataflowRunnerConfig dataflowRunnerConfig = new DataflowRunnerConfig(opts); + List args = Lists.newArrayList(dataflowRunnerConfig.toArgs()); + String[] expectedArgs = + Arrays.asList( + "--project=my-project", + "--region=asia-east1", + "--zone=asia-east1-a", + "--tempLocation=gs://bucket/tempLocation", + "--network=default", + "--subnetwork=regions/asia-east1/subnetworks/mysubnetwork", + "--maxNumWorkers=1", + "--autoscalingAlgorithm=THROUGHPUT_BASED", + "--usePublicIps=false", + "--workerMachineType=n1-standard-1", + "--labels={}") + .toArray(String[]::new); + assertThat(args.size(), equalTo(expectedArgs.length)); + assertThat(args, containsInAnyOrder(expectedArgs)); + } } diff --git a/protos/feast/core/Runner.proto b/protos/feast/core/Runner.proto index 1591ecb2bf..0684356f8d 100644 --- a/protos/feast/core/Runner.proto +++ b/protos/feast/core/Runner.proto @@ -77,4 +77,8 @@ message DataflowRunnerConfigOptions { /* Labels to apply to the dataflow job */ map labels = 13; + + /* Disk size to use on each remote Compute Engine worker instance */ + int32 diskSizeGb = 14; + } \ No newline at end of file