diff --git a/src/errors/types.rs b/src/errors/types.rs index 4986ac313..c61062727 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -336,7 +336,7 @@ error_types! { error: {ctx_type: Cow<'static, str>, ctx_fn: cow_field_from_context}, }, DateFromDatetimeParsing { - error: {ctx_type: String, ctx_fn: field_from_context}, + error: {ctx_type: Cow<'static, str>, ctx_fn: cow_field_from_context}, }, DateFromDatetimeInexact {}, DatePast {}, diff --git a/src/validators/date.rs b/src/validators/date.rs index 53d8a8177..41cd7ce66 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -47,16 +47,12 @@ impl Validator for DateValidator { _definitions: &'data Definitions, _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let date = match input.validate_date(extra.strict.unwrap_or(self.strict)) { + let strict = extra.strict.unwrap_or(self.strict); + let date = match input.validate_date(strict) { Ok(date) => date, - // if the date error was an internal error, return that immediately - Err(ValError::InternalErr(internal_err)) => return Err(ValError::InternalErr(internal_err)), - Err(date_err) => match self.strict { - // if we're in strict mode, we doing try coercing from a date - true => return Err(date_err), - // otherwise, try creating a date from a datetime input - false => date_from_datetime(input, date_err), - }?, + // if the error was a parsing error, in lax mode we allow datetimes at midnight + Err(line_errors @ ValError::LineErrors(..)) if !strict => date_from_datetime(input)?.ok_or(line_errors)?, + Err(otherwise) => return Err(otherwise), }; if let Some(constraints) = &self.constraints { let raw_date = date.as_raw()?; @@ -122,35 +118,31 @@ impl Validator for DateValidator { /// In lax mode, if the input is not a date, we try parsing the input as a datetime, then check it is an /// "exact date", e.g. has a zero time component. -fn date_from_datetime<'data>( - input: &'data impl Input<'data>, - date_err: ValError<'data>, -) -> ValResult<'data, EitherDate<'data>> { +/// +/// Ok(None) means that this is not relevant to dates (the input was not a datetime nor a string) +fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result>, ValError<'data>> { let either_dt = match input.validate_datetime(false, speedate::MicrosecondsPrecisionOverflowBehavior::Truncate) { Ok(dt) => dt, - Err(dt_err) => { - return match dt_err { - ValError::LineErrors(mut line_errors) => { - // if we got a errors while parsing the datetime, - // convert DateTimeParsing -> DateFromDatetimeParsing but keep the rest of the error unchanged - for line_error in &mut line_errors { - match line_error.error_type { - ErrorType::DatetimeParsing { ref error, .. } => { - line_error.error_type = ErrorType::DateFromDatetimeParsing { - error: error.to_string(), - context: None, - }; - } - _ => { - return Err(date_err); - } - } - } - Err(ValError::LineErrors(line_errors)) + // if the error was a parsing error, update the error type from DatetimeParsing to DateFromDatetimeParsing + // and return it + Err(ValError::LineErrors(mut line_errors)) => { + if line_errors.iter_mut().fold(false, |has_parsing_error, line_error| { + if let ErrorType::DatetimeParsing { error, .. } = &mut line_error.error_type { + line_error.error_type = ErrorType::DateFromDatetimeParsing { + error: std::mem::take(error), + context: None, + }; + true + } else { + has_parsing_error } - other => Err(other), - }; + }) { + return Err(ValError::LineErrors(line_errors)); + } + return Ok(None); } + // for any other error, don't return it + Err(_) => return Ok(None), }; let dt = either_dt.as_raw()?; let zero_time = Time { @@ -161,7 +153,7 @@ fn date_from_datetime<'data>( tz_offset: dt.time.tz_offset, }; if dt.time == zero_time { - Ok(EitherDate::Raw(dt.date)) + Ok(Some(EitherDate::Raw(dt.date))) } else { Err(ValError::new(ErrorTypeDefaults::DateFromDatetimeInexact, input)) } diff --git a/tests/conftest.py b/tests/conftest.py index c17a98174..a5f5cc344 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,6 +91,25 @@ class ChosenPyAndJsonValidator(PyAndJsonValidator): return ChosenPyAndJsonValidator +class StrictModeType: + def __init__(self, schema: bool, extra: bool): + assert schema or extra + self.schema = schema + self.validator_args = {'strict': True} if extra else {} + + +@pytest.fixture( + params=[ + StrictModeType(schema=True, extra=False), + StrictModeType(schema=False, extra=True), + StrictModeType(schema=True, extra=True), + ], + ids=['strict-schema', 'strict-extra', 'strict-both'], +) +def strict_mode_type(request) -> StrictModeType: + return request.param + + @pytest.fixture def tmp_work_path(tmp_path: Path): """ diff --git a/tests/validators/test_date.py b/tests/validators/test_date.py index 07e0ee849..616771bf0 100644 --- a/tests/validators/test_date.py +++ b/tests/validators/test_date.py @@ -127,13 +127,13 @@ def test_date_json(py_and_json: PyAndJson, input_value, expected): ], ids=repr, ) -def test_date_strict(input_value, expected): - v = SchemaValidator({'type': 'date', 'strict': True}) +def test_date_strict(input_value, expected, strict_mode_type): + v = SchemaValidator({'type': 'date', 'strict': strict_mode_type.schema}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): - v.validate_python(input_value) + v.validate_python(input_value, **strict_mode_type.validator_args) else: - output = v.validate_python(input_value) + output = v.validate_python(input_value, **strict_mode_type.validator_args) assert output == expected @@ -148,13 +148,13 @@ def test_date_strict(input_value, expected): ('1654646400', Err('Input should be a valid date [type=date_type')), ], ) -def test_date_strict_json(input_value, expected): - v = SchemaValidator({'type': 'date', 'strict': True}) +def test_date_strict_json(input_value, expected, strict_mode_type): + v = SchemaValidator({'type': 'date', 'strict': strict_mode_type.schema}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): - v.validate_json(input_value) + v.validate_json(input_value, **strict_mode_type.validator_args) else: - output = v.validate_json(input_value) + output = v.validate_json(input_value, **strict_mode_type.validator_args) assert output == expected