Skip to content

Commit

Permalink
Fix bug in local_reduce_join rewrite.
Browse files Browse the repository at this point in the history
The helper `apply_local_dimshuffle_lift` requires a FunctionGraph when elemwise inputs are involved.
  • Loading branch information
ricardoV94 committed Nov 1, 2024
1 parent e934ac7 commit e73258b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node):
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis))
new_inputs.append(new_input)

ret = Elemwise(node.op.scalar_op)(*new_inputs)
Expand Down
19 changes: 19 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain,
local_reduce_join,
local_sum_prod_of_mul_or_div,
mul_canonizer,
parse_mul_tree,
Expand Down Expand Up @@ -3415,6 +3416,24 @@ def test_not_supported_unequal_shapes(self):
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)

def test_non_ds_inputs(self):
"""Make sure rewrite works when inputs to join are not the usual DimShuffle.
Sum{axis=1} [id A] <Vector(float64, shape=(3,))>
└─ Join [id B] <Matrix(float64, shape=(3, 3))>
├─ 1 [id C] <Scalar(int8, shape=())>
├─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(3, 1))>
├─ Sub [id E] <Matrix(float64, shape=(3, 1))>
└─ Sub [id F] <Matrix(float64, shape=(3, 1))>
"""
x = vector("x")
out = join(0, exp(x[None]), log(x[None])).sum(axis=0)

fg = FunctionGraph([x], [out], clone=False)
[rewritten_out] = local_reduce_join.transform(fg, out.owner)
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])


def test_local_useless_adds():
default_mode = get_default_mode()
Expand Down

0 comments on commit e73258b

Please sign in to comment.