Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive logprob of matmul #7542

Merged
merged 3 commits into from
Oct 21, 2024
Merged

Derive logprob of matmul #7542

merged 3 commits into from
Oct 21, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 18, 2024

This can be pretty useful for timeseries models as well as normalizing flows.

Here is a simple example:

import numpy as np
import pytensor.tensor as pt

import pymc as pm

rng = np.random.default_rng(37)
x = pm.MvNormal.dist(cov=np.eye(2), size=(128,))

n_layers = 3
A = pt.tensor("A", shape=(n_layers, 2, 2))
b = pt.tensor("b", shape=(n_layers, 2,))

# Repeated layers of Affine transform -> Tanh
for i in range(n_layers):
    y = A[i] @ x + b[i]
    # parametrized leaky-relu would be nicer: https://github.com/pymc-devs/pymc/issues/7543
    # y = pt.switch(y > 0, y, c[i] * y)
    y = pt.tanh(y)

A_test = rng.normal(size=A.type.shape)
b_test = rng.normal(size=b.type.shape)
y_test = rng.uniform(-1, 1, size=y.type.shape)
pm.logp(y, y_test).sum().eval({A: A_test, b: b_test})  # array(-3.54498234)

Copy link

codecov bot commented Oct 18, 2024

Codecov Report

Attention: Patch coverage is 87.77778% with 11 lines in your changes missing coverage. Please review.

Project coverage is 92.82%. Comparing base (5352798) to head (5e5e077).

Files with missing lines Patch % Lines
pymc/logprob/tensor.py 79.41% 7 Missing ⚠️
pymc/logprob/linalg.py 91.48% 4 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7542      +/-   ##
==========================================
- Coverage   92.85%   92.82%   -0.04%     
==========================================
  Files         105      106       +1     
  Lines       17591    17669      +78     
==========================================
+ Hits        16335    16402      +67     
- Misses       1256     1267      +11     
Files with missing lines Coverage Δ
pymc/logprob/__init__.py 100.00% <100.00%> (ø)
pymc/logprob/abstract.py 94.28% <100.00%> (+0.16%) ⬆️
pymc/logprob/basic.py 94.28% <100.00%> (ø)
pymc/logprob/mixture.py 95.70% <100.00%> (ø)
pymc/logprob/rewriting.py 100.00% <100.00%> (ø)
pymc/logprob/scan.py 94.90% <ø> (ø)
pymc/logprob/transform_value.py 98.14% <100.00%> (ø)
pymc/logprob/utils.py 92.46% <100.00%> (ø)
pymc/logprob/linalg.py 91.48% <91.48%> (ø)
pymc/logprob/tensor.py 94.48% <79.41%> (-5.52%) ⬇️

@@ -320,7 +320,7 @@ def find_negated_var(var):
return None


def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]:
def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much more natural order used in all sorts of pytensor utilities that require a node/variable and its' fgraph

Comment on lines +277 to +285
# In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise
# operations separate it from the valued node. Further transformations likely need to know where
# the support axes are for a correct implementation (and thus assume they are the rightmost axes).
# TODO: When we include the support axis as meta information in each intermediate MeasurableVariable,
# we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360)
if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain(
fgraph, node
):
return None
Copy link
Member Author

@ricardoV94 ricardoV94 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These DimShuffle changes were needed to naturally accommodate A @ x when x is a vector, which looks like:

import pytensor.tensor as pt

A = pt.matrix("A")
x = pt.vector("x")
y = A @ x
y.dprint()
# DropDims{axis=1} [id A]
#  └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B]
#     ├─ A [id C]
#     └─ ExpandDims{axis=1} [id D]
#        └─ x [id E]

It's also more strict / correct than the limitation we had before, because the concerns are much more about what's after the DimShuffle not so much before.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm : )

@ricardoV94 ricardoV94 merged commit 1249c86 into pymc-devs:main Oct 21, 2024
18 of 20 checks passed
@ricardoV94 ricardoV94 deleted the matmul branch October 21, 2024 09:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants