Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make 'RestrictedTransformForConditional' conform to 'torch.Transform' interface #1057

Open
Baschdl opened this issue Mar 20, 2024 · 2 comments
Labels
architecture Internal changes without API consequences

Comments

@Baschdl
Copy link
Contributor

Baschdl commented Mar 20, 2024

RestrictedTransformForConditional is currently typed as torch.Transform having a transform as variable:

class RestrictedTransformForConditional(torch_tf.Transform):
"""
Class to restrict the transform to fewer dimensions for conditional sampling.
The resulting transform transforms only the free dimensions of the conditional.
Notably, the `log_abs_det` is computed given all dimensions. However, the
`log_abs_det` stemming from the fixed dimensions is a constant and drops out during
MCMC.
All methods work in a similar way: `full_theta`` will first have all entries of the
`condition` and then override the entries that should be sampled with `theta`. In
case `theta` is a batch of `theta` (e.g. multi-chain MCMC), we have to repeat
`theta_condition`` to the match the batchsize.
This is needed for the the MCMC initialization functions when conditioning and when
transforming the samples back into the original theta space after sampling.
"""
def __init__(
self,
transform: torch_tf.Transform,
condition: Tensor,
dims_to_sample: List[int],
) -> None:

The conditioning with theta makes the interface incompatible to torch.Transform as e.g. the normal inv() is called without arguments and our inv(theta) is called with an argument. This example could be fixed by renaming it to restricted_inv(theta).

@Baschdl Baschdl added the architecture Internal changes without API consequences label Mar 20, 2024
@janfb janfb added this to the Hackathon and release 2024 milestone Aug 6, 2024
@gmoss13
Copy link
Contributor

gmoss13 commented Aug 20, 2024

The problem is caused by the fact that RestrictedTransformForConditional is a torch Transform, but also takes a torch Transform as an argument. This can be refactored, but I also think this is not priority for the release milestone as conditional_potential is not used by anything at the moment.

@gmoss13 gmoss13 removed this from the Hackathon and release 0.23 milestone Aug 20, 2024
@janfb
Copy link
Contributor

janfb commented Aug 20, 2024

Thanks for the clarification.

Then this is related to the recent question here: #1223 and we will fix this when we fix the general handling of custom potentials.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences
Projects
None yet
Development

No branches or pull requests

3 participants