-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive logprob of matmul #7542
Derive logprob of matmul #7542
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7542 +/- ##
==========================================
- Coverage 92.85% 92.82% -0.04%
==========================================
Files 105 106 +1
Lines 17591 17669 +78
==========================================
+ Hits 16335 16402 +67
- Misses 1256 1267 +11
|
… direct valued nodes
@@ -320,7 +320,7 @@ def find_negated_var(var): | |||
return None | |||
|
|||
|
|||
def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]: | |||
def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much more natural order used in all sorts of pytensor utilities that require a node/variable and its' fgraph
# In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise | ||
# operations separate it from the valued node. Further transformations likely need to know where | ||
# the support axes are for a correct implementation (and thus assume they are the rightmost axes). | ||
# TODO: When we include the support axis as meta information in each intermediate MeasurableVariable, | ||
# we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360) | ||
if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain( | ||
fgraph, node | ||
): | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These DimShuffle changes were needed to naturally accommodate A @ x
when x is a vector, which looks like:
import pytensor.tensor as pt
A = pt.matrix("A")
x = pt.vector("x")
y = A @ x
y.dprint()
# DropDims{axis=1} [id A]
# └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B]
# ├─ A [id C]
# └─ ExpandDims{axis=1} [id D]
# └─ x [id E]
It's also more strict / correct than the limitation we had before, because the concerns are much more about what's after the DimShuffle not so much before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm : )
This can be pretty useful for timeseries models as well as normalizing flows.
Here is a simple example: