Skip to content

Commit

Permalink
add validation for holiday regions
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed May 24, 2022
1 parent 8babe6d commit c0c1ab9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
34 changes: 33 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import datetime
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import abc

Expand Down Expand Up @@ -4518,6 +4518,7 @@ def _run(
"optimizationObjective": self._optimization_objective,
"holidayRegions": holiday_regions,
}
self._validate_training_task_inputs(training_task_inputs_dict)

final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
Expand Down Expand Up @@ -4563,6 +4564,37 @@ def _run(

return new_model

@staticmethod
def _validate_training_task_inputs(training_task_inputs: Dict[str, Any]):
"""Validates the given training task inputs.
Args:
training_task_inputs (Dict[str, Any]):
Required. The training task's input that corresponds to the
training_task_definition parameter.
Raises:
ValueError: If a training task input is invalid.
"""
# TODO(TheMichaelHu): Validate all training task inputs.
if training_task_inputs.get("holidayRegions"):
target_regions = {
"GLOBAL", "NA", "JAPAC", "EMEA", "LAC", "AE", "AR", "AT", "AU",
"BE", "BR", "CA", "CH", "CL", "CN", "CO", "CZ", "DE", "DK",
"DZ", "EC", "EE", "EG", "ES", "FI", "FR", "GB", "GR", "HK",
"HU", "ID", "IE", "IL", "IN", "IR", "IT", "JP", "KR", "LV",
"MA", "MX", "MY", "NG", "NL", "NO", "NZ", "PE", "PH", "PK",
"PL", "PT", "RO", "RS", "RU", "SA", "SE", "SG", "SI", "SK",
"TH", "TR", "TW", "UA", "US", "VE", "VN", "ZA ",
}
for region in training_task_inputs.get("holidayRegions"):
if region.upper() not in target_regions:
raise ValueError(f"Invalid holiday region: {region}.")
if training_task_inputs["dataGranularity"]["unit"].lower() != "day":
raise ValueError(
"Holiday regions are only supported at day-level "
"granularity.")

@property
def _model_upload_fail_string(self) -> str:
"""Helper property for model upload failure."""
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,20 @@ def test_splits_default(
training_pipeline=true_training_pipeline,
timeout=None,
)

def test_validate_with_invalid_holiday_region_fails(self):
with pytest.raises(ValueError) as e:
AutoMLForecastingTrainingJob._validate_training_task_inputs({
"dataGranularity": {"unit": "day"},
"holidayRegions": ["NEPTUNE"],
})
assert e.value.args[0] == "Invalid holiday region: NEPTUNE."

def test_validate_with_wrong_granularity_fails(self):
with pytest.raises(ValueError) as e:
AutoMLForecastingTrainingJob._validate_training_task_inputs({
"dataGranularity": {"unit": "week"},
"holidayRegions": ["GLOBAL"],
})
assert e.value.args[0] == (
"Holiday regions are only supported at day-level granularity.")

0 comments on commit c0c1ab9

Please sign in to comment.