Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/prior array #1021

Merged
merged 20 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autofit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .mapper.prior_model.annotation import AnnotationPriorModel
from .mapper.prior_model.collection import Collection
from .mapper.prior_model.prior_model import Model
from .mapper.prior_model.prior_model import Model
from .mapper.prior_model.array import Array
from .non_linear.search.abstract_search import NonLinearSearch
from .non_linear.analysis.visualize import Visualizer
from .non_linear.analysis.analysis import Analysis
Expand Down
7 changes: 7 additions & 0 deletions autofit/mapper/model_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def from_dict(
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior.gaussian import GaussianPrior
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior.arithmetic.compound import Compound
from autofit.mapper.prior.arithmetic.compound import ModifiedPrior
Expand Down Expand Up @@ -234,7 +235,10 @@ def get_class_path():
f"Could not find type for class path {class_path}. Defaulting to Instance placeholder."
)
instance = ModelInstance()
elif type_ == "array":
from autofit.mapper.prior_model.array import Array

return Array.from_dict(d)
else:
try:
return Prior.from_dict(d, loaded_ids=loaded_ids)
Expand Down Expand Up @@ -276,6 +280,7 @@ def dict(self) -> dict:
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior_model.array import Array

if isinstance(self, Collection):
type_ = "collection"
Expand All @@ -285,6 +290,8 @@ def dict(self) -> dict:
type_ = "model"
elif isinstance(self, TuplePrior):
type_ = "tuple_prior"
elif isinstance(self, Array):
type_ = "array"
else:
raise AssertionError(
f"{self.__class__.__name__} cannot be serialised to dict"
Expand Down
2 changes: 0 additions & 2 deletions autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import os
import random
from abc import ABC, abstractmethod
from copy import copy
Expand Down Expand Up @@ -115,7 +114,6 @@ def factor(self):
return self.message.factor

def assert_within_limits(self, value):

if jax_wrapper.use_jax:
return

Expand Down
215 changes: 215 additions & 0 deletions autofit/mapper/prior_model/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from typing import Tuple, Dict, Optional, Union

from autoconf.dictable import from_dict
from .abstract import AbstractPriorModel
from autofit.mapper.prior.abstract import Prior
import numpy as np

from autofit.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class Array(AbstractPriorModel):
def __init__(
self,
shape: Tuple[int, ...],
prior: Optional[Prior] = None,
):
"""
An array of priors.

Parameters
----------
shape : (int, int)
The shape of the array.
prior : Prior

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you describe what happens if the parameter prior is not provided?

The prior of every entry in the array.
"""
super().__init__()
self.shape = shape
self.indices = list(np.ndindex(*shape))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are converting self.indices to a list but you are usint the typing Tuple[int, ...] everywhere.


if prior is not None:
for index in self.indices:
self[index] = prior.new()

@staticmethod
def _make_key(index: Tuple[int, ...]) -> str:
"""
Make a key for the prior.

This is so an index (e.g. (1, 2)) can be used to access a
prior (e.g. prior_1_2).

Parameters
----------
index
The index of an element in an array.

Returns
-------
The attribute name for the prior.
"""
if isinstance(index, int):
suffix = str(index)
else:
suffix = "_".join(map(str, index))
return f"prior_{suffix}"

def _instance_for_arguments(
self,
arguments: Dict[Prior, float],
ignore_assertions: bool = False,
) -> np.ndarray:
"""
Create an array where the prior at each index is replaced with the
a concrete value.

Parameters
----------
arguments
The arguments to replace the priors with.
ignore_assertions
Whether to ignore assertions in the priors.

Returns
-------
The array with the priors replaced.
"""
array = np.zeros(self.shape)
for index in self.indices:
value = self[index]
try:
value = value.instance_for_arguments(
arguments,
ignore_assertions,
)
except AttributeError:
pass

array[index] = value
return array

def __setitem__(
self,
index: Union[int, Tuple[int, ...]],
value: Union[float, Prior],
):
"""
Set the value at an index.

Parameters
----------
index
The index of the prior.
value
The new value.
"""
setattr(
self,
self._make_key(index),
value,
)

def __getitem__(
self,
index: Union[int, Tuple[int, ...]],
) -> Union[float, Prior]:
"""
Get the value at an index.

Parameters
----------
index
The index of the value.

Returns
-------
The value at the index.
"""
return getattr(
self,
self._make_key(index),
)

@classmethod
def from_dict(
cls,
d,
reference: Optional[Dict[str, str]] = None,
loaded_ids: Optional[dict] = None,
) -> "Array":
"""
Create an array from a dictionary.

Parameters
----------
d
The dictionary.
reference
A dictionary of references.
loaded_ids
A dictionary of loaded ids.

Returns
-------
The array.
"""
arguments = d["arguments"]
shape = from_dict(arguments["shape"])
array = cls(shape)
for key, value in arguments.items():
if key.startswith("prior"):
setattr(array, key, from_dict(value))

return array

def tree_flatten(self):
"""
Flatten this array model as a PyTree.
"""
members = [self[index] for index in self.indices]
return members, (self.shape,)

@classmethod
def tree_unflatten(cls, aux_data, children):
"""
Unflatten a PyTree into an array model.
"""
(shape,) = aux_data
instance = cls(shape)
for index, child in zip(instance.indices, children):
instance[index] = child

return instance

@property
def prior_class_dict(self):
return {
**{
prior: cls
for prior_model in self.direct_prior_model_tuples
for prior, cls in prior_model[1].prior_class_dict.items()
},
**{prior: np.ndarray for _, prior in self.direct_prior_tuples},
}

def gaussian_prior_model_for_arguments(self, arguments: Dict[Prior, Prior]):
"""
Returns a new instance of model mapper with a set of Gaussian priors based on
tuples provided by a previous nonlinear search.

Parameters
----------
arguments
Tuples providing the mean and sigma of gaussians

Returns
-------
A new model mapper populated with Gaussian priors
"""
new_array = Array(self.shape)
for index in self.indices:
new_array[index] = self[index].gaussian_prior_model_for_arguments(arguments)

return new_array
Empty file.
Loading
Loading