Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: local_det_chol fails when function inputs are included in outputs #392

Closed
jessegrabowski opened this issue Jul 20, 2023 · 3 comments · Fixed by #393
Closed

BUG: local_det_chol fails when function inputs are included in outputs #392

jessegrabowski opened this issue Jul 20, 2023 · 3 comments · Fixed by #393
Labels
bug Something isn't working graph rewriting

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Jul 20, 2023

Describe the issue:

When compiling a graph in JAX mode with a scan Op with an inner function that returns the determinant of a sit-sot tap (is that the right term? I mean a recursive input), an error will be raised by the local_det_chol rewrite. Something to do with the recurrent variable? In the minimum example, P0 must be in outputs_info, or else the error will not trigger.

An easy fix is to change:

if isinstance(node.op, Det):
    (x,) = node.inputs
    for cl, xpos in fgraph.clients[x]:
        if isinstance(cl.op, Cholesky):
            L = cl.outputs[0]
            return [prod(at.extract_diag(L) ** 2)]

To:

if isinstance(node.op, Det):
    (x,) = node.inputs
    for cl, xpos in fgraph.clients[x]:
        op = getattr(cl, 'op', None)
        if isinstance(op, Cholesky):
            L = cl.outputs[0]
            return [prod(at.extract_diag(L) ** 2)]

But this seems like it's changing a symptom and not the underlying problem. Maybe that's nit-picking?

Reproducable code example:

import pytensor
import pytensor.tensor as pt

def update(P):
    P_det = pt.linalg.det(P)
    return P, P_det

P0 = pt.eye(2)

outputs, _ = pytensor.scan(update, outputs_info=[P0, None], n_steps=2, mode='JAX')
f = pytensor.function([], outputs, mode='JAX')

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_det_chol
ERROR (pytensor.graph.rewriting.basic): node: Det(Add.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1914, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1074, in transform
    return self.fn(fgraph, node)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.10/site-packages/pytensor/tensor/rewriting/linalg.py", line 165, in local_det_chol
    if isinstance(cl.op, Cholesky):
AttributeError: 'str' object has no attribute 'op'

PyTensor version information:

pytensor version: 2.13.1

Context for the issue:

I want to scan over a stack of multivariate distributions and compute the sequence of logps, which involves taking determinants in a scan.

@jessegrabowski jessegrabowski added the bug Something isn't working label Jul 20, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Jul 20, 2023

Fix sounds good, not masking any issue.

I would probably do if cl=="output": continue. Everything else must be a normal node with an Op.

This shows up a lot when working with clients. I wonder if we could remove those "dummy output" clients and keep track of outputs separately.

@jessegrabowski
Copy link
Member Author

The scan isn't important to the bug, just that the input is also an output of a function, as in :

X = pt.dmatrix('X')
det_X = pt.linalg.det(X)
f = pytensor.function([X], [X, det_X])

(Thanks @ricardoV94 for finding this).

I'll open a PR to fix it.

@jessegrabowski jessegrabowski changed the title BUG: local_det_chol fails in recursive scans when mode=JAX BUG: local_det_chol fails when function inputs are included in outputs Jul 20, 2023
@jessegrabowski
Copy link
Member Author

Closed by #393

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants