Skip to content

Commit

Permalink
Use new format for union/tagged-union error locs
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Jul 28, 2023
1 parent 1aaa035 commit be5fc6e
Show file tree
Hide file tree
Showing 24 changed files with 144 additions and 80 deletions.
11 changes: 11 additions & 0 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'Dict[Hashable, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'List[Union[CoreSchema, Tuple[CoreSchema, str]]]':
schema = {
'type': 'list',
'items_schema': {
'type': 'union',
'choices': [
schema_ref_validator,
{'type': 'tuple-positional', 'items_schema': [schema_ref_validator, {'type': 'str'}]},
],
},
}
else:
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
else:
Expand Down
8 changes: 4 additions & 4 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import Protocol, Required, TypeAlias
Expand Down Expand Up @@ -2379,7 +2379,7 @@ def nullable_schema(

class UnionSchema(TypedDict, total=False):
type: Required[Literal['union']]
choices: Required[List[CoreSchema]]
choices: Required[List[Union[CoreSchema, Tuple[CoreSchema, str]]]]
# default true, whether to automatically collapse unions with one element to the inner validator
auto_collapse: bool
custom_error_type: str
Expand All @@ -2392,7 +2392,7 @@ class UnionSchema(TypedDict, total=False):


def union_schema(
choices: list[CoreSchema],
choices: list[CoreSchema | tuple[CoreSchema, str]],
*,
auto_collapse: bool | None = None,
custom_error_type: str | None = None,
Expand All @@ -2416,7 +2416,7 @@ def union_schema(
```
Args:
choices: The schemas to match
choices: The schemas to match. If a tuple, the second item is used as the label for the case.
auto_collapse: whether to automatically collapse unions with one element to the inner validator, default true
custom_error_type: The custom error type to use if the validation fails
custom_error_message: The custom error message to use if the validation fails
Expand Down
54 changes: 37 additions & 17 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::fmt::Write;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use pyo3::{intern, PyTraverseError, PyVisit};

use crate::build_tools::py_schema_err;
use crate::build_tools::{is_strict, schema_or_config};
use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult};
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
use crate::input::{GenericMapping, Input};
use crate::lookup_key::LookupKey;
use crate::py_gc::PyGcTraverse;
Expand All @@ -20,6 +20,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, Definitions, Def
#[derive(Debug, Clone)]
pub struct UnionValidator {
choices: Vec<CombinedValidator>,
labels: Vec<Option<String>>,
custom_error: Option<CustomError>,
strict: bool,
name: String,
Expand All @@ -36,21 +37,42 @@ impl BuildValidator for UnionValidator {
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let mut labels: Vec<Option<String>> = vec![];
let choices: Vec<CombinedValidator> = schema
.get_as_req::<&PyList>(intern!(py, "choices"))?
.iter()
.map(|choice| build_validator(choice, config, definitions))
.map(|choice| {
let choice: &PyAny = match choice.downcast::<PyTuple>() {
Ok(py_tuple) => {
let choice = py_tuple.get_item(0)?;
let label = py_tuple.get_item(1)?;
labels.push(Some(label.to_string()));
choice
}
Err(_) => {
labels.push(None);
choice
}
};
build_validator(choice, config, definitions)
})
.collect::<PyResult<Vec<CombinedValidator>>>()?;

let auto_collapse = || schema.get_as_req(intern!(py, "auto_collapse")).unwrap_or(true);
match choices.len() {
0 => py_schema_err!("One or more union choices required"),
1 if auto_collapse() => Ok(choices.into_iter().next().unwrap()),
_ => {
let descr = choices.iter().map(Validator::get_name).collect::<Vec<_>>().join(",");
let descr = choices
.iter()
.zip(&labels)
.map(|(choice, label)| label.as_deref().unwrap_or(choice.get_name()))
.collect::<Vec<_>>()
.join(",");

Ok(Self {
choices,
labels,
custom_error: CustomError::build(schema, config, definitions)?,
strict: is_strict(schema, config)?,
name: format!("{}[{descr}]", Self::EXPECTED_TYPE),
Expand Down Expand Up @@ -108,18 +130,17 @@ impl Validator for UnionValidator {
};
let strict_extra = extra.as_strict(false);

for validator in &self.choices {
for (validator, label) in self.choices.iter().zip(&self.labels) {
let line_errors = match validator.validate(py, input, &strict_extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
otherwise => return otherwise,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(format!("[case:{case_label}]").into())
}));
}
}

Expand All @@ -145,18 +166,17 @@ impl Validator for UnionValidator {
};

// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
for validator in &self.choices {
for (validator, label) in self.choices.iter().zip(self.labels.iter()) {
let line_errors = match validator.validate(py, input, extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
success => return success,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(format!("[case:{case_label}]").into())
}));
}
}

Expand Down Expand Up @@ -435,7 +455,7 @@ impl TaggedUnionValidator {
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
return match validator.validate(py, input, extra, definitions, recursion_guard) {
Ok(res) => Ok(res),
Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)),
Err(err) => Err(err.with_outer_location(format!("[tag:{}]", tag.repr()?).into())),
};
}
match self.custom_error {
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_def_error():
)

assert str(exc_info.value).startswith(
"Invalid Schema:\ndefinitions.definitions.1\n Input tag 'wrong' found using 'type'"
"Invalid Schema:\n[tag:'definitions'].definitions.1\n Input tag 'wrong' found using 'type'"
)


Expand Down
4 changes: 2 additions & 2 deletions tests/validators/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def my_function(a):
assert exc_info.value.errors(include_url=False) == [
{
'type': 'unexpected_positional_argument',
'loc': ('call[my_function]', 1),
'loc': ('[case:call[my_function]]', 1),
'msg': 'Unexpected positional argument',
'input': 2,
},
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': (1, 2)},
{'type': 'int_type', 'loc': ('[case:int]',), 'msg': 'Input should be a valid integer', 'input': (1, 2)},
]


Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def test_repr():
assert v.isinstance_python(func) is True
assert v.isinstance_python('foo') is False

with pytest.raises(ValidationError, match=r'callable\s+Input should be callable'):
with pytest.raises(ValidationError, match=r'\[case:callable]\s+Input should be callable'):
v.validate_python('foo')
4 changes: 2 additions & 2 deletions tests/validators/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,11 @@ def test_custom_dataclass_names():
{
'ctx': {'class_name': 'FooDataclass[dataclass_args_schema]'},
'input': 123,
'loc': ('foo', 'FooDataclass[cls_name]'),
'loc': ('foo', '[case:FooDataclass[cls_name]]'),
'msg': 'Input should be a dictionary or an instance of FooDataclass[dataclass_args_schema]',
'type': 'dataclass_type',
},
{'input': 123, 'loc': ('foo', 'none'), 'msg': 'Input should be None', 'type': 'none_required'},
{'input': 123, 'loc': ('foo', '[case:none]'), 'msg': 'Input should be None', 'type': 'none_required'},
]


Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_date_kwargs(kwargs: Dict[str, Any], input_value, expected):


def test_invalid_constraint():
with pytest.raises(SchemaError, match=r'date\.gt\n Input should be a valid date or datetime'):
with pytest.raises(SchemaError, match=r"\[tag:'date']\.gt\n Input should be a valid date or datetime"):
SchemaValidator({'type': 'date', 'gt': 'foobar'})


Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_union():


def test_invalid_constraint():
with pytest.raises(SchemaError, match=r'datetime\.gt\n Input should be a valid datetime'):
with pytest.raises(SchemaError, match=r"\[tag:'datetime']\.gt\n Input should be a valid datetime"):
SchemaValidator({'type': 'datetime', 'gt': 'foobar'})


Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_def_error():
)
)
assert str(exc_info.value).startswith(
"Invalid Schema:\ndefinitions.definitions.1\n Input tag 'wrong' found using 'type'"
"Invalid Schema:\n[tag:'definitions'].definitions.1\n Input tag 'wrong' found using 'type'"
)
assert exc_info.value.error_count() == 1

Expand Down
18 changes: 9 additions & 9 deletions tests/validators/test_definitions_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ def test_nullable_error():
assert exc_info.value.errors(include_url=False) == [
{
'type': 'none_required',
'loc': ('sub_branch', 'none'),
'loc': ('sub_branch', '[case:none]'),
'msg': 'Input should be None',
'input': {'width': 'wrong'},
},
{
'type': 'int_parsing',
'loc': ('sub_branch', 'typed-dict', 'width'),
'loc': ('sub_branch', '[case:typed-dict]', 'width'),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'wrong',
},
Expand Down Expand Up @@ -606,8 +606,8 @@ def test_union_ref_strictness():
v.validate_python({'a': 1, 'b': []})

assert exc_info.value.errors(include_url=False) == [
{'type': 'int_type', 'loc': ('b', 'int'), 'msg': 'Input should be a valid integer', 'input': []},
{'type': 'string_type', 'loc': ('b', 'str'), 'msg': 'Input should be a valid string', 'input': []},
{'type': 'int_type', 'loc': ('b', '[case:int]'), 'msg': 'Input should be a valid integer', 'input': []},
{'type': 'string_type', 'loc': ('b', '[case:str]'), 'msg': 'Input should be a valid string', 'input': []},
]


Expand All @@ -631,8 +631,8 @@ def test_union_container_strictness():
v.validate_python({'a': 1, 'b': []})

assert exc_info.value.errors(include_url=False) == [
{'type': 'int_type', 'loc': ('b', 'int'), 'msg': 'Input should be a valid integer', 'input': []},
{'type': 'string_type', 'loc': ('b', 'str'), 'msg': 'Input should be a valid string', 'input': []},
{'type': 'int_type', 'loc': ('b', '[case:int]'), 'msg': 'Input should be a valid integer', 'input': []},
{'type': 'string_type', 'loc': ('b', '[case:str]'), 'msg': 'Input should be a valid string', 'input': []},
]


Expand Down Expand Up @@ -669,7 +669,7 @@ def test_union_cycle(strict: bool):
assert exc_info.value.errors(include_url=False) == [
{
'type': 'recursion_loop',
'loc': ('typed-dict', 'foobar', 0),
'loc': ('[case:typed-dict]', 'foobar', 0),
'msg': 'Recursion error - cyclic reference detected',
'input': {'foobar': [{'foobar': IsList(length=1)}]},
}
Expand Down Expand Up @@ -703,13 +703,13 @@ def f(input_value, info):
assert exc_info.value.errors(include_url=False) == [
{
'type': 'recursion_loop',
'loc': ('function-after[f(), ...]',),
'loc': ('[case:function-after[f(), ...]]',),
'msg': 'Recursion error - cyclic reference detected',
'input': 'input value',
},
{
'type': 'int_parsing',
'loc': ('int',),
'loc': ('[case:int]',),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'input value',
},
Expand Down
8 changes: 4 additions & 4 deletions tests/validators/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def test_union_float(py_and_json: PyAndJson):
with pytest.raises(ValidationError) as exc_info:
v.validate_test('5')
assert exc_info.value.errors(include_url=False) == [
{'type': 'float_type', 'loc': ('float',), 'msg': 'Input should be a valid number', 'input': '5'},
{'type': 'float_type', 'loc': ('[case:float]',), 'msg': 'Input should be a valid number', 'input': '5'},
{
'type': 'multiple_of',
'loc': ('constrained-float',),
'loc': ('[case:constrained-float]',),
'msg': 'Input should be a multiple of 7',
'input': '5',
'ctx': {'multiple_of': 7.0},
Expand All @@ -162,13 +162,13 @@ def test_union_float_simple(py_and_json: PyAndJson):
assert exc_info.value.errors(include_url=False) == [
{
'type': 'float_parsing',
'loc': ('float',),
'loc': ('[case:float]',),
'msg': 'Input should be a valid number, unable to parse string as a number',
'input': 'xxx',
},
{
'type': 'list_type',
'loc': ('list[any]',),
'loc': ('[case:list[any]]',),
'msg': IsStr(regex='Input should be a valid (list|array)'),
'input': 'xxx',
},
Expand Down
4 changes: 2 additions & 2 deletions tests/validators/test_frozenset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def test_union_frozenset_list(input_value, expected):
errors=[
{
'type': 'int_type',
'loc': ('frozenset[int]', 1),
'loc': ('[case:frozenset[int]]', 1),
'msg': 'Input should be a valid integer',
'input': 'a',
},
# second because validation on the string choice comes second
{
'type': 'string_type',
'loc': ('frozenset[str]', 0),
'loc': ('[case:frozenset[str]]', 0),
'msg': 'Input should be a valid string',
'input': 1,
},
Expand Down
8 changes: 5 additions & 3 deletions tests/validators/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,14 @@ def f(input_value, validator, info):


def test_function_wrap_not_callable():
with pytest.raises(SchemaError, match='function-wrap.function.typed-dict.function\n Input should be callable'):
with pytest.raises(
SchemaError, match=r"\[tag:'function-wrap'].function.\[case:typed-dict].function\n Input should be callable"
):
SchemaValidator(
{'type': 'function-wrap', 'function': {'type': 'general', 'function': []}, 'schema': {'type': 'str'}}
)

with pytest.raises(SchemaError, match='function-wrap.function\n Field required'):
with pytest.raises(SchemaError, match=r"\[tag:'function-wrap'].function\n Field required"):
SchemaValidator({'type': 'function-wrap', 'schema': {'type': 'str'}})


Expand Down Expand Up @@ -442,7 +444,7 @@ def f(input_value):


def test_plain_with_schema():
with pytest.raises(SchemaError, match='function-plain.schema\n Extra inputs are not permitted'):
with pytest.raises(SchemaError, match=r"\[tag:'function-plain'].schema\n Extra inputs are not permitted"):
SchemaValidator(
{
'type': 'function-plain',
Expand Down
Loading

0 comments on commit be5fc6e

Please sign in to comment.