Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #732

Closed
jessegrabowski opened this issue Apr 28, 2024 · 4 comments · Fixed by #769
Closed

Add rewrite to merge multiple SVD Ops with different settings #732

jessegrabowski opened this issue Apr 28, 2024 · 4 comments · Fixed by #769

Comments

@jessegrabowski
Copy link
Member

Description

SVD comes with a bunch of keyword arguments, most important of which is compute_uv. If False, it will return only the singular values for a given matrix. This is nice if you want to save on computation, but it can actually be inefficient if the user wants gradients. In the reverse mode, we need to compute the U and V matrices anyway, and indeed the L_op for SVD (added in #614 ) adds a 2nd SVD Op to the graph with compute_uv = True

When we see two SVD Ops with the same inputs on a graph, differing only by compute_uv, we should change compute_uv = False to True everywhere. This will allow pytensor to see that these outputs are equivalent and re-use them, rather than computing the decomposition multiple times.

@ricardoV94
Copy link
Member

The rewrite can do the merge immediately, it's just not a local rewrite but a global one then.

Also if an Op has compute_uv=True but the arrays are not used in the graph we can set it to False. That can be a local rewrite, but probably fine to handle together in the same global rewrite

@HangenYuu
Copy link
Contributor

Hi, I want to work on this. As I understand it, I will need to create a class SVDSimplify(GraphRewriter) to file pytensor/tensor/rewriting/linalg.py that check for SVD Op in PyTensor graph, and change the keyword argument compute_uv of all of them to True if

  1. there are more than 1 SVD Ops.
  2. at least one of them has compute_uv=True.

From the documentation, I know roughly how I should do it. I am still not 100% sure on all the decorators used (e.g., @register_canonicalize, @register_stabilize, @register_specialize, @node_rewriter() etc.) but I will ask your inputs in the draft PR.

@jessegrabowski
Copy link
Member Author

tensor\rewritings\linalg\local_det_chol is a good rewrite to look at, because it also uses the full FunctionGraph (the first argument to the rewrite function. usually called fgraph) to perform the rewrite.

The decorators tell pytensor at which step of the rewriting process the rewrite should be preformed. This one can come last, so I guess it should be @register_specialize. It's a @node_rewriter because it changes a single node of computation (an SVD Op with compute_uv=False), as opposed to a @graph_rewriter that operates on a whole group of nodes.

Tag me on your draft PR and I'm happy to walk you through the sharp bits.

@ricardoV94
Copy link
Member

This one can come last, so I guess it should be @register_specialize

This one is pretty cheap that we can run in all 3 stages. It will only be triggered if there's an SVD Op anyway

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants