Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@ operator returns wrong graph for batched matrices #451

Closed
ricardoV94 opened this issue Sep 23, 2023 · 1 comment · Fixed by #452
Closed

@ operator returns wrong graph for batched matrices #451

ricardoV94 opened this issue Sep 23, 2023 · 1 comment · Fixed by #452
Labels
bug Something isn't working NumPy compatibility

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 23, 2023

Description

import numpy as np
import pytensor.tensor as pt

x = pt.tensor("x", shape=(10, 3, 3))
x_val = np.random.normal(size=x.type.shape)

np.testing.assert_allclose(
  (x @ x).eval({x: x_val}),
  x_val @ x_val,
)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
(shapes (10, 3, 10, 3), (10, 3, 3) mismatch)
 x: array([[[[ 3.433369e-01, -2.325831e+00, -1.113984e+00],
         [-7.958884e-02,  1.460045e+00, -1.120295e+00],
         [-4.348782e-01, -4.313117e-01, -6.228517e-02],...
 y: array([[[ 0.343337, -2.325831, -1.113984],
        [ 0.614024,  0.805719, -0.136424],
        [-1.463383, -2.395001,  4.435603]],...

This happens because the __matmul__ method is not returning a pt.matmul which works correctly

def __dot__(left, right):
return at.math.dense_dot(left, right)
def __rdot__(right, left):
return at.math.dense_dot(left, right)
dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__

np.testing.assert_allclose(
  pt.matmul(x, x).eval({x: x_val}),
  x_val @ x_val,
)  # Fine
@ricardoV94 ricardoV94 added bug Something isn't working NumPy compatibility labels Sep 23, 2023
@ricardoV94
Copy link
Member Author

On the other hand, matmul doesn't have a C implementation or gradient. We should just Blockwise the matrix-matrix dot implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NumPy compatibility
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant