Skip to content

Commit

Permalink
Support manually specifying case labels for union validators (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu authored Aug 15, 2023
1 parent 5c98b05 commit 09f0acf
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 32 deletions.
11 changes: 11 additions & 0 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'Dict[Hashable, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'List[Union[CoreSchema, Tuple[CoreSchema, str]]]':
schema = {
'type': 'list',
'items_schema': {
'type': 'union',
'choices': [
schema_ref_validator,
{'type': 'tuple-positional', 'items_schema': [schema_ref_validator, {'type': 'str'}]},
],
},
}
else:
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
else:
Expand Down
8 changes: 4 additions & 4 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import Protocol, Required, TypeAlias
Expand Down Expand Up @@ -2454,7 +2454,7 @@ def nullable_schema(

class UnionSchema(TypedDict, total=False):
type: Required[Literal['union']]
choices: Required[List[CoreSchema]]
choices: Required[List[Union[CoreSchema, Tuple[CoreSchema, str]]]]
# default true, whether to automatically collapse unions with one element to the inner validator
auto_collapse: bool
custom_error_type: str
Expand All @@ -2467,7 +2467,7 @@ class UnionSchema(TypedDict, total=False):


def union_schema(
choices: list[CoreSchema],
choices: list[CoreSchema | tuple[CoreSchema, str]],
*,
auto_collapse: bool | None = None,
custom_error_type: str | None = None,
Expand All @@ -2491,7 +2491,7 @@ def union_schema(
```
Args:
choices: The schemas to match
choices: The schemas to match. If a tuple, the second item is used as the label for the case.
auto_collapse: whether to automatically collapse unions with one element to the inner validator, default true
custom_error_type: The custom error type to use if the validation fails
custom_error_message: The custom error message to use if the validation fails
Expand Down
10 changes: 8 additions & 2 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::types::{PyDict, PyList, PyTuple};
use std::borrow::Cow;

use crate::build_tools::py_schema_err;
Expand Down Expand Up @@ -31,7 +31,13 @@ impl BuildSerializer for UnionSerializer {
let choices: Vec<CombinedSerializer> = schema
.get_as_req::<&PyList>(intern!(py, "choices"))?
.iter()
.map(|choice| CombinedSerializer::build(choice.downcast()?, config, definitions))
.map(|choice| {
let choice: &PyAny = match choice.downcast::<PyTuple>() {
Ok(py_tuple) => py_tuple.get_item(0)?,
Err(_) => choice,
};
CombinedSerializer::build(choice.downcast()?, config, definitions)
})
.collect::<PyResult<Vec<CombinedSerializer>>>()?;

Self::from_choices(choices)
Expand Down
70 changes: 46 additions & 24 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Write;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use pyo3::{intern, PyTraverseError, PyVisit};

use crate::build_tools::py_schema_err;
Expand All @@ -19,7 +19,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, Definitions, Def

#[derive(Debug, Clone)]
pub struct UnionValidator {
choices: Vec<CombinedValidator>,
choices: Vec<(CombinedValidator, Option<String>)>,
custom_error: Option<CustomError>,
strict: bool,
name: String,
Expand All @@ -36,18 +36,33 @@ impl BuildValidator for UnionValidator {
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let choices: Vec<CombinedValidator> = schema
let choices: Vec<(CombinedValidator, Option<String>)> = schema
.get_as_req::<&PyList>(intern!(py, "choices"))?
.iter()
.map(|choice| build_validator(choice, config, definitions))
.collect::<PyResult<Vec<CombinedValidator>>>()?;
.map(|choice| {
let mut label: Option<String> = None;
let choice: &PyAny = match choice.downcast::<PyTuple>() {
Ok(py_tuple) => {
let choice = py_tuple.get_item(0)?;
label = Some(py_tuple.get_item(1)?.to_string());
choice
}
Err(_) => choice,
};
Ok((build_validator(choice, config, definitions)?, label))
})
.collect::<PyResult<Vec<(CombinedValidator, Option<String>)>>>()?;

let auto_collapse = || schema.get_as_req(intern!(py, "auto_collapse")).unwrap_or(true);
match choices.len() {
0 => py_schema_err!("One or more union choices required"),
1 if auto_collapse() => Ok(choices.into_iter().next().unwrap()),
1 if auto_collapse() => Ok(choices.into_iter().next().unwrap().0),
_ => {
let descr = choices.iter().map(Validator::get_name).collect::<Vec<_>>().join(",");
let descr = choices
.iter()
.map(|(choice, label)| label.as_deref().unwrap_or(choice.get_name()))
.collect::<Vec<_>>()
.join(",");

Ok(Self {
choices,
Expand Down Expand Up @@ -77,7 +92,12 @@ impl UnionValidator {
}
}

impl_py_gc_traverse!(UnionValidator { choices });
impl PyGcTraverse for UnionValidator {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.choices.iter().try_for_each(|(v, _)| v.py_gc_traverse(visit))?;
Ok(())
}
}

impl Validator for UnionValidator {
fn validate<'s, 'data>(
Expand All @@ -94,7 +114,9 @@ impl Validator for UnionValidator {
if let Some(res) = self
.choices
.iter()
.map(|validator| validator.validate(py, input, &ultra_strict_extra, definitions, recursion_guard))
.map(|(validator, _label)| {
validator.validate(py, input, &ultra_strict_extra, definitions, recursion_guard)
})
.find(ValResult::is_ok)
{
return res;
Expand All @@ -108,18 +130,17 @@ impl Validator for UnionValidator {
};
let strict_extra = extra.as_strict(false);

for validator in &self.choices {
for (validator, label) in &self.choices {
let line_errors = match validator.validate(py, input, &strict_extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
otherwise => return otherwise,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(case_label.into())
}));
}
}

Expand All @@ -132,7 +153,9 @@ impl Validator for UnionValidator {
if let Some(res) = self
.choices
.iter()
.map(|validator| validator.validate(py, input, &strict_extra, definitions, recursion_guard))
.map(|(validator, _label)| {
validator.validate(py, input, &strict_extra, definitions, recursion_guard)
})
.find(ValResult::is_ok)
{
return res;
Expand All @@ -145,18 +168,17 @@ impl Validator for UnionValidator {
};

// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
for validator in &self.choices {
for (validator, label) in &self.choices {
let line_errors = match validator.validate(py, input, extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
success => return success,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(case_label.into())
}));
}
}

Expand All @@ -171,15 +193,15 @@ impl Validator for UnionValidator {
) -> bool {
self.choices
.iter()
.any(|v| v.different_strict_behavior(definitions, ultra_strict))
.any(|(v, _)| v.different_strict_behavior(definitions, ultra_strict))
}

fn get_name(&self) -> &str {
&self.name
}

fn complete(&mut self, definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
self.choices.iter_mut().try_for_each(|v| v.complete(definitions))?;
self.choices.iter_mut().try_for_each(|(v, _)| v.complete(definitions))?;
self.strict_required = self.different_strict_behavior(Some(definitions), false);
self.ultra_strict_required = self.different_strict_behavior(Some(definitions), true);
Ok(())
Expand Down
9 changes: 7 additions & 2 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ def __init__(self, **kwargs) -> None:
setattr(self, name, value)


@pytest.mark.parametrize('bool_case_label', [False, True])
@pytest.mark.parametrize('int_case_label', [False, True])
@pytest.mark.parametrize('input_value,expected_value', [(True, True), (False, False), (1, 1), (123, 123), (-42, -42)])
def test_union_bool_int(input_value, expected_value):
s = SchemaSerializer(core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]))
def test_union_bool_int(input_value, expected_value, bool_case_label, int_case_label):
bool_case = core_schema.bool_schema() if not bool_case_label else (core_schema.bool_schema(), 'my_bool_label')
int_case = core_schema.int_schema() if not int_case_label else (core_schema.int_schema(), 'my_int_label')
s = SchemaSerializer(core_schema.union_schema([bool_case, int_case]))

assert s.to_python(input_value) == expected_value
assert s.to_python(input_value, mode='json') == expected_value
assert s.to_json(input_value) == json.dumps(expected_value).encode()
Expand Down
20 changes: 20 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,23 @@ def test_strict_reference():

assert repr(v.validate_python((1, 2))) == '(1.0, 2)'
assert repr(v.validate_python((1.0, (2.0, 3)))) == '(1.0, (2.0, 3))'


def test_case_labels():
v = SchemaValidator(
{'type': 'union', 'choices': [{'type': 'none'}, ({'type': 'int'}, 'my_label'), {'type': 'str'}]}
)
assert v.validate_python(None) is None
assert v.validate_python(1) == 1
with pytest.raises(ValidationError, match=r'3 validation errors for union\[none,my_label,str]') as exc_info:
v.validate_python(1.5)
assert exc_info.value.errors(include_url=False) == [
{'input': 1.5, 'loc': ('none',), 'msg': 'Input should be None', 'type': 'none_required'},
{
'input': 1.5,
'loc': ('my_label',),
'msg': 'Input should be a valid integer, got a number with a fractional part',
'type': 'int_from_float',
},
{'input': 1.5, 'loc': ('str',), 'msg': 'Input should be a valid string', 'type': 'string_type'},
]

0 comments on commit 09f0acf

Please sign in to comment.