Skip to content

Commit

Permalink
Merge pull request #691 from rhayes777/feature/filter
Browse files Browse the repository at this point in the history
feature/filter
  • Loading branch information
Jammy2211 authored Mar 6, 2023
2 parents 82050f0 + 8bbf30d commit dfa33e3
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 127 deletions.
22 changes: 20 additions & 2 deletions autofit/mapper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,26 @@ def values(self):
def __len__(self):
return len(self.values())

def as_model(self, model_classes=tuple()):
def as_model(
self,
model_classes: Union[type, Iterable[type]] = tuple(),
excluded_classes: Union[type, Iterable[type]] = tuple(),
):
"""
Convert this instance to a model
Parameters
----------
model_classes
The classes to convert to models
excluded_classes
The classes to exclude from conversion
Returns
-------
A model
"""

from autofit.mapper.prior_model.abstract import AbstractPriorModel

return AbstractPriorModel.from_instance(self, model_classes)
return AbstractPriorModel.from_instance(self, model_classes, excluded_classes,)
6 changes: 6 additions & 0 deletions autofit/mapper/model_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def from_dict(d):
instance = TuplePrior()
elif type_ == "dict":
return {key: ModelObject.from_dict(value) for key, value in d.items()}
elif type_ == "instance":
d.pop("type")
cls = get_class(d.pop("class_path"))
return cls(
**{key: ModelObject.from_dict(value) for key, value in d.items()}
)
else:
try:
return Prior.from_dict(d)
Expand Down
37 changes: 27 additions & 10 deletions autofit/mapper/prior_model/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,7 @@ def assert_no_assertions(obj):
try:
item = copy.copy(source)
if isinstance(item, dict):
from autofit.mapper.prior_model.collection import (
Collection,
)
from autofit.mapper.prior_model.collection import Collection

item = Collection(item)
for attribute in path:
Expand Down Expand Up @@ -1008,13 +1006,20 @@ def random_instance(self, ignore_prior_limits=False):

@staticmethod
@DynamicRecursionCache()
def from_instance(instance, model_classes=tuple()):
def from_instance(
instance,
model_classes: Union[type, Iterable[type]] = tuple(),
exclude_classes: Union[type, Iterable[type]] = tuple(),
):
"""
Recursively create an prior object model from an object model.
Recursively create a prior object model from an object model.
Parameters
----------
model_classes
A tuple of classes that should be converted to a prior model
exclude_classes
A tuple of classes that should not be converted to a prior model
instance
A dictionary, list, class instance or model instance
Returns
Expand All @@ -1024,12 +1029,18 @@ def from_instance(instance, model_classes=tuple()):
"""
from autofit.mapper.prior_model import collection

if isinstance(instance, exclude_classes):
return instance
if isinstance(instance, (Prior, AbstractPriorModel)):
return instance
elif isinstance(instance, list):
result = collection.Collection(
[
AbstractPriorModel.from_instance(item, model_classes=model_classes)
AbstractPriorModel.from_instance(
item,
model_classes=model_classes,
exclude_classes=exclude_classes,
)
for item in instance
]
)
Expand All @@ -1042,14 +1053,18 @@ def from_instance(instance, model_classes=tuple()):
result,
key,
AbstractPriorModel.from_instance(
value, model_classes=model_classes
value,
model_classes=model_classes,
exclude_classes=exclude_classes,
),
)
elif isinstance(instance, dict):
result = collection.Collection(
{
key: AbstractPriorModel.from_instance(
value, model_classes=model_classes
value,
model_classes=model_classes,
exclude_classes=exclude_classes,
)
for key, value in instance.items()
}
Expand All @@ -1064,15 +1079,17 @@ def from_instance(instance, model_classes=tuple()):
instance.__class__,
**{
key: AbstractPriorModel.from_instance(
value, model_classes=model_classes
value,
model_classes=model_classes,
exclude_classes=exclude_classes,
)
for key, value in instance.__dict__.items()
if key != "cls"
},
)
except AttributeError:
return instance
if any([isinstance(instance, cls) for cls in model_classes]):
if isinstance(instance, model_classes):
return result.as_model()
return result

Expand Down
140 changes: 33 additions & 107 deletions test_autofit/mapper/model/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,37 @@

import autofit as af

@pytest.fixture(
name="model_dict"
)

@pytest.fixture(name="model_dict")
def make_model_dict():
return {
"type": "model",
"class_path": "autofit.example.model.Gaussian",
"centre": {'lower_limit': 0.0, 'type': 'Uniform', 'upper_limit': 2.0},
"normalization": {'lower_limit': 0.0, 'type': 'Uniform', 'upper_limit': 1.0},
"sigma": {'lower_limit': 0.0, 'type': 'Uniform', 'upper_limit': 1.0},
"centre": {"lower_limit": 0.0, "type": "Uniform", "upper_limit": 2.0},
"normalization": {"lower_limit": 0.0, "type": "Uniform", "upper_limit": 1.0},
"sigma": {"lower_limit": 0.0, "type": "Uniform", "upper_limit": 1.0},
}


@pytest.fixture(
name="instance_dict"
)
@pytest.fixture(name="instance_dict")
def make_instance_dict():
return {
"type": "instance",
"class_path": "autofit.example.model.Gaussian",
"centre": 0.0,
"normalization": 0.1,
"sigma": 0.01
"sigma": 0.01,
}


@pytest.fixture(
name="collection_dict"
)
def make_collection_dict(
model_dict
):
return {
"gaussian": model_dict,
"type": "collection"
}
@pytest.fixture(name="collection_dict")
def make_collection_dict(model_dict):
return {"gaussian": model_dict, "type": "collection"}


@pytest.fixture(
name="model"
)
@pytest.fixture(name="model")
def make_model():
return af.Model(
af.Gaussian,
centre=af.UniformPrior(
upper_limit=2.0
)
)
return af.Model(af.Gaussian, centre=af.UniformPrior(upper_limit=2.0))


class TestTuple:
Expand All @@ -61,118 +44,61 @@ def test_tuple_prior(self):
tuple_prior.tup_0 = 0
tuple_prior.tup_1 = 1

result = af.Model.from_dict(
tuple_prior.dict()
)
assert isinstance(
result,
af.TuplePrior
)
result = af.Model.from_dict(tuple_prior.dict())
assert isinstance(result, af.TuplePrior)

def test_model_with_tuple(self):
tuple_model = af.Model(af.m.MockWithTuple)
tuple_model.instance_from_prior_medians()
model_dict = tuple_model.dict()

model = af.Model.from_dict(
model_dict
)
model = af.Model.from_dict(model_dict)
instance = model.instance_from_prior_medians()
assert instance.tup == (0.5, 0.5)


class TestFromDict:
def test_model_from_dict(
self,
model_dict
):
model = af.Model.from_dict(
model_dict
)
def test_model_from_dict(self, model_dict):
model = af.Model.from_dict(model_dict)
assert model.cls == af.Gaussian
assert model.prior_count == 3
assert model.centre.upper_limit == 2.0

def test_instance_from_dict(
self,
instance_dict
):
instance = af.Model.from_dict(
instance_dict
)
assert isinstance(
instance,
af.Gaussian
)
def test_instance_from_dict(self, instance_dict):
instance = af.Model.from_dict(instance_dict)
assert isinstance(instance, af.Gaussian)
assert instance.centre == 0.0
assert instance.normalization == 0.1
assert instance.sigma == 0.01

def test_collection_from_dict(
self,
collection_dict
):
collection = af.Model.from_dict(
collection_dict
)
assert isinstance(
collection,
af.Collection
)
def test_collection_from_dict(self, collection_dict):
collection = af.Model.from_dict(collection_dict)
assert isinstance(collection, af.Collection)
assert len(collection) == 1


class TestToDict:
def test_model_priors(
self,
model,
model_dict
):
def test_model_priors(self, model, model_dict):
assert model.dict() == model_dict

def test_model_floats(
self,
instance_dict
):
model = af.Model(
af.Gaussian,
centre=0.0,
normalization=0.1,
sigma=0.01
)
def test_model_floats(self, instance_dict):
model = af.Model(af.Gaussian, centre=0.0, normalization=0.1, sigma=0.01)

assert model.dict() == instance_dict

def test_collection(
self,
model,
collection_dict
):
collection = af.Collection(
gaussian=model
)
def test_collection(self, model, collection_dict):
collection = af.Collection(gaussian=model)
assert collection.dict() == collection_dict

def test_collection_instance(
self,
instance_dict
):
collection = af.Collection(
gaussian=af.Gaussian()
)
assert collection.dict() == {
"gaussian": instance_dict,
"type": "collection"
}
def test_collection_instance(self, instance_dict):
collection = af.Collection(gaussian=af.Gaussian())
assert collection.dict() == {"gaussian": instance_dict, "type": "collection"}


class TestFromJson:

def test__from_json(self, model_dict):

model = af.Model.from_dict(
model_dict
)
model = af.Model.from_dict(model_dict)

model_file = Path(__file__).parent / "model.json"

Expand All @@ -190,4 +116,4 @@ def test__from_json(self, model_dict):
assert model.prior_count == 3
assert model.centre.upper_limit == 2.0

os.remove(model_file)
os.remove(model_file)
Loading

0 comments on commit dfa33e3

Please sign in to comment.