Skip to content

Commit

Permalink
Fix tagged union serialization warning when using aliases (#1442)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Sep 6, 2024
1 parent c462f77 commit f2a0bb8
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 30 deletions.
49 changes: 24 additions & 25 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,34 +191,10 @@ impl LookupKey {
}
}

pub fn py_get_attr<'py, 's>(
pub fn simple_py_get_attr<'py, 's>(
&'s self,
obj: &Bound<'py, PyAny>,
kwargs: Option<&Bound<'py, PyDict>>,
) -> ValResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
match self._py_get_attr(obj, kwargs) {
Ok(v) => Ok(v),
Err(err) => {
let error = py_err_string(obj.py(), err);
Err(ValError::new(
ErrorType::GetAttributeError { error, context: None },
obj,
))
}
}
}

pub fn _py_get_attr<'py, 's>(
&'s self,
obj: &Bound<'py, PyAny>,
kwargs: Option<&Bound<'py, PyDict>>,
) -> PyResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
if let Some(dict) = kwargs {
if let Ok(Some(item)) = self.py_get_dict_item(dict) {
return Ok(Some(item));
}
}

match self {
Self::Simple { py_key, path, .. } => match py_get_attrs(obj, py_key)? {
Some(value) => Ok(Some((path, value))),
Expand Down Expand Up @@ -260,6 +236,29 @@ impl LookupKey {
}
}

pub fn py_get_attr<'py, 's>(
&'s self,
obj: &Bound<'py, PyAny>,
kwargs: Option<&Bound<'py, PyDict>>,
) -> ValResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
if let Some(dict) = kwargs {
if let Ok(Some(item)) = self.py_get_dict_item(dict) {
return Ok(Some(item));
}
}

match self.simple_py_get_attr(obj) {
Ok(v) => Ok(v),
Err(err) => {
let error = py_err_string(obj.py(), err);
Err(ValError::new(
ErrorType::GetAttributeError { error, context: None },
obj,
))
}
}
}

pub fn json_get<'a, 'data, 's>(
&'s self,
dict: &'a JsonObject<'data>,
Expand Down
9 changes: 4 additions & 5 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::lookup_key::LookupKey;
use crate::serializers::type_serializers::py_err_se_err;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;
Expand Down Expand Up @@ -438,10 +437,10 @@ impl TaggedUnionSerializer {
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
Discriminator::LookupKey(lookup_key) => match lookup_key {
LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok().map(|obj| obj.to_object(py)),
_ => None,
},
Discriminator::LookupKey(lookup_key) => lookup_key
.simple_py_get_attr(value)
.ok()
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
Expand Down
59 changes: 59 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,62 @@ def test_custom_serializer() -> None:
print(s)
assert s.to_python([{'id': 1}, {'id': 2}]) == [1, 2]
assert s.to_python({'id': 1}) == 1


def test_tagged_union_with_aliases() -> None:
@dataclasses.dataclass
class ModelA:
field: int
tag: Literal['a'] = 'a'

@dataclasses.dataclass
class ModelB:
field: int
tag: Literal['b'] = 'b'

s = SchemaSerializer(
core_schema.tagged_union_schema(
choices={
'a': core_schema.dataclass_schema(
ModelA,
core_schema.dataclass_args_schema(
'ModelA',
[
core_schema.dataclass_field(name='field', schema=core_schema.int_schema()),
core_schema.dataclass_field(
name='tag',
schema=core_schema.literal_schema(['a']),
validation_alias='TAG',
serialization_alias='TAG',
),
],
),
['field', 'tag'],
),
'b': core_schema.dataclass_schema(
ModelB,
core_schema.dataclass_args_schema(
'ModelB',
[
core_schema.dataclass_field(name='field', schema=core_schema.int_schema()),
core_schema.dataclass_field(
name='tag',
schema=core_schema.literal_schema(['b']),
validation_alias='TAG',
serialization_alias='TAG',
),
],
),
['field', 'tag'],
),
},
discriminator=[['tag'], ['TAG']],
)
)

assert 'TaggedUnionSerializer' in repr(s)

model_a = ModelA(field=1)
model_b = ModelB(field=1)
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}

0 comments on commit f2a0bb8

Please sign in to comment.