Skip to content

Commit

Permalink
Remove tensor__local_elemwise_fusion config.
Browse files Browse the repository at this point in the history
Same behavior can be obtained with `optimizer_excluding`

The `local_careduce_rewrite` is now included in this database. Otherwise it would usually not be applied because it ran before the fusion rewrites
  • Loading branch information
ricardoV94 committed Jul 4, 2023
1 parent 671cb44 commit e20dd0b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 72 deletions.
10 changes: 0 additions & 10 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,16 +640,6 @@ def add_tensor_configvars():
in_c_key=False,
)

config.add(
"tensor__local_elemwise_fusion",
(
"Enable or not in fast_run mode(fast_run optimization) the elemwise "
"fusion optimization"
),
BoolParam(True),
in_c_key=False,
)

# http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx
config.add(
"lib__amblibm",
Expand Down
66 changes: 36 additions & 30 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,38 +1085,10 @@ def print_profile(stream, prof, level=0):
print(blanc, " time_toposort", prof[7], file=stream)


if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)


@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_useless_composite(fgraph, node):
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite
Expand Down Expand Up @@ -1231,11 +1203,45 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]


# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)

fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
fuse_seqopt.register(
"local_useless_composite_outputs",
in2out(local_useless_composite_outputs),
"fast_run",
"fusion",
position=2,
)
fuse_seqopt.register(
"local_careduce_fusion",
in2out(local_careduce_fusion),
"fast_run",
"fusion",
position=49,
position=10,
)


Expand Down
65 changes: 33 additions & 32 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,39 +1425,40 @@ def test_nested_composite(self):
fval = f([1, 2, 3])
assert np.all(fval == [6, 12, 18])

def test_local_useless_composite(self):
x = aes.float32()
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")

f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])

f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])

f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
def test_local_useless_composite_outputs():
x = aes.float32()
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")

f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])

f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])

f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])


def test_local_useless_dimshuffle_makevector():
Expand Down

0 comments on commit e20dd0b

Please sign in to comment.