Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_prob of blocked programs #81

Open
jhn-nt opened this issue Aug 29, 2024 · 1 comment
Open

log_prob of blocked programs #81

jhn-nt opened this issue Aug 29, 2024 · 1 comment

Comments

@jhn-nt
Copy link

jhn-nt commented Aug 29, 2024

Hello,

First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:

from jax.random import split
from oryx.core import ppl
import tensorflow_probability.substrates.jax.distributions as tfd

def latent_normal(key):
    z_key,x_key= split(key)
    z=ppl.random_variable(tfd.Normal(0,1),name="z")(z_key)
    return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)


blocked=ppl.block(latent_normal,names=["z"])
ppl.joint_log_prob(blocked)({"x":10})

However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})

File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)

File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob

ValueError: Cannot compute log_prob of function."
}

Am I missing something?

Thanks again

Very Best
Giovanni

@PaulScemama
Copy link

PaulScemama commented Dec 21, 2024

Hi @jhn-nt. I'm not a contributor or maintainer but I think you might be missing the following. If you look at this description of block it says:

"The block transformation takes in a program and a sequence of names and returns a program that behaves identically except that in downstream transformations (likejoint_sample), the provided names are ignored."

When you run blocked=ppl.block(latent_normal,names=["z"]) you are telling downstream transformations (like joint_log_prob) to ignore the variable with name 'z'. Because the variable with name 'x' depends on the variable with name 'z', which you blocked, joint_log_prob is unable to "see" what the value of 'z' is to get the distribution over 'x' to then compute the log_prob of 10 under that distribution. That's why you're seeing the error.

In general, I would guess you are unable to compute samples or log probabilities of a program with respect to a variable that depends on variables you previously blocked. @sharadmv let me know if I missed anything; I am not sure if there is a principled way to catch this sort of error and display a more helpful error message -- I may look into it. And @jhn-nt let me know if you have any further questions! Hope that helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants