Skip to content

Commit

Permalink
docs(samples): replace deprecated fields in create_training_pipeline_…
Browse files Browse the repository at this point in the history
…tabular_forecasting_sample.py (#981)

* Update create_training_pipeline_tabular_forecasting_sample.py

* Update create_training_pipeline_tabular_forecasting_sample_test.py

Co-authored-by: Anthonios Partheniou <[email protected]>
Co-authored-by: Karl Weinmeister <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2022
1 parent ea16849 commit 9ebc972
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def create_training_pipeline_tabular_forecasting_sample(
target_column: str,
time_series_identifier_column: str,
time_column: str,
static_columns: str,
time_variant_past_only_columns: str,
time_variant_past_and_future_columns: str,
forecast_window_end: int,
time_series_attribute_columns: str,
unavailable_at_forecast: str,
available_at_forecast: str,
forecast_horizon: int,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
Expand All @@ -47,7 +47,7 @@ def create_training_pipeline_tabular_forecasting_sample(
{"auto": {"column_name": "deaths"}},
]

period = {"unit": "day", "quantity": 1}
data_granularity = {"unit": "day", "quantity": 1}

# the inputs should be formatted according to the training_task_definition yaml file
training_task_inputs_dict = {
Expand All @@ -56,13 +56,13 @@ def create_training_pipeline_tabular_forecasting_sample(
"timeSeriesIdentifierColumn": time_series_identifier_column,
"timeColumn": time_column,
"transformations": transformations,
"period": period,
"dataGranularity": data_granularity,
"optimizationObjective": "minimize-rmse",
"trainBudgetMilliNodeHours": 8000,
"staticColumns": static_columns,
"timeVariantPastOnlyColumns": time_variant_past_only_columns,
"timeVariantPastAndFutureColumns": time_variant_past_and_future_columns,
"forecastWindowEnd": forecast_window_end,
"timeSeriesAttributeColumns": time_series_attribute_columns,
"unavailableAtForecast": unavailable_at_forecast,
"availableAtForecast": available_at_forecast,
"forecastHorizon": forecast_horizon,
}

training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_ucaip_generated_create_training_pipeline_sample(capsys, shared_state):
target_column=TARGET_COLUMN,
time_series_identifier_column="county",
time_column="date",
static_columns=["state_name"],
time_variant_past_only_columns=["deaths"],
time_variant_past_and_future_columns=["date"],
forecast_window_end=10,
time_series_attribute_columns=["state_name"],
unavailable_at_forecast=["deaths"],
available_at_forecast=["date"],
forecast_horizon=10,
)

out, _ = capsys.readouterr()
Expand Down

0 comments on commit 9ebc972

Please sign in to comment.