diff --git a/test_autofit/mapper/model/test_model_instance.py b/test_autofit/mapper/model/test_model_instance.py index efd57c534..05e29ee37 100644 --- a/test_autofit/mapper/model/test_model_instance.py +++ b/test_autofit/mapper/model/test_model_instance.py @@ -170,20 +170,36 @@ class Child(af.Gaussian): pass -@pytest.fixture(name="instance") -def make_instance(): - return af.ModelInstance({"child": Child(), "gaussian": af.Gaussian()}) +class Child2(af.Gaussian): + pass + +@pytest.fixture(name="exclude_instance") +def make_excluded_instance(): + return af.ModelInstance( + {"child": Child(), "gaussian": af.Gaussian(), "child2": Child2(),} + ) -def test_single_argument(instance): - model = instance.as_model(af.Gaussian) + +def test_single_argument(exclude_instance): + model = exclude_instance.as_model(af.Gaussian) assert isinstance(model.gaussian, af.Model) assert isinstance(model.child, af.Model) + assert isinstance(model.child2, af.Model) + + +def test_filter_child(exclude_instance): + model = exclude_instance.as_model(af.Gaussian, excluded_classes=Child) + + assert isinstance(model.gaussian, af.Model) + assert not isinstance(model.child, af.Model) + assert isinstance(model.child2, af.Model) -def test_filter_child(instance): - model = instance.as_model(af.Gaussian, excluded_classes=Child) +def test_filter_multiple(exclude_instance): + model = exclude_instance.as_model(af.Gaussian, excluded_classes=(Child, Child2),) assert isinstance(model.gaussian, af.Model) assert not isinstance(model.child, af.Model) + assert not isinstance(model.child2, af.Model)