Skip to content

Commit

Permalink
Merge pull request #1021 from rhayes777/feature/prior_array
Browse files Browse the repository at this point in the history
feature/prior array
  • Loading branch information
Jammy2211 authored Jul 26, 2024
2 parents a468317 + 526f4c8 commit 355af89
Show file tree
Hide file tree
Showing 15 changed files with 764 additions and 11 deletions.
2 changes: 1 addition & 1 deletion autofit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .mapper.prior_model.annotation import AnnotationPriorModel
from .mapper.prior_model.collection import Collection
from .mapper.prior_model.prior_model import Model
from .mapper.prior_model.prior_model import Model
from .mapper.prior_model.array import Array
from .non_linear.search.abstract_search import NonLinearSearch
from .non_linear.analysis.visualize import Visualizer
from .non_linear.analysis.analysis import Analysis
Expand Down
7 changes: 7 additions & 0 deletions autofit/mapper/model_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def from_dict(
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior.gaussian import GaussianPrior
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior.arithmetic.compound import Compound
from autofit.mapper.prior.arithmetic.compound import ModifiedPrior
Expand Down Expand Up @@ -234,7 +235,10 @@ def get_class_path():
f"Could not find type for class path {class_path}. Defaulting to Instance placeholder."
)
instance = ModelInstance()
elif type_ == "array":
from autofit.mapper.prior_model.array import Array

return Array.from_dict(d)
else:
try:
return Prior.from_dict(d, loaded_ids=loaded_ids)
Expand Down Expand Up @@ -276,6 +280,7 @@ def dict(self) -> dict:
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior_model.array import Array

if isinstance(self, Collection):
type_ = "collection"
Expand All @@ -285,6 +290,8 @@ def dict(self) -> dict:
type_ = "model"
elif isinstance(self, TuplePrior):
type_ = "tuple_prior"
elif isinstance(self, Array):
type_ = "array"
else:
raise AssertionError(
f"{self.__class__.__name__} cannot be serialised to dict"
Expand Down
2 changes: 0 additions & 2 deletions autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import os
import random
from abc import ABC, abstractmethod
from copy import copy
Expand Down Expand Up @@ -115,7 +114,6 @@ def factor(self):
return self.message.factor

def assert_within_limits(self, value):

if jax_wrapper.use_jax:
return

Expand Down
215 changes: 215 additions & 0 deletions autofit/mapper/prior_model/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from typing import Tuple, Dict, Optional, Union

from autoconf.dictable import from_dict
from .abstract import AbstractPriorModel
from autofit.mapper.prior.abstract import Prior
import numpy as np

from autofit.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class Array(AbstractPriorModel):
def __init__(
self,
shape: Tuple[int, ...],
prior: Optional[Prior] = None,
):
"""
An array of priors.
Parameters
----------
shape : (int, int)
The shape of the array.
prior : Prior
The prior of every entry in the array.
"""
super().__init__()
self.shape = shape
self.indices = list(np.ndindex(*shape))

if prior is not None:
for index in self.indices:
self[index] = prior.new()

@staticmethod
def _make_key(index: Tuple[int, ...]) -> str:
"""
Make a key for the prior.
This is so an index (e.g. (1, 2)) can be used to access a
prior (e.g. prior_1_2).
Parameters
----------
index
The index of an element in an array.
Returns
-------
The attribute name for the prior.
"""
if isinstance(index, int):
suffix = str(index)
else:
suffix = "_".join(map(str, index))
return f"prior_{suffix}"

def _instance_for_arguments(
self,
arguments: Dict[Prior, float],
ignore_assertions: bool = False,
) -> np.ndarray:
"""
Create an array where the prior at each index is replaced with the
a concrete value.
Parameters
----------
arguments
The arguments to replace the priors with.
ignore_assertions
Whether to ignore assertions in the priors.
Returns
-------
The array with the priors replaced.
"""
array = np.zeros(self.shape)
for index in self.indices:
value = self[index]
try:
value = value.instance_for_arguments(
arguments,
ignore_assertions,
)
except AttributeError:
pass

array[index] = value
return array

def __setitem__(
self,
index: Union[int, Tuple[int, ...]],
value: Union[float, Prior],
):
"""
Set the value at an index.
Parameters
----------
index
The index of the prior.
value
The new value.
"""
setattr(
self,
self._make_key(index),
value,
)

def __getitem__(
self,
index: Union[int, Tuple[int, ...]],
) -> Union[float, Prior]:
"""
Get the value at an index.
Parameters
----------
index
The index of the value.
Returns
-------
The value at the index.
"""
return getattr(
self,
self._make_key(index),
)

@classmethod
def from_dict(
cls,
d,
reference: Optional[Dict[str, str]] = None,
loaded_ids: Optional[dict] = None,
) -> "Array":
"""
Create an array from a dictionary.
Parameters
----------
d
The dictionary.
reference
A dictionary of references.
loaded_ids
A dictionary of loaded ids.
Returns
-------
The array.
"""
arguments = d["arguments"]
shape = from_dict(arguments["shape"])
array = cls(shape)
for key, value in arguments.items():
if key.startswith("prior"):
setattr(array, key, from_dict(value))

return array

def tree_flatten(self):
"""
Flatten this array model as a PyTree.
"""
members = [self[index] for index in self.indices]
return members, (self.shape,)

@classmethod
def tree_unflatten(cls, aux_data, children):
"""
Unflatten a PyTree into an array model.
"""
(shape,) = aux_data
instance = cls(shape)
for index, child in zip(instance.indices, children):
instance[index] = child

return instance

@property
def prior_class_dict(self):
return {
**{
prior: cls
for prior_model in self.direct_prior_model_tuples
for prior, cls in prior_model[1].prior_class_dict.items()
},
**{prior: np.ndarray for _, prior in self.direct_prior_tuples},
}

def gaussian_prior_model_for_arguments(self, arguments: Dict[Prior, Prior]):
"""
Returns a new instance of model mapper with a set of Gaussian priors based on
tuples provided by a previous nonlinear search.
Parameters
----------
arguments
Tuples providing the mean and sigma of gaussians
Returns
-------
A new model mapper populated with Gaussian priors
"""
new_array = Array(self.shape)
for index in self.indices:
new_array[index] = self[index].gaussian_prior_model_for_arguments(arguments)

return new_array
Loading

0 comments on commit 355af89

Please sign in to comment.