diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 35341de82..483b1918c 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -22,9 +22,8 @@ use pyo3::PyTypeInfo; use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult}; -use crate::recursion_guard::RecursionGuard; use crate::tools::py_err; -use crate::validators::{CombinedValidator, Extra, Validator}; +use crate::validators::{CombinedValidator, ValidationState, Validator}; use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::{py_error_on_minusone, Input}; @@ -157,15 +156,13 @@ fn validate_iter_to_vec<'a, 's>( capacity: usize, mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>, validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'a, Vec> { let mut output: Vec = Vec::with_capacity(capacity); let mut errors: Vec = Vec::new(); for (index, item_result) in iter.enumerate() { let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?; - match validator.validate(py, item, extra, definitions, recursion_guard) { + match validator.validate(py, item, state) { Ok(item) => { max_length_check.incr()?; output.push(item); @@ -226,14 +223,12 @@ fn validate_iter_to_set<'a, 's>( field_type: &'static str, max_length: Option, validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'a, ()> { let mut errors: Vec = Vec::new(); for (index, item_result) in iter.enumerate() { let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?; - match validator.validate(py, item, extra, definitions, recursion_guard) { + match validator.validate(py, item, state) { Ok(item) => { set.build_add(item)?; if let Some(max_length) = max_length { @@ -315,9 +310,7 @@ impl<'a> GenericIterable<'a> { max_length: Option, field_type: &'static str, validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'a, Vec> { let capacity = self .generic_len() @@ -326,16 +319,7 @@ impl<'a> GenericIterable<'a> { macro_rules! validate { ($iter:expr) => { - validate_iter_to_vec( - py, - $iter, - capacity, - max_length_check, - validator, - extra, - definitions, - recursion_guard, - ) + validate_iter_to_vec(py, $iter, capacity, max_length_check, validator, state) }; } @@ -360,24 +344,11 @@ impl<'a> GenericIterable<'a> { max_length: Option, field_type: &'static str, validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'a, ()> { macro_rules! validate_set { ($iter:expr) => { - validate_iter_to_set( - py, - set, - $iter, - input, - field_type, - max_length, - validator, - extra, - definitions, - recursion_guard, - ) + validate_iter_to_set(py, set, $iter, input, field_type, max_length, validator, state) }; } diff --git a/src/validators/any.rs b/src/validators/any.rs index 7f452ff20..3f10ee34d 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -4,9 +4,7 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; - -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{validation_state::ValidationState, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; /// This might seem useless, but it's useful in DictValidator to avoid Option a lot #[derive(Debug, Clone)] @@ -31,11 +29,8 @@ impl Validator for AnyValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - // Ok(input.clone().into_py(py)) Ok(input.to_object(py)) } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 6496cfe8a..03b2443a7 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -10,10 +10,10 @@ use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::validation_state::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] struct Parameter { @@ -165,9 +165,7 @@ impl Validator for ArgumentsValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let args = input.validate_args()?; @@ -205,9 +203,7 @@ impl Validator for ArgumentsValidator { )); } (Some(pos_value), None) => { - match parameter - .validator - .validate(py, pos_value, extra, definitions, recursion_guard) + match parameter.validator.validate(py, pos_value, state) { Ok(value) => output_args.push(value), Err(ValError::LineErrors(line_errors)) => { @@ -217,9 +213,7 @@ impl Validator for ArgumentsValidator { } } (None, Some((lookup_path, kw_value))) => { - match parameter - .validator - .validate(py, kw_value, extra, definitions, recursion_guard) + match parameter.validator.validate(py, kw_value, state) { Ok(value) => output_kwargs.set_item(parameter.kwarg_key.as_ref().unwrap(), value)?, Err(ValError::LineErrors(line_errors)) => { @@ -231,7 +225,7 @@ impl Validator for ArgumentsValidator { } } (None, None) => { - if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), extra, definitions, recursion_guard)? { + if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), state)? { if let Some(ref kwarg_key) = parameter.kwarg_key { output_kwargs.set_item(kwarg_key, value)?; } else { @@ -261,7 +255,7 @@ impl Validator for ArgumentsValidator { if len > self.positional_params_count { if let Some(ref validator) = self.var_args_validator { for (index, item) in $slice_macro!(args, self.positional_params_count, len).iter().enumerate() { - match validator.validate(py, item, extra, definitions, recursion_guard) { + match validator.validate(py, item, state) { Ok(value) => output_args.push(value), Err(ValError::LineErrors(line_errors)) => { errors.extend( @@ -303,7 +297,7 @@ impl Validator for ArgumentsValidator { }; if !used_kwargs.contains(either_str.as_cow()?.as_ref()) { match self.var_kwargs_validator { - Some(ref validator) => match validator.validate(py, value, extra, definitions, recursion_guard) { + Some(ref validator) => match validator.validate(py, value, state) { Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 137562350..700117feb 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -5,9 +5,7 @@ use crate::build_tools::is_strict; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; - -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct BoolValidator { @@ -36,13 +34,12 @@ impl Validator for BoolValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? - Ok(input.validate_bool(extra.strict.unwrap_or(self.strict))?.into_py(py)) + let strict = state.strict_or(self.strict); + Ok(input.validate_bool(strict)?.into_py(py)) } fn different_strict_behavior( diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 454d06020..f44f0e08a 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -6,10 +6,9 @@ use crate::build_tools::is_strict; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct BytesValidator { @@ -45,11 +44,9 @@ impl Validator for BytesValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?; + let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; Ok(either_bytes.into_py(py)) } @@ -84,11 +81,9 @@ impl Validator for BytesConstrainedValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?; + let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; let len = either_bytes.len()?; if let Some(min_length) = self.min_length { diff --git a/src/validators/call.rs b/src/validators/call.rs index 6e023ea66..940a6e1d9 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -6,10 +6,10 @@ use pyo3::types::{PyDict, PyTuple}; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::validation_state::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct CallValidator { @@ -76,13 +76,9 @@ impl Validator for CallValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let args = self - .arguments_validator - .validate(py, input, extra, definitions, recursion_guard)?; + let args = self.arguments_validator.validate(py, input, state)?; let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) { self.function.call(py, args, Some(kwargs))? @@ -95,7 +91,7 @@ impl Validator for CallValidator { if let Some(return_validator) = &self.return_validator { return_validator - .validate(py, return_value.into_ref(py), extra, definitions, recursion_guard) + .validate(py, return_value.into_ref(py), state) .map_err(|e| e.with_outer_location("return".into())) } else { Ok(return_value.to_object(py)) diff --git a/src/validators/callable.rs b/src/validators/callable.rs index dc57612f5..3793ddc4e 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -4,9 +4,7 @@ use pyo3::types::PyDict; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; - -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct CallableValidator; @@ -30,9 +28,7 @@ impl Validator for CallableValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match input.callable() { true => Ok(input.to_object(py)), diff --git a/src/validators/chain.rs b/src/validators/chain.rs index bc08cd98b..55852834a 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -5,10 +5,10 @@ use pyo3::types::{PyDict, PyList}; use crate::build_tools::py_schema_err; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::validation_state::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct ChainValidator { @@ -74,17 +74,13 @@ impl Validator for ChainValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let mut steps_iter = self.steps.iter(); let first_step = steps_iter.next().unwrap(); - let value = first_step.validate(py, input, extra, definitions, recursion_guard)?; + let value = first_step.validate(py, input, state)?; - steps_iter.try_fold(value, |v, step| { - step.validate(py, v.into_ref(py), extra, definitions, recursion_guard) - }) + steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state)) } fn different_strict_behavior( diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index d61c324c2..2bde6bdb7 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -5,10 +5,10 @@ use pyo3::types::PyDict; use crate::build_tools::py_schema_err; use crate::errors::{ErrorType, PydanticCustomError, PydanticKnownError, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::validation_state::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub enum CustomError { @@ -92,12 +92,10 @@ impl Validator for CustomErrorValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { self.validator - .validate(py, input, extra, definitions, recursion_guard) + .validate(py, input, state) .map_err(|_| self.custom_error.as_val_error(input)) } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 3f1cd7e4e..8d7ee4183 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -10,13 +10,14 @@ use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; use super::model::{create_class, force_setattr, Revalidate}; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, +}; #[derive(Debug, Clone)] struct Field { @@ -132,9 +133,7 @@ impl Validator for DataclassArgsValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let args = input.validate_dataclass_args(&self.dataclass_name)?; @@ -144,162 +143,157 @@ impl Validator for DataclassArgsValidator { let mut errors: Vec = Vec::new(); let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len()); - let extra = Extra { - data: Some(output_dict), - ..*extra - }; - - macro_rules! set_item { - ($field:ident, $value:expr) => {{ - let py_name = $field.py_name.as_ref(py); - if $field.init_only { - if let Some(ref mut init_only_args) = init_only_args { - init_only_args.push($value); - } - } else { - output_dict.set_item(py_name, $value)?; + state.with_new_extra( + Extra { + data: Some(output_dict), + ..*state.extra() + }, + |state| { + macro_rules! set_item { + ($field:ident, $value:expr) => {{ + let py_name = $field.py_name.as_ref(py); + if $field.init_only { + if let Some(ref mut init_only_args) = init_only_args { + init_only_args.push($value); + } + } else { + output_dict.set_item(py_name, $value)?; + } + }}; } - }}; - } - macro_rules! process { - ($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{ - // go through fields getting the value from args or kwargs and validating it - for (index, field) in self.fields.iter().enumerate() { - let mut pos_value = None; - if let Some(args) = $args.args { - if !field.kw_only { - pos_value = $get_macro!(args, index); - } - } + macro_rules! process { + ($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{ + // go through fields getting the value from args or kwargs and validating it + for (index, field) in self.fields.iter().enumerate() { + let mut pos_value = None; + if let Some(args) = $args.args { + if !field.kw_only { + pos_value = $get_macro!(args, index); + } + } - let mut kw_value = None; - if let Some(kwargs) = $args.kwargs { - if let Some((lookup_path, value)) = field.lookup_key.$get_method(kwargs)? { - used_keys.insert(lookup_path.first_key()); - kw_value = Some((lookup_path, value)); - } - } - - match (pos_value, kw_value) { - // found both positional and keyword arguments, error - (Some(_), Some((_, kw_value))) => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::MultipleArgumentValues, - kw_value, - field.name.clone(), - )); - } - // found a positional argument, validate it - (Some(pos_value), None) => { - match field - .validator - .validate(py, pos_value, &extra, definitions, recursion_guard) - { - Ok(value) => set_item!(field, value), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors - .into_iter() - .map(|err| err.with_outer_location(index.into())), - ); + let mut kw_value = None; + if let Some(kwargs) = $args.kwargs { + if let Some((lookup_path, value)) = field.lookup_key.$get_method(kwargs)? { + used_keys.insert(lookup_path.first_key()); + kw_value = Some((lookup_path, value)); } - Err(err) => return Err(err), } - } - // found a keyword argument, validate it - (None, Some((lookup_path, kw_value))) => { - match field - .validator - .validate(py, kw_value, &extra, definitions, recursion_guard) - { - Ok(value) => set_item!(field, value), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors.into_iter().map(|err| { - lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name) - }), - ); + + match (pos_value, kw_value) { + // found both positional and keyword arguments, error + (Some(_), Some((_, kw_value))) => { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::MultipleArgumentValues, + kw_value, + field.name.clone(), + )); + } + // found a positional argument, validate it + (Some(pos_value), None) => match field.validator.validate(py, pos_value, state) { + Ok(value) => set_item!(field, value), + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors + .into_iter() + .map(|err| err.with_outer_location(index.into())), + ); + } + Err(err) => return Err(err), + }, + // found a keyword argument, validate it + (None, Some((lookup_path, kw_value))) => { + match field.validator.validate(py, kw_value, state) { + Ok(value) => set_item!(field, value), + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| { + lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name) + })); + } + Err(err) => return Err(err), + } + } + // found neither, check if there is a default value, otherwise error + (None, None) => { + if let Some(value) = + field + .validator + .default_value(py, Some(field.name.as_str()), state)? + { + set_item!(field, value); + } else { + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name, + )); + } } - Err(err) => return Err(err), } } - // found neither, check if there is a default value, otherwise error - (None, None) => { - if let Some(value) = field.validator.default_value( - py, - Some(field.name.as_str()), - &extra, - definitions, - recursion_guard, - )? { - set_item!(field, value); - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name, - )); + // if there are more args than positional_count, add an error for each one + if let Some(args) = $args.args { + let len = args.len(); + if len > self.positional_count { + for (index, item) in $slice_macro!(args, self.positional_count, len) + .iter() + .enumerate() + { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedPositionalArgument, + item, + index + self.positional_count, + )); + } } } - } - } - // if there are more args than positional_count, add an error for each one - if let Some(args) = $args.args { - let len = args.len(); - if len > self.positional_count { - for (index, item) in $slice_macro!(args, self.positional_count, len).iter().enumerate() { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::UnexpectedPositionalArgument, - item, - index + self.positional_count, - )); - } - } - } - // if there are kwargs check any that haven't been processed yet - if let Some(kwargs) = $args.kwargs { - if kwargs.len() != used_keys.len() { - for (raw_key, value) in kwargs.iter() { - match raw_key.strict_str() { - Ok(either_str) => { - if !used_keys.contains(either_str.as_cow()?.as_ref()) { - // Unknown / extra field - match self.extra_behavior { - ExtraBehavior::Forbid => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::UnexpectedKeywordArgument, - value, - raw_key.as_loc_item(), - )); + // if there are kwargs check any that haven't been processed yet + if let Some(kwargs) = $args.kwargs { + if kwargs.len() != used_keys.len() { + for (raw_key, value) in kwargs.iter() { + match raw_key.strict_str() { + Ok(either_str) => { + if !used_keys.contains(either_str.as_cow()?.as_ref()) { + // Unknown / extra field + match self.extra_behavior { + ExtraBehavior::Forbid => { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedKeywordArgument, + value, + raw_key.as_loc_item(), + )); + } + ExtraBehavior::Ignore => {} + ExtraBehavior::Allow => { + output_dict.set_item(either_str.as_py_string(py), value)? + } + } } - ExtraBehavior::Ignore => {} - ExtraBehavior::Allow => { - output_dict.set_item(either_str.as_py_string(py), value)? + } + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push( + err.with_outer_location(raw_key.as_loc_item()) + .with_type(ErrorTypeDefaults::InvalidKey), + ); } } + Err(err) => return Err(err), } } - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push( - err.with_outer_location(raw_key.as_loc_item()) - .with_type(ErrorTypeDefaults::InvalidKey), - ); - } - } - Err(err) => return Err(err), } } - } + }}; } - }}; - } - match args { - GenericArguments::Py(a) => process!(a, py_get_dict_item, py_get, py_slice), - GenericArguments::Json(a) => process!(a, json_get, json_get, json_slice), - } + match args { + GenericArguments::Py(a) => process!(a, py_get_dict_item, py_get, py_slice), + GenericArguments::Json(a) => process!(a, json_get, json_get, json_slice), + } + Ok(()) + }, + )?; if errors.is_empty() { if let Some(init_only_args) = init_only_args { Ok((output_dict, PyTuple::new(py, init_only_args)).to_object(py)) @@ -317,9 +311,7 @@ impl Validator for DataclassArgsValidator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let dict: &PyDict = obj.downcast()?; @@ -348,14 +340,13 @@ impl Validator for DataclassArgsValidator { return Err(err.into()); } } - let next_extra = Extra { - data: Some(data_dict), - ..*extra - }; - match field - .validator - .validate(py, field_value, &next_extra, definitions, recursion_guard) - { + match state.with_new_extra( + Extra { + data: Some(data_dict), + ..*state.extra() + }, + |state| field.validator.validate(py, field_value, state), + ) { Ok(output) => ok(output), Err(ValError::LineErrors(line_errors)) => { let errors = line_errors @@ -479,13 +470,11 @@ impl Validator for DataclassValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - if let Some(self_instance) = extra.self_instance { + if let Some(self_instance) = state.extra().self_instance { // in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__` - return self.validate_init(py, self_instance, input, extra, definitions, recursion_guard); + return self.validate_init(py, self_instance, input, state); } // same logic as on models @@ -493,16 +482,14 @@ impl Validator for DataclassValidator { if let Some(py_input) = input.input_is_instance(class) { if self.revalidate.should_revalidate(py_input, class) { let input_dict: &PyAny = self.dataclass_to_dict(py, py_input)?; - let val_output = self - .validator - .validate(py, input_dict, extra, definitions, recursion_guard)?; + let val_output = self.validator.validate(py, input_dict, state)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) } else { Ok(input.to_object(py)) } - } else if extra.strict.unwrap_or(self.strict) && input.is_python() { + } else if state.strict_or(self.strict) && input.is_python() { Err(ValError::new( ErrorType::DataclassExactType { class_name: self.get_name().to_string(), @@ -511,9 +498,7 @@ impl Validator for DataclassValidator { input, )) } else { - let val_output = self - .validator - .validate(py, input, extra, definitions, recursion_guard)?; + let val_output = self.validator.validate(py, input, state)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -526,9 +511,7 @@ impl Validator for DataclassValidator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if self.frozen { return Err(ValError::new(ErrorTypeDefaults::FrozenInstance, field_value)); @@ -538,15 +521,9 @@ impl Validator for DataclassValidator { new_dict.set_item(field_name, field_value)?; - let val_assignment_result = self.validator.validate_assignment( - py, - new_dict, - field_name, - field_value, - extra, - definitions, - recursion_guard, - )?; + let val_assignment_result = self + .validator + .validate_assignment(py, new_dict, field_name, field_value, state)?; let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?; @@ -590,19 +567,17 @@ impl DataclassValidator { py: Python<'data>, self_instance: &'s PyAny, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { // we need to set `self_instance` to None for nested validators as we don't want to operate on the self_instance // instance anymore - let new_extra = Extra { - self_instance: None, - ..*extra - }; - let val_output = self - .validator - .validate(py, input, &new_extra, definitions, recursion_guard)?; + let val_output = state.with_new_extra( + Extra { + self_instance: None, + ..*state.extra() + }, + |state| self.validator.validate(py, input, state), + )?; self.set_dict_call(py, self_instance, val_output, input)?; diff --git a/src/validators/date.rs b/src/validators/date.rs index 41cd7ce66..ad64c6b47 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -8,11 +8,10 @@ use crate::build_tools::{is_strict, py_schema_error_type}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::{EitherDate, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use crate::validators::datetime::{NowConstraint, NowOp}; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct DateValidator { @@ -43,11 +42,9 @@ impl Validator for DateValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = extra.strict.unwrap_or(self.strict); + let strict = state.strict_or(self.strict); let date = match input.validate_date(strict) { Ok(date) => date, // if the error was a parsing error, in lax mode we allow datetimes at midnight diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 56b380cf5..4d6b50736 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -11,10 +11,9 @@ use crate::build_tools::{py_schema_err, schema_or_config_same}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::{EitherDateTime, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct DateTimeValidator { @@ -63,11 +62,10 @@ impl Validator for DateTimeValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let datetime = input.validate_datetime(extra.strict.unwrap_or(self.strict), self.microseconds_precision)?; + let strict = state.strict_or(self.strict); + let datetime = input.validate_datetime(strict, self.microseconds_precision)?; if let Some(constraints) = &self.constraints { // if we get an error from as_speedate, it's probably because the input datetime was invalid // specifically had an invalid tzinfo, hence here we return a validation error diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 4ca15b7c1..ff31687fd 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -9,10 +9,9 @@ use crate::errors::ValResult; use crate::errors::{ErrorType, InputValue}; use crate::errors::{ErrorTypeDefaults, Number}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct DecimalValidator { @@ -79,12 +78,10 @@ impl Validator for DecimalValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let decimal = input.validate_decimal( - extra.strict.unwrap_or(self.strict), + state.strict_or(self.strict), // Safety: self and py both outlive this call unsafe { py.from_borrowed_ptr(self.decimal_type.as_ptr()) }, )?; diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 2c4b70175..2183ed246 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -5,10 +5,9 @@ use pyo3::types::{PyDict, PyList}; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct DefinitionsValidatorBuilder; @@ -78,25 +77,23 @@ impl Validator for DefinitionRefValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if let Some(id) = input.identity() { - if recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.validator_id) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } else { - if recursion_guard.incr_depth() { + if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } - let output = validate(self.validator_id, py, input, extra, definitions, recursion_guard); - recursion_guard.remove(id, self.validator_id); - recursion_guard.decr_depth(); + let output = validate(self.validator_id, py, input, state); + state.recursion_guard.remove(id, self.validator_id); + state.recursion_guard.decr_depth(); output } } else { - validate(self.validator_id, py, input, extra, definitions, recursion_guard) + validate(self.validator_id, py, input, state) } } @@ -106,43 +103,23 @@ impl Validator for DefinitionRefValidator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if let Some(id) = obj.identity() { - if recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.validator_id) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } else { - if recursion_guard.incr_depth() { + if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } - let output = validate_assignment( - self.validator_id, - py, - obj, - field_name, - field_value, - extra, - definitions, - recursion_guard, - ); - recursion_guard.remove(id, self.validator_id); - recursion_guard.decr_depth(); + let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.validator_id); + state.recursion_guard.decr_depth(); output } } else { - validate_assignment( - self.validator_id, - py, - obj, - field_name, - field_value, - extra, - definitions, - recursion_guard, - ) + validate_assignment(self.validator_id, py, obj, field_name, field_value, state) } } @@ -176,12 +153,10 @@ fn validate<'s, 'data>( validator_id: usize, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let validator = definitions.get(validator_id).unwrap(); - validator.validate(py, input, extra, definitions, recursion_guard) + let validator = state.definitions.get(validator_id).unwrap(); + validator.validate(py, input, state) } #[allow(clippy::too_many_arguments)] @@ -191,10 +166,8 @@ fn validate_assignment<'data>( obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let validator = definitions.get(validator_id).unwrap(); - validator.validate_assignment(py, obj, field_name, field_value, extra, definitions, recursion_guard) + let validator = state.definitions.get(validator_id).unwrap(); + validator.validate_assignment(py, obj, field_name, field_value, state) } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 933d40b9c..b47d82702 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -8,12 +8,11 @@ use crate::input::{ DictGenericIterator, GenericMapping, Input, JsonObject, JsonObjectGenericIterator, MappingGenericIterator, }; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::any::AnyValidator; use super::list::length_check; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct DictValidator { @@ -70,22 +69,15 @@ impl Validator for DictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let dict = input.validate_dict(extra.strict.unwrap_or(self.strict))?; + let strict = state.strict_or(self.strict); + let dict = input.validate_dict(strict)?; match dict { - GenericMapping::PyDict(py_dict) => { - self.validate_dict(py, input, py_dict, extra, definitions, recursion_guard) - } - GenericMapping::PyMapping(mapping) => { - self.validate_mapping(py, input, mapping, extra, definitions, recursion_guard) - } + GenericMapping::PyDict(py_dict) => self.validate_dict(py, input, py_dict, state), + GenericMapping::PyMapping(mapping) => self.validate_mapping(py, input, mapping, state), GenericMapping::PyGetAttr(_, _) => unreachable!(), - GenericMapping::JsonObject(json_object) => { - self.validate_json_object(py, input, json_object, extra, definitions, recursion_guard) - } + GenericMapping::JsonObject(json_object) => self.validate_json_object(py, input, json_object, state), } } @@ -119,9 +111,7 @@ macro_rules! build_validate { py: Python<'data>, input: &'data impl Input<'data>, dict: &'data $dict_type, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let output = PyDict::new(py); let mut errors: Vec = Vec::new(); @@ -130,7 +120,7 @@ macro_rules! build_validate { let value_validator = self.value_validator.as_ref(); for item_result in <$iter>::new(dict)? { let (key, value) = item_result?; - let output_key = match key_validator.validate(py, key, extra, definitions, recursion_guard) { + let output_key = match key_validator.validate(py, key, state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -145,7 +135,7 @@ macro_rules! build_validate { Err(ValError::Omit) => continue, Err(err) => return Err(err), }; - let output_value = match value_validator.validate(py, value, extra, definitions, recursion_guard) { + let output_value = match value_validator.validate(py, value, state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { diff --git a/src/validators/float.rs b/src/validators/float.rs index 0a9ab51f5..dd58b2ee8 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -5,10 +5,9 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, schema_or_config_same}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; pub struct FloatBuilder; @@ -67,11 +66,10 @@ impl Validator for FloatValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?; + let strict = state.strict_or(self.strict); + let either_float = input.validate_float(strict, state.extra().ultra_strict)?; if !self.allow_inf_nan && !either_float.as_f64().is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } @@ -113,11 +111,10 @@ impl Validator for ConstrainedFloatValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?; + let strict = state.strict_or(self.strict); + let either_float = input.validate_float(strict, state.extra().ultra_strict)?; let float: f64 = either_float.as_f64(); if !self.allow_inf_nan && !float.is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index 54dd694b1..c64a09e39 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -3,12 +3,12 @@ use pyo3::types::{PyDict, PyFrozenSet}; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::list::min_length_check; use super::set::set_build; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::validation_state::ValidationState; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct FrozenSetValidator { @@ -31,11 +31,9 @@ impl Validator for FrozenSetValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let collection = input.validate_frozenset(extra.strict.unwrap_or(self.strict))?; + let collection = input.validate_frozenset(state.strict_or(self.strict))?; let f_set = PyFrozenSet::empty(py)?; collection.validate_to_set( py, @@ -44,9 +42,7 @@ impl Validator for FrozenSetValidator { self.max_length, "Frozenset", &self.item_validator, - extra, - definitions, - recursion_guard, + state, )?; min_length_check!(input, "Frozenset", self.min_length, f_set); Ok(f_set.into_py(py)) diff --git a/src/validators/function.rs b/src/validators/function.rs index b6fbc153e..5e6b57af2 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -8,13 +8,13 @@ use crate::errors::{ }; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; use crate::tools::{function_name, py_err, SchemaDict}; use crate::PydanticUseDefault; use super::generator::InternalValidator; use super::{ - build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, InputType, Validator, + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputType, ValidationState, + Validator, }; struct FunctionInfo { @@ -97,13 +97,10 @@ macro_rules! impl_validator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { - let validate = - move |v: &'data PyAny, e: &Extra| self.validator.validate(py, v, e, definitions, recursion_guard); - self._validate(validate, py, input.to_object(py).into_ref(py), extra) + let validate = |v, s: &mut ValidationState<'_>| self.validator.validate(py, v, s); + self._validate(validate, py, input.to_object(py).into_ref(py), state) } fn validate_assignment<'s, 'data: 's>( &'s self, @@ -111,15 +108,13 @@ macro_rules! impl_validator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let validate = move |v: &'data PyAny, e: &Extra| { + let validate = move |v, s: &mut ValidationState<'_>| { self.validator - .validate_assignment(py, v, field_name, field_value, e, definitions, recursion_guard) + .validate_assignment(py, v, field_name, field_value, s) }; - self._validate(validate, py, obj, extra) + self._validate(validate, py, obj, state) } fn different_strict_behavior( @@ -161,19 +156,19 @@ impl_build!(FunctionBeforeValidator, "function-before"); impl FunctionBeforeValidator { fn _validate<'s, 'data>( &'s self, - mut call: impl FnMut(&'data PyAny, &Extra) -> ValResult<'data, PyObject>, + call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, py: Python<'data>, input: &'data PyAny, - extra: &Extra, + state: &'s mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { let r = if self.info_arg { - let info = ValidationInfo::new(py, extra, &self.config, self.field_name.clone()); + let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), info)) } else { self.func.call1(py, (input.to_object(py),)) }; let value = r.map_err(|e| convert_err(py, e, input))?; - call(value.into_ref(py), extra) + call(value.into_ref(py), state) } } @@ -194,14 +189,14 @@ impl_build!(FunctionAfterValidator, "function-after"); impl FunctionAfterValidator { fn _validate<'s, 'data>( &'s self, - mut call: impl FnMut(&'data PyAny, &Extra) -> ValResult<'data, PyObject>, + call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, py: Python<'data>, input: &'data PyAny, - extra: &Extra, + state: &mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { - let v = call(input, extra)?; + let v = call(input, state)?; let r = if self.info_arg { - let info = ValidationInfo::new(py, extra, &self.config, self.field_name.clone()); + let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (v.to_object(py), info)) } else { self.func.call1(py, (v.to_object(py),)) @@ -255,12 +250,10 @@ impl Validator for FunctionPlainValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let r = if self.info_arg { - let info = ValidationInfo::new(py, extra, &self.config, self.field_name.clone()); + let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), info)) } else { self.func.call1(py, (input.to_object(py),)) @@ -331,10 +324,10 @@ impl FunctionWrapValidator { handler: &'s PyAny, py: Python<'data>, input: &'data PyAny, - extra: &Extra, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let r = if self.info_arg { - let info = ValidationInfo::new(py, extra, &self.config, self.field_name.clone()); + let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), handler, info)) } else { self.func.call1(py, (input.to_object(py), handler)) @@ -354,18 +347,14 @@ impl Validator for FunctionWrapValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let handler = ValidatorCallable { validator: InternalValidator::new( py, "ValidatorCallable", &self.validator, - definitions, - extra, - recursion_guard, + state, self.hide_input_in_errors, ), }; @@ -373,7 +362,7 @@ impl Validator for FunctionWrapValidator { Py::new(py, handler)?.into_ref(py), py, input.to_object(py).into_ref(py), - extra, + state, ) } @@ -383,24 +372,20 @@ impl Validator for FunctionWrapValidator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let handler = AssignmentValidatorCallable { validator: InternalValidator::new( py, "AssignmentValidatorCallable", &self.validator, - definitions, - extra, - recursion_guard, + state, self.hide_input_in_errors, ), updated_field_name: field_name.to_string(), updated_field_value: field_value.to_object(py), }; - self._validate(Py::new(py, handler)?.into_ref(py), py, obj, extra) + self._validate(Py::new(py, handler)?.into_ref(py), py, obj, state) } fn different_strict_behavior( diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 87947f1cb..44fe3b187 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use crate::ValidationError; use super::list::get_items_schema; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, InputType, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputType, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct GeneratorValidator { @@ -55,22 +55,13 @@ impl Validator for GeneratorValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let iterator = input.validate_iter()?; - let validator = self.item_validator.as_ref().map(|v| { - InternalValidator::new( - py, - "ValidatorIterator", - v, - definitions, - extra, - recursion_guard, - self.hide_input_in_errors, - ) - }); + let validator = self + .item_validator + .as_ref() + .map(|v| InternalValidator::new(py, "ValidatorIterator", v, state, self.hide_input_in_errors)); let v_iterator = ValidatorIterator { iterator, @@ -239,21 +230,20 @@ impl InternalValidator { py: Python, name: &str, validator: &CombinedValidator, - definitions: &[CombinedValidator], - extra: &Extra, - recursion_guard: &RecursionGuard, + state: &ValidationState, hide_input_in_errors: bool, ) -> Self { + let extra = state.extra(); Self { name: name.to_string(), validator: validator.clone(), - definitions: definitions.to_vec(), + definitions: state.definitions.to_vec(), data: extra.data.map(|d| d.into_py(py)), strict: extra.strict, from_attributes: extra.from_attributes, context: extra.context.map(|d| d.into_py(py)), self_instance: extra.self_instance.map(|d| d.into_py(py)), - recursion_guard: recursion_guard.clone(), + recursion_guard: state.recursion_guard.clone(), validation_mode: extra.mode, hide_input_in_errors, } @@ -276,16 +266,9 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; + let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); self.validator - .validate_assignment( - py, - model, - field_name, - field_value, - &extra, - &self.definitions, - &mut self.recursion_guard, - ) + .validate_assignment(py, model, field_name, field_value, &mut state) .map_err(|e| { ValidationError::from_val_error( py, @@ -316,18 +299,17 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - self.validator - .validate(py, input, &extra, &self.definitions, &mut self.recursion_guard) - .map_err(|e| { - ValidationError::from_val_error( - py, - self.name.to_object(py), - ErrorMode::Python, - e, - outer_location, - self.hide_input_in_errors, - ) - }) + let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); + self.validator.validate(py, input, &mut state).map_err(|e| { + ValidationError::from_val_error( + py, + self.name.to_object(py), + ErrorMode::Python, + e, + outer_location, + self.hide_input_in_errors, + ) + }) } } diff --git a/src/validators/int.rs b/src/validators/int.rs index 7c802a016..ce9967295 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -6,10 +6,10 @@ use pyo3::types::PyDict; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{Input, Int}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct IntValidator { @@ -48,11 +48,10 @@ impl Validator for IntValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - Ok(input.validate_int(extra.strict.unwrap_or(self.strict))?.into_py(py)) + let either_int = input.validate_int(state.strict_or(self.strict))?; + Ok(either_int.into_py(py)) } fn different_strict_behavior( @@ -89,11 +88,9 @@ impl Validator for ConstrainedIntValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_int = input.validate_int(extra.strict.unwrap_or(self.strict))?; + let either_int = input.validate_int(state.strict_or(self.strict))?; let int_value = either_int.as_int()?; if let Some(ref multiple_of) = self.multiple_of { diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 86487c9eb..86a56d4a8 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -6,10 +6,10 @@ use pyo3::types::{PyDict, PyType}; use crate::build_tools::py_schema_err; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct IsInstanceValidator { @@ -61,9 +61,7 @@ impl Validator for IsInstanceValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if !input.is_python() { return Err(ValError::InternalErr(PyNotImplementedError::new_err( diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index 2e2c92d51..c0af25082 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -4,10 +4,10 @@ use pyo3::types::{PyDict, PyType}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct IsSubclassValidator { @@ -48,9 +48,7 @@ impl Validator for IsSubclassValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match input.input_is_subclass(self.class.as_ref(py))? { true => Ok(input.to_object(py)), diff --git a/src/validators/json.rs b/src/validators/json.rs index 0815fdcce..d4d522577 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -4,10 +4,10 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct JsonValidator { @@ -49,13 +49,11 @@ impl Validator for JsonValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let json_value = input.parse_json()?; match self.validator { - Some(ref validator) => match validator.validate(py, &json_value, extra, definitions, recursion_guard) { + Some(ref validator) => match validator.validate(py, &json_value, state) { Ok(v) => Ok(v), Err(err) => Err(err.duplicate(py)), }, diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 5649abecd..f693a7861 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -5,11 +5,11 @@ use pyo3::types::PyDict; use crate::definitions::DefinitionsBuilder; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::InputType; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, Extra, Validator}; +use super::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, Validator}; #[derive(Debug, Clone)] pub struct JsonOrPython { @@ -55,13 +55,11 @@ impl Validator for JsonOrPython { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - match extra.mode { - InputType::Python => self.python.validate(py, input, extra, definitions, recursion_guard), - InputType::Json => self.json.validate(py, input, extra, definitions, recursion_guard), + match state.extra().mode { + InputType::Python => self.python.validate(py, input, state), + InputType::Json => self.json.validate(py, input, state), } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 75ae12671..d559ee2a8 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -5,10 +5,10 @@ use pyo3::types::PyDict; use crate::build_tools::is_strict; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct LaxOrStrictValidator { @@ -59,16 +59,12 @@ impl Validator for LaxOrStrictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - if extra.strict.unwrap_or(self.strict) { - self.strict_validator - .validate(py, input, extra, definitions, recursion_guard) + if state.strict_or(self.strict) { + self.strict_validator.validate(py, input, state) } else { - self.lax_validator - .validate(py, input, extra, definitions, recursion_guard) + self.lax_validator.validate(py, input, state) } } diff --git a/src/validators/list.rs b/src/validators/list.rs index ae3a0a0a7..ec2640b91 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -3,10 +3,9 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::{GenericIterable, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct ListValidator { @@ -120,23 +119,12 @@ impl Validator for ListValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let seq = input.validate_list(extra.strict.unwrap_or(self.strict))?; + let seq = input.validate_list(state.strict_or(self.strict))?; let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec( - py, - input, - self.max_length, - "List", - v, - extra, - definitions, - recursion_guard, - )?, + Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "List", v, state)?, None => match seq { GenericIterable::List(list) => { length_check!(input, "List", self.min_length, self.max_length, list); diff --git a/src/validators/literal.rs b/src/validators/literal.rs index e3ae3adf6..45c673c78 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -11,10 +11,9 @@ use crate::build_tools::{py_schema_err, py_schema_error_type}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] struct BoolLiteral { @@ -185,9 +184,7 @@ impl Validator for LiteralValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match self.lookup.validate(py, input)? { Some((_, v)) => Ok(v.clone()), diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 9a1b2eaa1..17ec7968d 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -9,7 +9,7 @@ use pyo3::types::{PyAny, PyDict, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::errors::{ErrorMode, LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType}; use crate::py_gc::PyGcTraverse; @@ -55,11 +55,13 @@ mod typed_dict; mod union; mod url; mod uuid; +mod validation_state; mod with_default; pub use with_default::DefaultType; use self::definitions::DefinitionRefValidator; +pub use self::validation_state::ValidationState; #[pyclass(module = "pydantic_core._pydantic_core", name = "Some")] pub struct PySome { @@ -156,7 +158,7 @@ impl SchemaValidator { context: Option<&PyAny>, self_instance: Option<&PyAny>, ) -> PyResult { - let r = self._validate( + self._validate( py, input, InputType::Python, @@ -164,8 +166,9 @@ impl SchemaValidator { from_attributes, context, self_instance, - ); - r.map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python)) + &mut RecursionGuard::default(), + ) + .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python)) } #[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))] @@ -186,6 +189,7 @@ impl SchemaValidator { from_attributes, context, self_instance, + &mut RecursionGuard::default(), ) { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), @@ -204,11 +208,20 @@ impl SchemaValidator { context: Option<&PyAny>, self_instance: Option<&PyAny>, ) -> PyResult { + let recursion_guard = &mut RecursionGuard::default(); match input.parse_json() { - Ok(input) => { - let r = self._validate(py, &input, InputType::Json, strict, None, context, self_instance); - r.map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Json)) - } + Ok(input) => self + ._validate( + py, + &input, + InputType::Json, + strict, + None, + context, + self_instance, + recursion_guard, + ) + .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Json)), Err(err) => Err(self.prepare_validation_err(py, err, ErrorMode::Json)), } } @@ -236,8 +249,9 @@ impl SchemaValidator { }; let guard = &mut RecursionGuard::default(); + let mut state = ValidationState::new(extra, &self.definitions, guard); self.validator - .validate_assignment(py, obj, field_name, field_value, &extra, &self.definitions, guard) + .validate_assignment(py, obj, field_name, field_value, &mut state) .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python)) } @@ -253,9 +267,8 @@ impl SchemaValidator { self_instance: None, }; let recursion_guard = &mut RecursionGuard::default(); - let r = self - .validator - .default_value(py, None::, &extra, &self.definitions, recursion_guard); + let mut state = ValidationState::new(extra, &self.definitions, recursion_guard); + let r = self.validator.default_value(py, None::, &mut state); match r { Ok(maybe_default) => match maybe_default { Some(v) => Ok(PySome::new(v).into_py(py)), @@ -295,17 +308,17 @@ impl SchemaValidator { from_attributes: Option, context: Option<&'data PyAny>, self_instance: Option<&PyAny>, + recursion_guard: &'data mut RecursionGuard, ) -> ValResult<'data, PyObject> where 's: 'data, { - self.validator.validate( - py, - input, - &Extra::new(strict, from_attributes, context, self_instance, mode), + let mut state = ValidationState::new( + Extra::new(strict, from_attributes, context, self_instance, mode), &self.definitions, - &mut RecursionGuard::default(), - ) + recursion_guard, + ); + self.validator.validate(py, input, &mut state) } fn prepare_validation_err(&self, py: Python, error: ValError, error_mode: ErrorMode) -> PyErr { @@ -337,14 +350,13 @@ impl<'py> SelfValidator<'py> { } pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny) -> PyResult<&'py PyAny> { - let extra = Extra::new(None, None, None, None, InputType::Python); - match self.validator.validator.validate( - py, - schema, - &extra, + let mut recursion_guard = RecursionGuard::default(); + let mut state = ValidationState::new( + Extra::new(None, None, None, None, InputType::Python), &self.validator.definitions, - &mut RecursionGuard::default(), - ) { + &mut recursion_guard, + ); + match self.validator.validator.validate(py, schema, &mut state) { Ok(schema_obj) => Ok(schema_obj.into_ref(py)), Err(e) => Err(SchemaError::from_val_error(py, e)), } @@ -675,19 +687,15 @@ pub trait Validator: Send + Sync + Clone + Debug { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject>; /// Get a default value, currently only used by `WithDefaultValidator` - fn default_value<'s, 'data>( - &'s self, + fn default_value<'data>( + &self, _py: Python<'data>, _outer_loc: Option>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, Option> { Ok(None) } @@ -700,9 +708,7 @@ pub trait Validator: Send + Sync + Clone + Debug { _obj: &'data PyAny, _field_name: &'data str, _field_value: &'data PyAny, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let py_err = PyTypeError::new_err(format!("validate_assignment is not supported for {}", self.get_name())); Err(py_err.into()) diff --git a/src/validators/model.rs b/src/validators/model.rs index e1f4327a6..918e96173 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -7,12 +7,13 @@ use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType}; use pyo3::{ffi, intern}; use super::function::convert_err; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, +}; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::{py_error_on_minusone, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::{py_err, SchemaDict}; use crate::PydanticUndefinedType; @@ -106,13 +107,11 @@ impl Validator for ModelValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - if let Some(self_instance) = extra.self_instance { + if let Some(self_instance) = state.extra().self_instance { // in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__` - return self.validate_init(py, self_instance, input, extra, definitions, recursion_guard); + return self.validate_init(py, self_instance, input, state); } // if we're in strict mode, we require an exact instance of the class (from python, with JSON an object is ok) @@ -125,7 +124,7 @@ impl Validator for ModelValidator { let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?; if self.root_model { let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?; - self.validate_construct(py, inner_input, Some(fields_set), extra, definitions, recursion_guard) + self.validate_construct(py, inner_input, Some(fields_set), state) } else { // get dict here so from_attributes logic doesn't apply let dict = py_input.getattr(intern!(py, DUNDER_DICT))?; @@ -138,13 +137,13 @@ impl Validator for ModelValidator { full_model_dict.update(model_extra.downcast()?)?; full_model_dict }; - self.validate_construct(py, inner_input, Some(fields_set), extra, definitions, recursion_guard) + self.validate_construct(py, inner_input, Some(fields_set), state) } } else { Ok(input.to_object(py)) } } else { - self.validate_construct(py, input, None, extra, definitions, recursion_guard) + self.validate_construct(py, input, None, state) } } @@ -154,9 +153,7 @@ impl Validator for ModelValidator { model: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if self.frozen { return Err(ValError::new(ErrorTypeDefaults::FrozenInstance, field_value)); @@ -171,10 +168,7 @@ impl Validator for ModelValidator { field_name.to_string(), )) } else { - let field_extra = Extra { ..*extra }; - let output = self - .validator - .validate(py, field_value, &field_extra, definitions, recursion_guard)?; + let output = self.validator.validate(py, field_value, state)?; force_setattr(py, model, intern!(py, ROOT_FIELD), output)?; Ok(model.into_py(py)) @@ -189,15 +183,9 @@ impl Validator for ModelValidator { } input_dict.set_item(field_name, field_value)?; - let output = self.validator.validate_assignment( - py, - input_dict, - field_name, - field_value, - extra, - definitions, - recursion_guard, - )?; + let output = self + .validator + .validate_assignment(py, input_dict, field_name, field_value, state)?; let (validated_dict, validated_extra, validated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?; @@ -246,20 +234,17 @@ impl ModelValidator { py: Python<'data>, self_instance: &'s PyAny, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { // we need to set `self_instance` to None for nested validators as we don't want to operate on self_instance // anymore - let new_extra = Extra { - self_instance: None, - ..*extra - }; - - let output = self - .validator - .validate(py, input, &new_extra, definitions, recursion_guard)?; + let output = state.with_new_extra( + Extra { + self_instance: None, + ..*state.extra() + }, + |state| self.validator.validate(py, input, state), + )?; if self.root_model { let fields_set = if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { @@ -273,7 +258,7 @@ impl ModelValidator { let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; set_model_attrs(self_instance, model_dict, model_extra, fields_set)?; } - self.call_post_init(py, self_instance.into_py(py), input, extra) + self.call_post_init(py, self_instance.into_py(py), input, state.extra()) } fn validate_construct<'s, 'data>( @@ -281,9 +266,7 @@ impl ModelValidator { py: Python<'data>, input: &'data impl Input<'data>, existing_fields_set: Option<&'data PyAny>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if self.custom_init { // If we wanted, we could introspect the __init__ signature, and store the @@ -296,14 +279,7 @@ impl ModelValidator { } } - let output = if self.root_model { - let field_extra = Extra { ..*extra }; - self.validator - .validate(py, input, &field_extra, definitions, recursion_guard)? - } else { - self.validator - .validate(py, input, extra, definitions, recursion_guard)? - }; + let output = self.validator.validate(py, input, state)?; let instance = create_class(self.class.as_ref(py))?; let instance_ref = instance.as_ref(py); @@ -321,7 +297,7 @@ impl ModelValidator { let fields_set = existing_fields_set.unwrap_or(val_fields_set); set_model_attrs(instance_ref, model_dict, model_extra, fields_set)?; } - self.call_post_init(py, instance, input, extra) + self.call_post_init(py, instance, input, state.extra()) } fn call_post_init<'s, 'data>( diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index cebf8434e..695a44e6b 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -13,10 +13,12 @@ use crate::input::{ MappingGenericIterator, }; use crate::lookup_key::LookupKey; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, Validator}; + +use std::ops::ControlFlow; #[derive(Debug, Clone)] struct Field { @@ -119,12 +121,10 @@ impl Validator for ModelFieldsValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = extra.strict.unwrap_or(self.strict); - let from_attributes = extra.from_attributes.unwrap_or(self.from_attributes); + let strict = state.strict_or(self.strict); + let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes); // we convert the DictType error to a ModelType error let dict = match input.validate_model_fields(strict, from_attributes) { @@ -162,60 +162,74 @@ impl Validator for ModelFieldsValidator { _ => None, }; + macro_rules! control_flow { + ($e: expr) => { + match $e { + Ok(v) => ControlFlow::Continue(v), + Err(err) => ControlFlow::Break(ValError::from(err)), + } + }; + } + macro_rules! process { ($dict:ident, $get_method:ident, $iter:ty $(,$kwargs:ident)?) => {{ - for field in &self.fields { - let extra = Extra { - data: Some(model_dict), - ..*extra - }; - let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { - Ok(v) => v, - Err(err) => { - errors.push(ValLineError::new_with_loc( - ErrorType::GetAttributeError { - error: py_err_string(py, err), - context: None, - }, - input, - field.name.clone(), - )); - continue; - } - }; - if let Some((lookup_path, value)) = op_key_value { - if let Some(ref mut used_keys) = used_keys { - // key is "used" whether or not validation passes, since we want to skip this key in - // extra logic either way - used_keys.insert(lookup_path.first_key()); - } - match field - .validator - .validate(py, value, &extra, definitions, recursion_guard) - { - Ok(value) => { - model_dict.set_item(&field.name_py, value)?; - fields_set_vec.push(field.name_py.clone_ref(py)); + match state.with_new_extra(Extra { + data: Some(model_dict), + ..*state.extra() + }, |state| { + for field in &self.fields { + let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { + Ok(v) => v, + Err(err) => { + errors.push(ValLineError::new_with_loc( + ErrorType::GetAttributeError { + error: py_err_string(py, err), + context: None, + }, + input, + field.name.clone(), + )); + continue; } - Err(ValError::Omit) => continue, - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + }; + if let Some((lookup_path, value)) = op_key_value { + if let Some(ref mut used_keys) = used_keys { + // key is "used" whether or not validation passes, since we want to skip this key in + // extra logic either way + used_keys.insert(lookup_path.first_key()); + } + match field + .validator + .validate(py, value, state) + { + Ok(value) => { + control_flow!(model_dict.set_item(&field.name_py, value))?; + fields_set_vec.push(field.name_py.clone_ref(py)); + } + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + } } + Err(err) => return ControlFlow::Break(err), } - Err(err) => return Err(err), + continue; + } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { + control_flow!(model_dict.set_item(&field.name_py, value))?; + } else { + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); } - continue; - } else if let Some(value) = field.validator.default_value(py, Some(field.name.as_str()), &extra, definitions, recursion_guard)? { - model_dict.set_item(&field.name_py, value)?; - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); } + ControlFlow::Continue(()) + }) { + ControlFlow::Continue(()) => {} + ControlFlow::Break(err) => return Err(err), } if let Some(ref mut used_keys) = used_keys { @@ -252,7 +266,7 @@ impl Validator for ModelFieldsValidator { ExtraBehavior::Allow => { let py_key = either_str.as_py_string(py); if let Some(ref validator) = self.extra_validator { - match validator.validate(py, value, &extra, definitions, recursion_guard) { + match validator.validate(py, value, state) { Ok(value) => { model_extra_dict.set_item(py_key, value)?; fields_set_vec.push(py_key.into_py(py)); @@ -305,9 +319,7 @@ impl Validator for ModelFieldsValidator { obj: &'data PyAny, field_name: &'data str, field_value: &'data PyAny, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let dict: &PyDict = obj.downcast()?; @@ -337,9 +349,9 @@ impl Validator for ModelFieldsValidator { } } - let extra = Extra { + let new_extra = Extra { data: Some(data_dict), - ..*extra + ..*state.extra() }; let new_data = if let Some(field) = self.fields.iter().find(|f| f.name == field_name) { @@ -351,9 +363,7 @@ impl Validator for ModelFieldsValidator { )) } else { prepare_result( - field - .validator - .validate(py, field_value, &extra, definitions, recursion_guard), + state.with_new_extra(new_extra, |state| field.validator.validate(py, field_value, state)), ) } } else { @@ -364,9 +374,9 @@ impl Validator for ModelFieldsValidator { // unless the user explicitly set extra_behavior to 'allow' match self.extra_behavior { ExtraBehavior::Allow => match self.extra_validator { - Some(ref validator) => { - prepare_result(validator.validate(py, field_value, &extra, definitions, recursion_guard)) - } + Some(ref validator) => prepare_result( + state.with_new_extra(new_extra, |state| validator.validate(py, field_value, state)), + ), None => get_updated_dict(field_value.to_object(py)), }, ExtraBehavior::Forbid | ExtraBehavior::Ignore => { diff --git a/src/validators/none.rs b/src/validators/none.rs index ff1c94b35..ce04040b4 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -3,9 +3,8 @@ use pyo3::types::PyDict; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct NoneValidator; @@ -29,9 +28,7 @@ impl Validator for NoneValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - _extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + _state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index ae460c634..af8bd860f 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -4,10 +4,10 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; #[derive(Debug, Clone)] pub struct NullableValidator { @@ -37,13 +37,11 @@ impl Validator for NullableValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), - false => self.validator.validate(py, input, extra, definitions, recursion_guard), + false => self.validator.validate(py, input, state), } } diff --git a/src/validators/set.rs b/src/validators/set.rs index 68920447d..c0d4224b3 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -3,11 +3,10 @@ use pyo3::types::{PyDict, PySet}; use crate::errors::ValResult; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::list::min_length_check; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct SetValidator { @@ -62,23 +61,11 @@ impl Validator for SetValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let collection = input.validate_set(extra.strict.unwrap_or(self.strict))?; + let collection = input.validate_set(state.strict_or(self.strict))?; let set = PySet::empty(py)?; - collection.validate_to_set( - py, - set, - input, - self.max_length, - "Set", - &self.item_validator, - extra, - definitions, - recursion_guard, - )?; + collection.validate_to_set(py, set, input, self.max_length, "Set", &self.item_validator, state)?; min_length_check!(input, "Set", self.min_length, set); Ok(set.into_py(py)) } diff --git a/src/validators/string.rs b/src/validators/string.rs index 4af1fb078..8b296c84a 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -6,10 +6,9 @@ use regex::Regex; use crate::build_tools::{is_strict, py_schema_error_type, schema_or_config}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct StrValidator { @@ -44,11 +43,10 @@ impl Validator for StrValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - Ok(input.validate_str(extra.strict.unwrap_or(self.strict))?.into_py(py)) + let either_str = input.validate_str(state.strict_or(self.strict))?; + Ok(either_str.into_py(py)) } fn different_strict_behavior( @@ -87,11 +85,9 @@ impl Validator for StrConstrainedValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(extra.strict.unwrap_or(self.strict))?; + let either_str = input.validate_str(state.strict_or(self.strict))?; let cow = either_str.as_cow()?; let mut str = cow.as_ref(); if self.strip_whitespace { diff --git a/src/validators/time.rs b/src/validators/time.rs index 774b588f0..004ee54ee 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -7,12 +7,11 @@ use speedate::Time; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{EitherTime, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::datetime::extract_microseconds_precision; use super::datetime::TZConstraint; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct TimeValidator { @@ -45,11 +44,9 @@ impl Validator for TimeValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let time = input.validate_time(extra.strict.unwrap_or(self.strict), self.microseconds_precision)?; + let time = input.validate_time(state.strict_or(self.strict), self.microseconds_precision)?; if let Some(constraints) = &self.constraints { let raw_time = time.as_raw()?; diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 55da51d82..878a04048 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -5,10 +5,9 @@ use speedate::Duration; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{duration_as_pytimedelta, EitherTimedelta, Input}; -use crate::recursion_guard::RecursionGuard; use super::datetime::extract_microseconds_precision; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct TimeDeltaValidator { @@ -70,11 +69,9 @@ impl Validator for TimeDeltaValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let timedelta = input.validate_timedelta(extra.strict.unwrap_or(self.strict), self.microseconds_precision)?; + let timedelta = input.validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)?; let py_timedelta = timedelta.try_into_py(py)?; if let Some(constraints) = &self.constraints { let raw_timedelta = timedelta.to_duration()?; diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index de5fb13c9..55c283829 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -5,11 +5,10 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericIterable, Input}; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::list::{get_items_schema, min_length_check}; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct TupleVariableValidator { @@ -49,23 +48,12 @@ impl Validator for TupleVariableValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let seq = input.validate_tuple(extra.strict.unwrap_or(self.strict))?; + let seq = input.validate_tuple(state.strict_or(self.strict))?; let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec( - py, - input, - self.max_length, - "Tuple", - v, - extra, - definitions, - recursion_guard, - )?, + Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "Tuple", v, state)?, None => seq.to_vec(py, input, "Tuple", self.max_length)?, }; min_length_check!(input, "Tuple", self.min_length, output); @@ -143,9 +131,7 @@ impl BuildValidator for TuplePositionalValidator { fn validate_tuple_positional<'s, 'data, T: Iterator>, I: Input<'data> + 'data>( py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, output: &mut Vec, errors: &mut Vec>, extra_validator: &Option>, @@ -156,7 +142,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, ) -> ValResult<'data, ()> { for (index, validator) in items_validators.iter().enumerate() { match collection_iter.next() { - Some(result) => match validator.validate(py, result?, extra, definitions, recursion_guard) { + Some(result) => match validator.validate(py, result?, state) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); @@ -164,7 +150,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, Err(err) => return Err(err), }, None => { - if let Some(value) = validator.default_value(py, Some(index), extra, definitions, recursion_guard)? { + if let Some(value) = validator.default_value(py, Some(index), state)? { output.push(value); } else { errors.push(ValLineError::new_with_loc(ErrorTypeDefaults::Missing, input, index)); @@ -175,20 +161,18 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, for (index, result) in collection_iter.enumerate() { let item = result?; match extra_validator { - Some(ref extra_validator) => { - match extra_validator.validate(py, item, extra, definitions, recursion_guard) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors - .into_iter() - .map(|err| err.with_outer_location((index + expected_length).into())), - ); - } - Err(ValError::Omit) => (), - Err(err) => return Err(err), + Some(ref extra_validator) => match extra_validator.validate(py, item, state) { + Ok(item) => output.push(item), + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors + .into_iter() + .map(|err| err.with_outer_location((index + expected_length).into())), + ); } - } + Err(ValError::Omit) => (), + Err(err) => return Err(err), + }, None => { errors.push(ValLineError::new( ErrorType::TooLong { @@ -217,11 +201,9 @@ impl Validator for TuplePositionalValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let collection = input.validate_tuple(extra.strict.unwrap_or(self.strict))?; + let collection = input.validate_tuple(state.strict_or(self.strict))?; let expected_length = self.items_validators.len(); let collection_len = collection.generic_len(); @@ -233,9 +215,7 @@ impl Validator for TuplePositionalValidator { validate_tuple_positional( py, input, - extra, - definitions, - recursion_guard, + state, &mut output, &mut errors, &self.extra_validator, diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 8083f74fe..6066596c0 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -12,10 +12,13 @@ use crate::input::{ MappingGenericIterator, }; use crate::lookup_key::LookupKey; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, +}; + +use std::ops::ControlFlow; #[derive(Debug, Clone)] struct TypedDictField { @@ -144,11 +147,9 @@ impl Validator for TypedDictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = extra.strict.unwrap_or(self.strict); + let strict = state.strict_or(self.strict); let dict = input.validate_dict(strict)?; let output_dict = PyDict::new(py); @@ -162,59 +163,70 @@ impl Validator for TypedDictValidator { _ => None, }; + macro_rules! control_flow { + ($e: expr) => { + match $e { + Ok(v) => ControlFlow::Continue(v), + Err(err) => ControlFlow::Break(ValError::from(err)), + } + }; + } + macro_rules! process { ($dict:ident, $get_method:ident, $iter:ty $(,$kwargs:ident)?) => {{ - for field in &self.fields { - let extra = Extra { - data: Some(output_dict), - ..*extra - }; - let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { - Ok(v) => v, - Err(err) => { - errors.push(ValLineError::new_with_loc( - ErrorType::GetAttributeError { - error: py_err_string(py, err), - context: None, - }, - input, - field.name.clone(), - )); - continue; - } - }; - if let Some((lookup_path, value)) = op_key_value { - if let Some(ref mut used_keys) = used_keys { - // key is "used" whether or not validation passes, since we want to skip this key in - // extra logic either way - used_keys.insert(lookup_path.first_key()); - } - match field - .validator - .validate(py, value, &extra, definitions, recursion_guard) - { - Ok(value) => { - output_dict.set_item(&field.name_py, value)?; + match state.with_new_extra(Extra { + data: Some(output_dict), + ..*state.extra() + }, |state| { + for field in &self.fields { + let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { + Ok(v) => v, + Err(err) => { + errors.push(ValLineError::new_with_loc( + ErrorType::GetAttributeError { + error: py_err_string(py, err), + context: None, + }, + input, + field.name.clone(), + )); + continue; } - Err(ValError::Omit) => continue, - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + }; + if let Some((lookup_path, value)) = op_key_value { + if let Some(ref mut used_keys) = used_keys { + // key is "used" whether or not validation passes, since we want to skip this key in + // extra logic either way + used_keys.insert(lookup_path.first_key()); + } + match field.validator.validate(py, value, state) { + Ok(value) => { + control_flow!(output_dict.set_item(&field.name_py, value))?; } + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + } + } + Err(err) => return ControlFlow::Break(err), } - Err(err) => return Err(err), + continue; + } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { + control_flow!(output_dict.set_item(&field.name_py, value))?; + } else if field.required { + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); } - continue; - } else if let Some(value) = field.validator.default_value(py, Some(field.name.as_str()), &extra, definitions, recursion_guard)? { - output_dict.set_item(&field.name_py, value)?; - } else if field.required { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); } + ControlFlow::Continue(()) + }) { + ControlFlow::Continue(()) => {} + ControlFlow::Break(err) => return Err(err), } if let Some(ref mut used_keys) = used_keys { @@ -250,7 +262,7 @@ impl Validator for TypedDictValidator { ExtraBehavior::Allow => { let py_key = either_str.as_py_string(py); if let Some(ref validator) = self.extra_validator { - match validator.validate(py, value, &extra, definitions, recursion_guard) { + match validator.validate(py, value, state) { Ok(value) => { output_dict.set_item(py_key, value)?; } diff --git a/src/validators/union.rs b/src/validators/union.rs index b9451725d..c3dfdd976 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -10,12 +10,11 @@ use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input}; use crate::lookup_key::LookupKey; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct UnionValidator { @@ -104,60 +103,52 @@ impl Validator for UnionValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if self.ultra_strict_required { // do an ultra strict check first - let ultra_strict_extra = extra.as_strict(true); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| { - validator.validate(py, input, &ultra_strict_extra, definitions, recursion_guard) - }) - .find(ValResult::is_ok) - { + if let Some(res) = state.with_new_extra(state.extra().as_strict(true), |state| { + self.choices + .iter() + .map(|(validator, _label)| validator.validate(py, input, state)) + .find(ValResult::is_ok) + }) { return res; } } - if extra.strict.unwrap_or(self.strict) { + if state.strict_or(self.strict) { let mut errors: Option> = match self.custom_error { None => Some(Vec::with_capacity(self.choices.len())), _ => None, }; - let strict_extra = extra.as_strict(false); - - for (validator, label) in &self.choices { - let line_errors = match validator.validate(py, input, &strict_extra, definitions, recursion_guard) { - Err(ValError::LineErrors(line_errors)) => line_errors, - otherwise => return otherwise, - }; - - if let Some(ref mut errors) = errors { - errors.extend(line_errors.into_iter().map(|err| { - let case_label = label.as_deref().unwrap_or(validator.get_name()); - err.with_outer_location(case_label.into()) - })); + state.with_new_extra(state.extra().as_strict(false), |state| { + for (validator, label) in &self.choices { + let line_errors = match validator.validate(py, input, state) { + Err(ValError::LineErrors(line_errors)) => line_errors, + otherwise => return otherwise, + }; + + if let Some(ref mut errors) = errors { + errors.extend(line_errors.into_iter().map(|err| { + let case_label = label.as_deref().unwrap_or(validator.get_name()); + err.with_outer_location(case_label.into()) + })); + } } - } - Err(self.or_custom_error(errors, input)) + Err(self.or_custom_error(errors, input)) + }) } else { if self.strict_required { // 1st pass: check if the value is an exact instance of one of the Union types, // e.g. use validate in strict mode - let strict_extra = extra.as_strict(false); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| { - validator.validate(py, input, &strict_extra, definitions, recursion_guard) - }) - .find(ValResult::is_ok) - { + if let Some(res) = state.with_new_extra(state.extra().as_strict(false), |state| { + self.choices + .iter() + .map(|(validator, _label)| validator.validate(py, input, state)) + .find(ValResult::is_ok) + }) { return res; } } @@ -169,7 +160,7 @@ impl Validator for UnionValidator { // 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate for (validator, label) in &self.choices { - let line_errors = match validator.validate(py, input, extra, definitions, recursion_guard) { + let line_errors = match validator.validate(py, input, state) { Err(ValError::LineErrors(line_errors)) => line_errors, success => return success, }; @@ -329,9 +320,7 @@ impl Validator for TaggedUnionValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match self.discriminator { Discriminator::LookupKey(ref lookup_key) => { @@ -347,7 +336,7 @@ impl Validator for TaggedUnionValidator { } }}; } - let from_attributes = extra.from_attributes.unwrap_or(self.from_attributes); + let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes); let dict = input.validate_model_fields(self.strict, from_attributes)?; let tag = match dict { GenericMapping::PyDict(dict) => find_validator!(py_get_dict_item, dict), @@ -355,24 +344,19 @@ impl Validator for TaggedUnionValidator { GenericMapping::PyMapping(mapping) => find_validator!(py_get_mapping_item, mapping), GenericMapping::JsonObject(mapping) => find_validator!(json_get, mapping), }?; - self.find_call_validator(py, tag, input, extra, definitions, recursion_guard) + self.find_call_validator(py, tag, input, state) } Discriminator::Function(ref func) => { let tag = func.call1(py, (input.to_object(py),))?; if tag.is_none(py) { Err(self.tag_not_found(input)) } else { - self.find_call_validator(py, tag.into_ref(py), input, extra, definitions, recursion_guard) + self.find_call_validator(py, tag.into_ref(py), input, state) } } - Discriminator::SelfSchema => self.find_call_validator( - py, - self.self_schema_tag(py, input)?.as_ref(), - input, - extra, - definitions, - recursion_guard, - ), + Discriminator::SelfSchema => { + self.find_call_validator(py, self.self_schema_tag(py, input)?.as_ref(), input, state) + } } } @@ -450,12 +434,10 @@ impl TaggedUnionValidator { py: Python<'data>, tag: &'data PyAny, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { - return match validator.validate(py, input, extra, definitions, recursion_guard) { + return match validator.validate(py, input, state) { Ok(res) => Ok(res), Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)), }; diff --git a/src/validators/url.rs b/src/validators/url.rs index 8c30a1656..f3be1e81d 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -12,12 +12,11 @@ use url::{ParseError, SyntaxViolation, Url}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use crate::url::{schema_is_special, PyMultiHostUrl, PyUrl}; use super::literal::expected_repr_name; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; type AllowedSchemas = Option<(AHashSet, String)>; @@ -64,11 +63,9 @@ impl Validator for UrlValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let mut lib_url = self.get_url(input, extra.strict.unwrap_or(self.strict))?; + let mut lib_url = self.get_url(input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(lib_url.scheme()) { @@ -207,11 +204,9 @@ impl Validator for MultiHostUrlValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let mut multi_url = self.get_url(input, extra.strict.unwrap_or(self.strict))?; + let mut multi_url = self.get_url(input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(multi_url.scheme()) { diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 6a15080f5..df00719b4 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -9,12 +9,11 @@ use uuid::Uuid; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::model::create_class; use super::model::force_setattr; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; const UUID_INT: &str = "int"; const UUID_IS_SAFE: &str = "is_safe"; @@ -88,9 +87,7 @@ impl Validator for UuidValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - _definitions: &'data Definitions, - _recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let class = get_uuid_type(py)?; if let Some(py_input) = input.input_is_instance(class) { @@ -108,7 +105,7 @@ impl Validator for UuidValidator { } } Ok(py_input.to_object(py)) - } else if extra.strict.unwrap_or(self.strict) && input.is_python() { + } else if state.strict_or(self.strict) && input.is_python() { Err(ValError::new( ErrorType::IsInstanceOf { class: class.name().unwrap_or("UUID").to_string(), diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs new file mode 100644 index 000000000..fd15e56bc --- /dev/null +++ b/src/validators/validation_state.rs @@ -0,0 +1,47 @@ +use crate::{definitions::Definitions, recursion_guard::RecursionGuard}; + +use super::{CombinedValidator, Extra}; + +pub struct ValidationState<'a> { + pub recursion_guard: &'a mut RecursionGuard, + pub definitions: &'a Definitions, + // deliberately make Extra readonly + extra: Extra<'a>, +} + +impl<'a> ValidationState<'a> { + pub fn new( + extra: Extra<'a>, + definitions: &'a Definitions, + recursion_guard: &'a mut RecursionGuard, + ) -> Self { + Self { + recursion_guard, + definitions, + extra, + } + } + + pub fn with_new_extra<'r, R: 'r>( + &mut self, + extra: Extra<'_>, + f: impl for<'s> FnOnce(&'s mut ValidationState<'_>) -> R, + ) -> R { + // TODO: It would be nice to implement this function with a drop guard instead of a closure, + // but lifetimes get in a tangle. Maybe someone brave wants to have a go at unpicking lifetimes. + let mut new_state = ValidationState { + recursion_guard: self.recursion_guard, + definitions: self.definitions, + extra, + }; + f(&mut new_state) + } + + pub fn extra(&self) -> &'_ Extra<'a> { + &self.extra + } + + pub fn strict_or(&self, default: bool) -> bool { + self.extra.strict.unwrap_or(default) + } +} diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index a89028ecb..819ca1973 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -5,13 +5,12 @@ use pyo3::types::PyDict; use pyo3::PyTraverseError; use pyo3::PyVisit; -use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; use crate::errors::{LocItem, ValError, ValResult}; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use crate::PydanticUndefinedType; @@ -131,26 +130,18 @@ impl Validator for WithDefaultValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { - Ok(self - .default_value(py, None::, extra, definitions, recursion_guard)? - .unwrap()) + Ok(self.default_value(py, None::, state)?.unwrap()) } else { - match self.validator.validate(py, input, extra, definitions, recursion_guard) { + match self.validator.validate(py, input, state) { Ok(v) => Ok(v), Err(e) => match e { - ValError::UseDefault => Ok(self - .default_value(py, None::, extra, definitions, recursion_guard)? - .ok_or(e)?), + ValError::UseDefault => Ok(self.default_value(py, None::, state)?.ok_or(e)?), e => match self.on_error { OnError::Raise => Err(e), - OnError::Default => Ok(self - .default_value(py, None::, extra, definitions, recursion_guard)? - .ok_or(e)?), + OnError::Default => Ok(self.default_value(py, None::, state)?.ok_or(e)?), OnError::Omit => Err(ValError::Omit), }, }, @@ -158,13 +149,11 @@ impl Validator for WithDefaultValidator { } } - fn default_value<'s, 'data>( - &'s self, + fn default_value<'data>( + &self, py: Python<'data>, outer_loc: Option>, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, + state: &mut ValidationState, ) -> ValResult<'data, Option> { match self.default.default_value(py)? { Some(stored_dft) => { @@ -175,7 +164,7 @@ impl Validator for WithDefaultValidator { stored_dft }; if self.validate_default { - match self.validate(py, dft.into_ref(py), extra, definitions, recursion_guard) { + match self.validate(py, dft.into_ref(py), state) { Ok(v) => Ok(Some(v)), Err(e) => { if let Some(outer_loc) = outer_loc {