Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle non-square intermediate Blockwise operations #1017

Open
ricardoV94 opened this issue Oct 6, 2024 · 0 comments
Open

Handle non-square intermediate Blockwise operations #1017

ricardoV94 opened this issue Oct 6, 2024 · 0 comments
Labels
enhancement New feature or request vectorization

Comments

@ricardoV94
Copy link
Member

Description

When output shapes depend on input values, Blockwise are not necessarily valid at runtime. For example the following graph is not supported by PyTensor at runtime, because it would require support for ragged arrays in the intermediate Blockwise(Arange):

import pytensor.tensor as pt
from pytensor.graph import vectorize

i = pt.scalar("i", dtype=int)
y = pt.sum(pt.arange(0, i))

new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]})  # ValueError

However if we were to wrap the Arange + Sum in an OpFromGraph, that subgraph would be a valid Blockwise, and PyTensor would be happy to evaluate it:

import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph

i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)

new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]})  # [0, 1, 3, 6]

Would be nice to use this trick to support end-to-end vectorization in these cases. Some of the logic needed to infer whether an Op has a square shape or not is being developed in #1015.

Some of the logic developed in pymc-devs/pymc-experimental#300 to understand how dimensions propagate over nodes could be repurposed to figure out in which cases a subgraph collapses ragged dimensions.

We can start very simple and just allow immediate reductions of ragged dimensions.

@ricardoV94 ricardoV94 added enhancement New feature or request vectorization labels Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request vectorization
Projects
None yet
Development

No branches or pull requests

1 participant