Skip to content

Commit

Permalink
Add subexpression kwarg to expression macro
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 12, 2024
1 parent 5e4273d commit 5c100eb
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
45 changes: 43 additions & 2 deletions src/macros/@expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,30 @@ macro expression(input_args...)
error_fn,
input_args;
num_positional_args = 2:3,
valid_kwargs = [:container],
valid_kwargs = [:container, :subexpression],
)
if Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@expressions`?")
end
is_subexpression = get(kwargs, :subexpression, false)
name_expr = length(args) == 3 ? args[2] : nothing
name, index_vars, indices = Containers.parse_ref_sets(
error_fn,
name_expr;
invalid_index_variables = [args[1]],
)
name_expr = Containers.build_name_expr(name, index_vars, kwargs)
model = esc(args[1])
expr, build_code = _rewrite_expression(args[end])
code = quote
$build_code
# Don't leak a `_MA.Zero` if the expression is an empty summation, or
# other structure that returns `_MA.Zero()`.
_replace_zero($model, $expr)
if $is_subexpression
_build_subexpression($error_fn, $model, $expr, $name_expr)
else
_replace_zero($model, $expr)
end
end
return _finalize_macro(
model,
Expand All @@ -97,6 +103,41 @@ macro expression(input_args...)
)
end

function _build_subexpression(
::Function,
model::AbstractModel,
expr::AbstractJuMPScalar,
name::String,
)
y = @variable(model)
set_name(y, name)
@constraint(model, y == expr)
return y
end

function _build_subexpression(
::Function,
model::AbstractModel,
expr::Array{<:AbstractJuMPScalar},
name::String,
)
y = [@variable(model) for _ in expr]
set_name.(y, name)
@constraint(model, y .== expr)
return y
end

function _build_subexpression(
error_fn::Function,
::AbstractModel,
expr::Any,
::String,
)
return error_fn(
"Unable to build a subexpression for the type $(typeof(expr))",
)
end

"""
@expressions(model, args...)
Expand Down
88 changes: 88 additions & 0 deletions test/test_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2487,4 +2487,92 @@ function test_array_scalar_sets()
return
end

function test_subexpression_kwarg()
model = Model()
@variable(model, x)
@expression(model, ex, sin(x), subexpression = true)
@test ex isa VariableRef
@test model[:ex] isa VariableRef
@test model[:ex] === ex
@test occursin("ex - sin(x) = 0", sprint(print, model))
@test num_variables(model) == 2
return
end

function test_subexpression_kwarg_array()
model = Model()
@variable(model, x[1:2])
@expression(model, ex[i in 1:2], sin(x[i]), subexpression = true)
@test ex isa Vector{VariableRef}
@test model[:ex] === ex
@test occursin("ex[1] - sin(x[1]) = 0", sprint(print, model))
@test occursin("ex[2] - sin(x[2]) = 0", sprint(print, model))
@test num_variables(model) == 4
return
end

function test_subexpression_kwarg_dense_axis_array()
model = Model()
@variable(model, x[2:3])
@expression(model, ex[i in 2:3], sin(x[i]), subexpression = true)
@test ex isa Containers.DenseAxisArray{VariableRef}
@test model[:ex] === ex
@test occursin("ex[2] - sin(x[2]) = 0", sprint(print, model))
@test occursin("ex[3] - sin(x[3]) = 0", sprint(print, model))
@test num_variables(model) == 4
return
end

function test_subexpression_kwarg_dense_axis_array()
model = Model()
@variable(model, x[i in 1:3; isodd(i)])
@expression(model, ex[i in 1:3; isodd(i)], sin(x[i]), subexpression = true)
@test ex isa Containers.SparseAxisArray{VariableRef}
@test model[:ex] === ex
@test occursin("ex[1] - sin(x[1]) = 0", sprint(print, model))
@test occursin("ex[3] - sin(x[3]) = 0", sprint(print, model))
@test num_variables(model) == 4
return
end

function test_subexpression_kwarg_vector_element()
model = Model()
@variable(model, x[i in 1:2])
@expression(model, ex, sin.(x), subexpression = true)
@test ex isa Vector{VariableRef}
@test model[:ex] === ex
@test occursin("ex - sin(x[1]) = 0", sprint(print, model))
@test occursin("ex - sin(x[2]) = 0", sprint(print, model))
@test num_variables(model) == 4
return
end

function test_subexpression_kwarg_no_name()
model = Model()
@variable(model, x)
ex = @expression(model, sin(x), subexpression = true)
@test ex isa VariableRef
@test !haskey(model, :ex)
@test occursin("_[2] - sin(x) = 0", sprint(print, model))
@test num_variables(model) == 2
return
end

function test_subexpression_kwarg_dict_element()
model = Model()
@variable(model, x[i in 1:2])
@test_throws_runtime(
ErrorException(
"In `@expression(model, ex, Dict((i => x[i] for i = 1:2)), subexpression = true)`: Unable to build a subexpression for the type Dict{Int64, VariableRef}",
),
@expression(
model,
ex,
Dict(i => x[i] for i in 1:2),
subexpression = true,
),
)
return
end

end # module

0 comments on commit 5c100eb

Please sign in to comment.