BUG: local_det_chol
fails when function inputs
are included in outputs
#392
Labels
local_det_chol
fails when function inputs
are included in outputs
#392
Describe the issue:
When compiling a graph in JAX mode with a
scan
Op
with an inner function that returns the determinant of a sit-sot tap (is that the right term? I mean a recursive input), an error will be raised by thelocal_det_chol
rewrite. Something to do with the recurrent variable? In the minimum example,P0
must be inoutputs_info
, or else the error will not trigger.An easy fix is to change:
To:
But this seems like it's changing a symptom and not the underlying problem. Maybe that's nit-picking?
Reproducable code example:
Error message:
PyTensor version information:
Context for the issue:
I want to scan over a stack of multivariate distributions and compute the sequence of
logp
s, which involves taking determinants in a scan.The text was updated successfully, but these errors were encountered: