Skip to content

Commit

Permalink
Flatten expressions at end, not during overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jun 19, 2023
1 parent dc1ef75 commit 472daa3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 88 deletions.
8 changes: 7 additions & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,13 @@ expressions.
"""
function _rewrite_expression(expr)
new_expr = MacroTools.postwalk(_rewrite_to_jump_logic, expr)
return _MA.rewrite(new_expr; move_factors_into_sums = false)
new_aff, parse_aff = _MA.rewrite(new_expr; move_factors_into_sums = false)
ret = gensym()
code = quote
$parse_aff
$ret = $flatten($new_aff)
end
return ret, code
end

function parse_constraint_head(
Expand Down
101 changes: 55 additions & 46 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,57 +282,66 @@ for f in (:+, :-, :*, :^, :/, :atan)
end
end

# Thse n-ary operators are associative. Instead of creating deeply nested
# binary trees, flatten arguments where possible.
for f in (:+, :*)
op = Meta.quot(f)
@eval begin
function Base.$f(x::NonlinearExpr{V}, y::_Constant) where {V}
y2 = convert(Float64, _constant_to_number(y))
if x.head == $op
return NonlinearExpr{V}($op, vcat(x.args, y2))
end
return NonlinearExpr{V}($op, x, y2)
end
function Base.$f(x::_Constant, y::NonlinearExpr{V}) where {V}
x2 = convert(Float64, _constant_to_number(x))
if y.head == $op
return NonlinearExpr{V}($op, vcat(x2, y.args))
end
return NonlinearExpr{V}($op, x2, y)
end
function Base.$f(x::NonlinearExpr{V}, y::AbstractJuMPScalar) where {V}
if x.head == $op
return NonlinearExpr{V}($op, vcat(x.args, y))
end
return NonlinearExpr{V}($op, x, y)
end
function Base.$f(x::AbstractJuMPScalar, y::NonlinearExpr{V}) where {V}
if y.head == $op
return NonlinearExpr{V}($op, vcat(x, y.args))
end
return NonlinearExpr{V}($op, x, y)
end
function Base.$f(x::NonlinearExpr{V}, y::NonlinearExpr{V}) where {V}
if x.head == $op && y.head == $op
return NonlinearExpr{V}($op, vcat(x.args, y.args))
elseif x.head == $op
return NonlinearExpr{V}($op, vcat(x.args, y))
elseif y.head == $op
return NonlinearExpr{V}($op, vcat(x, y.args))
function _MA.operate!!(
::typeof(_MA.add_mul),
x::NonlinearExpr,
y::AbstractJuMPScalar,
)
if x.head == :+
push!(x.args, y)
return x
end
return +(x, y)
end

"""
flatten(expr::NonlinearExpr)
Flatten a nonlinear expression by lifting nested `+` and `*` nodes into a single
n-ary operation.
## Motivation
Nonlinear expressions created using operator overloading can be deeply nested
and unbalanced. For example, `prod(x for i in 1:4)` creates
`*(x, *(x, *(x, x)))` instead of the more preferable `*(x, x, x, x)`.
## Example
```jldoctest
julia> model = Model();
julia> @variable(model, x)
x
julia> y = prod(x for i in 1:4)
((x² * x) * x)
julia> flatten(y)
(x² * x * x)
```
"""
function flatten(expr::NonlinearExpr{V}) where {V}
if !(expr.head in (:+, :*))
return expr
end
args = Any[]
nodes_to_visit = Any[arg for arg in reverse(expr.args)]
while !isempty(nodes_to_visit)
arg = pop!(nodes_to_visit)
if arg isa NonlinearExpr && arg.head == expr.head
for n in reverse(arg.args)
push!(nodes_to_visit, n)
end
return NonlinearExpr{V}($op, x, y)
end
function Base.$f(x::NonlinearExpr{U}, y::NonlinearExpr{V}) where {U,V}
return error(
"Unable to call ",
$op,
" with nonlinear expressions of different variable type",
)
else
push!(args, flatten(arg))
end
end
return NonlinearExpr{V}(expr.head, args)
end

flatten(expr) = expr

function _ifelse(a::AbstractJuMPScalar, x, y)
return NonlinearExpr{variable_ref_type(a)}(:ifelse, Any[a, x, y])
end
Expand Down
62 changes: 21 additions & 41 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,47 +93,27 @@ function test_extension_flatten_nary(
expr_plus = NonlinearExpr{VariableRefType}(:+, Any[x])
expr_mult = NonlinearExpr{VariableRefType}(:*, Any[x])
expr_sin = NonlinearExpr{VariableRefType}(:sin, Any[x])
@test string(+(expr_plus, 1)) == "(x + 1.0)"
@test string(+(1, expr_plus)) == "(1.0 + x)"
@test string(+(expr_plus, x)) == "(x + x)"
@test string(+(expr_sin, x)) == "(sin(x) + x)"
@test string(+(x, expr_plus)) == "(x + x)"
@test string(+(x, expr_sin)) == "(x + sin(x))"
@test string(+(expr_plus, expr_plus)) == "(x + x)"
@test string(+(expr_plus, expr_sin)) == "(x + sin(x))"
@test string(+(expr_sin, expr_plus)) == "(sin(x) + x)"
@test string(+(expr_sin, expr_sin)) == "(sin(x) + sin(x))"
@test string(*(expr_mult, 2)) == "(x * 2.0)"
@test string(*(2, expr_mult)) == "(2.0 * x)"
@test string(*(expr_mult, x)) == "(x * x)"
@test string(*(expr_sin, x)) == "(sin(x) * x)"
@test string(*(x, expr_mult)) == "(x * x)"
@test string(*(x, expr_sin)) == "(x * sin(x))"
@test string(*(expr_mult, expr_mult)) == "(x * x)"
@test string(*(expr_mult, expr_sin)) == "(x * sin(x))"
@test string(*(expr_sin, expr_mult)) == "(sin(x) * x)"
@test string(*(expr_sin, expr_sin)) == "(sin(x) * sin(x))"
return
end

function test_extension_error_associative(
ModelType = Model,
VariableRefType = VariableRef,
)
if VariableRefType == VariableRef
return
end
model1 = ModelType()
@variable(model1, x1)
model2 = Model()
@variable(model2, x2)
@test_throws(
ErrorException(
"Unable to call + with nonlinear expressions of different " *
"variable type",
),
+(sin(x1), sin(x2)),
)
to_string(x) = string(flatten(x))
@test to_string(+(expr_plus, 1)) == "(x + 1.0)"
@test to_string(+(1, expr_plus)) == "(1.0 + x)"
@test to_string(+(expr_plus, x)) == "(x + x)"
@test to_string(+(expr_sin, x)) == "(sin(x) + x)"
@test to_string(+(x, expr_plus)) == "(x + x)"
@test to_string(+(x, expr_sin)) == "(x + sin(x))"
@test to_string(+(expr_plus, expr_plus)) == "(x + x)"
@test to_string(+(expr_plus, expr_sin)) == "(x + sin(x))"
@test to_string(+(expr_sin, expr_plus)) == "(sin(x) + x)"
@test to_string(+(expr_sin, expr_sin)) == "(sin(x) + sin(x))"
@test to_string(*(expr_mult, 2)) == "(x * 2.0)"
@test to_string(*(2, expr_mult)) == "(2.0 * x)"
@test to_string(*(expr_mult, x)) == "(x * x)"
@test to_string(*(expr_sin, x)) == "(sin(x) * x)"
@test to_string(*(x, expr_mult)) == "(x * x)"
@test to_string(*(x, expr_sin)) == "(x * sin(x))"
@test to_string(*(expr_mult, expr_mult)) == "(x * x)"
@test to_string(*(expr_mult, expr_sin)) == "(x * sin(x))"
@test to_string(*(expr_sin, expr_mult)) == "(sin(x) * x)"
@test to_string(*(expr_sin, expr_sin)) == "(sin(x) * sin(x))"
return
end

Expand Down

0 comments on commit 472daa3

Please sign in to comment.