You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:
However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob
ValueError: Cannot compute log_prob of function."
}
Am I missing something?
Thanks again
Very Best
Giovanni
The text was updated successfully, but these errors were encountered:
Hi @jhn-nt. I'm not a contributor or maintainer but I think you might be missing the following. If you look at this description of block it says:
"The block transformation takes in a program and a sequence of names and returns a program that behaves identically except that in downstream transformations (likejoint_sample), the provided names are ignored."
When you run blocked=ppl.block(latent_normal,names=["z"]) you are telling downstream transformations (like joint_log_prob) to ignore the variable with name 'z'. Because the variable with name 'x' depends on the variable with name 'z', which you blocked, joint_log_prob is unable to "see" what the value of 'z' is to get the distribution over 'x' to then compute the log_prob of 10 under that distribution. That's why you're seeing the error.
In general, I would guess you are unable to compute samples or log probabilities of a program with respect to a variable that depends on variables you previously blocked. @sharadmv let me know if I missed anything; I am not sure if there is a principled way to catch this sort of error and display a more helpful error message -- I may look into it. And @jhn-nt let me know if you have any further questions! Hope that helps.
Hello,
First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:
However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob
ValueError: Cannot compute log_prob of function."
}
Am I missing something?
Thanks again
Very Best
Giovanni
The text was updated successfully, but these errors were encountered: