Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/prior array #1021

Merged
merged 20 commits into from
Jul 26, 2024
Merged

feature/prior array #1021

merged 20 commits into from
Jul 26, 2024

Conversation

rhayes777
Copy link
Owner

Adds the Array class which functions as a PriorModel that creates numpy arrays of floats.

An array is defined by its shape and a prior.

array = af.Array(
    shape=(2, 2),
    prior=af.GaussianPrior(mean=0.0, sigma=1.0),
)

The prior is copied to each index meaning that a 2x2 array has four independent priors.

print(array.prior_count)
> 4

Arrays can be accessed and modified using indexing.

array[0, 0] = 1.0
print(array.prior_count)
> 3

They can be instantiated just as any other model class

instance = array.instance_from_prior_medians()
print(instance)
> [
     [1.0, 0.0],
     [0.0, 0.0],
  ]

Copy link
Collaborator

@Jammy2211 Jammy2211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't merge yet -- I'll build a test case this week and see whether any issues come up.

Couple of things to think about, albeit best not to worry about them until we've done some testing:

  1. Does this all play nicely with our JAX / Pytrees wrapping?
  2. How does this play with search chaining API?
  3. Is there a clean interface with the default-prior yaml configs?

@rhayes777
Copy link
Owner Author

Need to check JAX and search chaining. We mostly do config on the type -> prior relationship so it was hard to see how to carry this over

@Jammy2211
Copy link
Collaborator

I would make it so that the config default gives the same prior to everything on the numpy array.

We can worry about JAX and search chaining later on then!

@rhayes777
Copy link
Owner Author

I would make it so that the config default gives the same prior to everything on the numpy array.

We can worry about JAX and search chaining later on then!

So at the moment you define one prior and that gets copied to every slot. Would we want a config to determine the default for all Arrays globally?

@rhayes777
Copy link
Owner Author

Should now work with prior passing and JAX

----------
shape : (int, int)
The shape of the array.
prior : Prior

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you describe what happens if the parameter prior is not provided?

"""
super().__init__()
self.shape = shape
self.indices = list(np.ndindex(*shape))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are converting self.indices to a list but you are usint the typing Tuple[int, ...] everywhere.

@Jammy2211 Jammy2211 merged commit 355af89 into main Jul 26, 2024
0 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants