-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature/prior array #1021
Merged
Merged
feature/prior array #1021
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
ac570e8
fixed test
rhayes777 10ad4f9
creating an Array model
rhayes777 2588854
instance from prior medians for special case of 2d array
rhayes777 f9e4820
generalised instance for method
rhayes777 66d055e
use set and get item to simplify implementation
rhayes777 08f8a73
test modification
rhayes777 a0f0b87
modifying values + fix
rhayes777 dcc1195
more testing
rhayes777 f5116ab
array from dict
rhayes777 fc30f86
testing complex dict
rhayes777 7a50b5a
properly handling from dict
rhayes777 718fed9
docs and typws
rhayes777 00ba187
test 1d array
rhayes777 df708b7
modifying values on 1d arrays
rhayes777 5353631
tree flatten and unflatten for pytrees (jax)
rhayes777 fd44f53
array prior passing
rhayes777 ed5c582
docs
rhayes777 e88af1f
merge main
Jammy2211 b3ad3eb
model cookbook doc
Jammy2211 526f4c8
docs
Jammy2211 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
from typing import Tuple, Dict, Optional, Union | ||
|
||
from autoconf.dictable import from_dict | ||
from .abstract import AbstractPriorModel | ||
from autofit.mapper.prior.abstract import Prior | ||
import numpy as np | ||
|
||
from autofit.jax_wrapper import register_pytree_node_class | ||
|
||
|
||
@register_pytree_node_class | ||
class Array(AbstractPriorModel): | ||
def __init__( | ||
self, | ||
shape: Tuple[int, ...], | ||
prior: Optional[Prior] = None, | ||
): | ||
""" | ||
An array of priors. | ||
|
||
Parameters | ||
---------- | ||
shape : (int, int) | ||
The shape of the array. | ||
prior : Prior | ||
The prior of every entry in the array. | ||
""" | ||
super().__init__() | ||
self.shape = shape | ||
self.indices = list(np.ndindex(*shape)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are converting |
||
|
||
if prior is not None: | ||
for index in self.indices: | ||
self[index] = prior.new() | ||
|
||
@staticmethod | ||
def _make_key(index: Tuple[int, ...]) -> str: | ||
""" | ||
Make a key for the prior. | ||
|
||
This is so an index (e.g. (1, 2)) can be used to access a | ||
prior (e.g. prior_1_2). | ||
|
||
Parameters | ||
---------- | ||
index | ||
The index of an element in an array. | ||
|
||
Returns | ||
------- | ||
The attribute name for the prior. | ||
""" | ||
if isinstance(index, int): | ||
suffix = str(index) | ||
else: | ||
suffix = "_".join(map(str, index)) | ||
return f"prior_{suffix}" | ||
|
||
def _instance_for_arguments( | ||
self, | ||
arguments: Dict[Prior, float], | ||
ignore_assertions: bool = False, | ||
) -> np.ndarray: | ||
""" | ||
Create an array where the prior at each index is replaced with the | ||
a concrete value. | ||
|
||
Parameters | ||
---------- | ||
arguments | ||
The arguments to replace the priors with. | ||
ignore_assertions | ||
Whether to ignore assertions in the priors. | ||
|
||
Returns | ||
------- | ||
The array with the priors replaced. | ||
""" | ||
array = np.zeros(self.shape) | ||
for index in self.indices: | ||
value = self[index] | ||
try: | ||
value = value.instance_for_arguments( | ||
arguments, | ||
ignore_assertions, | ||
) | ||
except AttributeError: | ||
pass | ||
|
||
array[index] = value | ||
return array | ||
|
||
def __setitem__( | ||
self, | ||
index: Union[int, Tuple[int, ...]], | ||
value: Union[float, Prior], | ||
): | ||
""" | ||
Set the value at an index. | ||
|
||
Parameters | ||
---------- | ||
index | ||
The index of the prior. | ||
value | ||
The new value. | ||
""" | ||
setattr( | ||
self, | ||
self._make_key(index), | ||
value, | ||
) | ||
|
||
def __getitem__( | ||
self, | ||
index: Union[int, Tuple[int, ...]], | ||
) -> Union[float, Prior]: | ||
""" | ||
Get the value at an index. | ||
|
||
Parameters | ||
---------- | ||
index | ||
The index of the value. | ||
|
||
Returns | ||
------- | ||
The value at the index. | ||
""" | ||
return getattr( | ||
self, | ||
self._make_key(index), | ||
) | ||
|
||
@classmethod | ||
def from_dict( | ||
cls, | ||
d, | ||
reference: Optional[Dict[str, str]] = None, | ||
loaded_ids: Optional[dict] = None, | ||
) -> "Array": | ||
""" | ||
Create an array from a dictionary. | ||
|
||
Parameters | ||
---------- | ||
d | ||
The dictionary. | ||
reference | ||
A dictionary of references. | ||
loaded_ids | ||
A dictionary of loaded ids. | ||
|
||
Returns | ||
------- | ||
The array. | ||
""" | ||
arguments = d["arguments"] | ||
shape = from_dict(arguments["shape"]) | ||
array = cls(shape) | ||
for key, value in arguments.items(): | ||
if key.startswith("prior"): | ||
setattr(array, key, from_dict(value)) | ||
|
||
return array | ||
|
||
def tree_flatten(self): | ||
""" | ||
Flatten this array model as a PyTree. | ||
""" | ||
members = [self[index] for index in self.indices] | ||
return members, (self.shape,) | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
""" | ||
Unflatten a PyTree into an array model. | ||
""" | ||
(shape,) = aux_data | ||
instance = cls(shape) | ||
for index, child in zip(instance.indices, children): | ||
instance[index] = child | ||
|
||
return instance | ||
|
||
@property | ||
def prior_class_dict(self): | ||
return { | ||
**{ | ||
prior: cls | ||
for prior_model in self.direct_prior_model_tuples | ||
for prior, cls in prior_model[1].prior_class_dict.items() | ||
}, | ||
**{prior: np.ndarray for _, prior in self.direct_prior_tuples}, | ||
} | ||
|
||
def gaussian_prior_model_for_arguments(self, arguments: Dict[Prior, Prior]): | ||
""" | ||
Returns a new instance of model mapper with a set of Gaussian priors based on | ||
tuples provided by a previous nonlinear search. | ||
|
||
Parameters | ||
---------- | ||
arguments | ||
Tuples providing the mean and sigma of gaussians | ||
|
||
Returns | ||
------- | ||
A new model mapper populated with Gaussian priors | ||
""" | ||
new_array = Array(self.shape) | ||
for index in self.indices: | ||
new_array[index] = self[index].gaussian_prior_model_for_arguments(arguments) | ||
|
||
return new_array |
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you describe what happens if the parameter prior is not provided?