Skip to content

Commit

Permalink
blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed May 24, 2022
1 parent c0c1ab9 commit d1e8352
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
78 changes: 69 additions & 9 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4579,21 +4579,81 @@ def _validate_training_task_inputs(training_task_inputs: Dict[str, Any]):
# TODO(TheMichaelHu): Validate all training task inputs.
if training_task_inputs.get("holidayRegions"):
target_regions = {
"GLOBAL", "NA", "JAPAC", "EMEA", "LAC", "AE", "AR", "AT", "AU",
"BE", "BR", "CA", "CH", "CL", "CN", "CO", "CZ", "DE", "DK",
"DZ", "EC", "EE", "EG", "ES", "FI", "FR", "GB", "GR", "HK",
"HU", "ID", "IE", "IL", "IN", "IR", "IT", "JP", "KR", "LV",
"MA", "MX", "MY", "NG", "NL", "NO", "NZ", "PE", "PH", "PK",
"PL", "PT", "RO", "RS", "RU", "SA", "SE", "SG", "SI", "SK",
"TH", "TR", "TW", "UA", "US", "VE", "VN", "ZA ",
"GLOBAL",
"NA",
"JAPAC",
"EMEA",
"LAC",
"AE",
"AR",
"AT",
"AU",
"BE",
"BR",
"CA",
"CH",
"CL",
"CN",
"CO",
"CZ",
"DE",
"DK",
"DZ",
"EC",
"EE",
"EG",
"ES",
"FI",
"FR",
"GB",
"GR",
"HK",
"HU",
"ID",
"IE",
"IL",
"IN",
"IR",
"IT",
"JP",
"KR",
"LV",
"MA",
"MX",
"MY",
"NG",
"NL",
"NO",
"NZ",
"PE",
"PH",
"PK",
"PL",
"PT",
"RO",
"RS",
"RU",
"SA",
"SE",
"SG",
"SI",
"SK",
"TH",
"TR",
"TW",
"UA",
"US",
"VE",
"VN",
"ZA ",
}
for region in training_task_inputs.get("holidayRegions"):
if region.upper() not in target_regions:
raise ValueError(f"Invalid holiday region: {region}.")
if training_task_inputs["dataGranularity"]["unit"].lower() != "day":
raise ValueError(
"Holiday regions are only supported at day-level "
"granularity.")
"Holiday regions are only supported at day-level " "granularity."
)

@property
def _model_upload_fail_string(self) -> str:
Expand Down
23 changes: 14 additions & 9 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,17 +1032,22 @@ def test_splits_default(

def test_validate_with_invalid_holiday_region_fails(self):
with pytest.raises(ValueError) as e:
AutoMLForecastingTrainingJob._validate_training_task_inputs({
"dataGranularity": {"unit": "day"},
"holidayRegions": ["NEPTUNE"],
})
AutoMLForecastingTrainingJob._validate_training_task_inputs(
{
"dataGranularity": {"unit": "day"},
"holidayRegions": ["NEPTUNE"],
}
)
assert e.value.args[0] == "Invalid holiday region: NEPTUNE."

def test_validate_with_wrong_granularity_fails(self):
with pytest.raises(ValueError) as e:
AutoMLForecastingTrainingJob._validate_training_task_inputs({
"dataGranularity": {"unit": "week"},
"holidayRegions": ["GLOBAL"],
})
AutoMLForecastingTrainingJob._validate_training_task_inputs(
{
"dataGranularity": {"unit": "week"},
"holidayRegions": ["GLOBAL"],
}
)
assert e.value.args[0] == (
"Holiday regions are only supported at day-level granularity.")
"Holiday regions are only supported at day-level granularity."
)

0 comments on commit d1e8352

Please sign in to comment.