diff --git a/marshmallow_jsonschema/base.py b/marshmallow_jsonschema/base.py index 6ee3fa3..700cae4 100644 --- a/marshmallow_jsonschema/base.py +++ b/marshmallow_jsonschema/base.py @@ -292,7 +292,7 @@ def _from_nested_schema(self, obj, field): only = field.only exclude = field.exclude nested_cls = nested - nested_instance = nested(only=only, exclude=exclude) + nested_instance = nested(only=only, exclude=exclude, context=obj.context) else: nested_cls = nested.__class__ name = nested_cls.__name__ diff --git a/tests/test_dump.py b/tests/test_dump.py index 50b9fa3..b1d53d5 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -133,6 +133,32 @@ class TestSchema(Schema): assert nested_def["properties"]["foo"]["type"] == "integer" +def test_nested_context(): + class TestNestedSchema(Schema): + def __init__(self, *args, **kwargs): + if kwargs.get("context", {}).get("hide", False): + kwargs["exclude"] = ["foo"] + super().__init__(*args, **kwargs) + + foo = fields.Integer(required=True) + bar = fields.Integer(required=True) + + class TestSchema(Schema): + bar = fields.Nested(TestNestedSchema) + + schema = TestSchema() + dumped_show = validate_and_dump(schema) + + schema = TestSchema(context={"hide": True}) + dumped_hide = validate_and_dump(schema) + + nested_show = dumped_show["definitions"]["TestNestedSchema"]["properties"] + nested_hide = dumped_hide["definitions"]["TestNestedSchema"]["properties"] + + assert "bar" in nested_show and "foo" in nested_show + assert "bar" in nested_hide and "foo" not in nested_hide + + def test_list(): class ListSchema(Schema): foo = fields.List(fields.String(), required=True)