Skip to content

Commit

Permalink
feat: Add hierarchy and window configs to Vertex Forecasting training…
Browse files Browse the repository at this point in the history
… job (#1255)

Adds support for hierarchical forecasting and window filtering.

---

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)

Fixes b/229907889 b/228499154 🦕
  • Loading branch information
TheMichaelHu authored May 27, 2022
1 parent e82c179 commit 8560fa8
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 1 deletion.
132 changes: 131 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4037,6 +4037,13 @@ def run(
model_display_name: Optional[str] = None,
model_labels: Optional[Dict[str, str]] = None,
additional_experiments: Optional[List[str]] = None,
hierarchy_group_columns: Optional[List[str]] = None,
hierarchy_group_total_weight: Optional[float] = None,
hierarchy_temporal_total_weight: Optional[float] = None,
hierarchy_group_temporal_total_weight: Optional[float] = None,
window_column: Optional[str] = None,
window_stride_length: Optional[int] = None,
window_max_count: Optional[int] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> models.Model:
Expand Down Expand Up @@ -4157,7 +4164,7 @@ def run(
Applies only if [export_evaluated_data_items] is True and
[export_evaluated_data_items_bigquery_destination_uri] is specified.
quantiles (List[float]):
Quantiles to use for the `minimize-quantile-loss`
Quantiles to use for the ``minimize-quantile-loss``
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
this case.
Expand Down Expand Up @@ -4200,6 +4207,37 @@ def run(
Optional. Additional experiment flags for the time series forcasting training.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
hierarchy_group_columns (List[str]):
Optional. A list of time series attribute column names that
define the time series hierarchy. Only one level of hierarchy is
supported, ex. ``region`` for a hierarchy of stores or
``department`` for a hierarchy of products. If multiple columns
are specified, time series will be grouped by their combined
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
to 5 columns are accepted. If no group columns are specified,
all time series are considered to be part of the same group.
hierarchy_group_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
time series in the same hierarchy group.
hierarchy_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
the horizon for a single time series.
hierarchy_group_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
both the horizon and time series in the same hierarchy group.
window_column (str):
Optional. Name of the column that should be used to filter input
rows. The column should contain either booleans or string
booleans; if the value of the row is True, generate a sliding
window from that row.
window_stride_length (int):
Optional. Step length used to generate input examples. Every
``window_stride_length`` rows will be used to generate a sliding
window.
window_max_count (int):
Optional. Number of rows that should be used to generate input
examples. If the total row count is larger than this number, the
input data will be randomly sampled to hit the count.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
Expand Down Expand Up @@ -4254,6 +4292,13 @@ def run(
validation_options=validation_options,
model_display_name=model_display_name,
model_labels=model_labels,
hierarchy_group_columns=hierarchy_group_columns,
hierarchy_group_total_weight=hierarchy_group_total_weight,
hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
window_column=window_column,
window_stride_length=window_stride_length,
window_max_count=window_max_count,
sync=sync,
create_request_timeout=create_request_timeout,
)
Expand Down Expand Up @@ -4286,6 +4331,13 @@ def _run(
budget_milli_node_hours: int = 1000,
model_display_name: Optional[str] = None,
model_labels: Optional[Dict[str, str]] = None,
hierarchy_group_columns: Optional[List[str]] = None,
hierarchy_group_total_weight: Optional[float] = None,
hierarchy_temporal_total_weight: Optional[float] = None,
hierarchy_group_temporal_total_weight: Optional[float] = None,
window_column: Optional[str] = None,
window_stride_length: Optional[int] = None,
window_max_count: Optional[int] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> models.Model:
Expand Down Expand Up @@ -4453,6 +4505,37 @@ def _run(
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
hierarchy_group_columns (List[str]):
Optional. A list of time series attribute column names that
define the time series hierarchy. Only one level of hierarchy is
supported, ex. ``region`` for a hierarchy of stores or
``department`` for a hierarchy of products. If multiple columns
are specified, time series will be grouped by their combined
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
to 5 columns are accepted. If no group columns are specified,
all time series are considered to be part of the same group.
hierarchy_group_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
time series in the same hierarchy group.
hierarchy_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
the horizon for a single time series.
hierarchy_group_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
both the horizon and time series in the same hierarchy group.
window_column (str):
Optional. Name of the column that should be used to filter input
rows. The column should contain either booleans or string
booleans; if the value of the row is True, generate a sliding
window from that row.
window_stride_length (int):
Optional. Step length used to generate input examples. Every
``window_stride_length`` rows will be used to generate a sliding
window.
window_max_count (int):
Optional. Number of rows that should be used to generate input
examples. If the total row count is larger than this number, the
input data will be randomly sampled to hit the count.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
Expand Down Expand Up @@ -4482,6 +4565,12 @@ def _run(
% column_names
)

window_config = self._create_window_config(
column=window_column,
stride_length=window_stride_length,
max_count=window_max_count,
)

training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
Expand All @@ -4505,6 +4594,24 @@ def _run(
"optimizationObjective": self._optimization_objective,
}

# TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
if any(
[
hierarchy_group_columns,
hierarchy_group_total_weight,
hierarchy_temporal_total_weight,
hierarchy_group_temporal_total_weight,
]
):
training_task_inputs_dict["hierarchyConfig"] = {
"groupColumns": hierarchy_group_columns,
"groupTotalWeight": hierarchy_group_total_weight,
"temporalTotalWeight": hierarchy_temporal_total_weight,
"groupTemporalTotalWeight": hierarchy_group_temporal_total_weight,
}
if window_config:
training_task_inputs_dict["windowConfig"] = window_config

final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
"bq://"
Expand Down Expand Up @@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
"""
self._additional_experiments.extend(additional_experiments)

@staticmethod
def _create_window_config(
column: Optional[str] = None,
stride_length: Optional[int] = None,
max_count: Optional[int] = None,
) -> Optional[Dict[str, Union[int, str]]]:
"""Creates a window config from training job arguments."""
configs = {
"column": column,
"strideLength": stride_length,
"maxCount": max_count,
}
present_configs = {k: v for k, v in configs.items() if v is not None}
if not present_configs:
return None
if len(present_configs) > 1:
raise ValueError(
"More than one windowing strategy provided. Make sure only one "
"of window_column, window_stride_length, or window_max_count "
"is specified."
)
return present_configs


class AutoMLImageTrainingJob(_TrainingJob):
_supported_training_schemas = (
Expand Down
Loading

0 comments on commit 8560fa8

Please sign in to comment.