diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index 6e99783c501..dda46d90350 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -673,52 +673,59 @@ function _evaluate_expr( return convert(Float64, expr) end -function _evaluate_user_defined_function( - registry, - f, - expr::GenericNonlinearExpr, -) - model = owner_model(expr) - op, nargs = expr.head, length(expr.args) - udf = MOI.get(model, MOI.UserDefinedFunction(op, nargs)) - if udf === nothing - return error( - "Unable to evaluate nonlinear operator $op because it was " * - "not added as an operator.", - ) - end - args = [_evaluate_expr(registry, f, arg) for arg in expr.args] - return first(udf)(args...) -end - function _evaluate_expr( registry::MOI.Nonlinear.OperatorRegistry, f::Function, expr::GenericNonlinearExpr, ) - op = expr.head - # TODO(odow): uses private function - if !MOI.Nonlinear._is_registered(registry, op, length(expr.args)) - return _evaluate_user_defined_function(registry, f, expr) - end - if length(expr.args) == 1 && haskey(registry.univariate_operator_to_id, op) - arg = _evaluate_expr(registry, f, expr.args[1]) - return MOI.Nonlinear.eval_univariate_function(registry, op, arg) - elseif haskey(registry.multivariate_operator_to_id, op) - args = [_evaluate_expr(registry, f, arg) for arg in expr.args] - return MOI.Nonlinear.eval_multivariate_function(registry, op, args) - elseif haskey(registry.logic_operator_to_id, op) - @assert length(expr.args) == 2 - x = _evaluate_expr(registry, f, expr.args[1]) - y = _evaluate_expr(registry, f, expr.args[2]) - return MOI.Nonlinear.eval_logic_function(registry, op, x, y) - else - @assert haskey(registry.comparison_operator_to_id, op) - @assert length(expr.args) == 2 - x = _evaluate_expr(registry, f, expr.args[1]) - y = _evaluate_expr(registry, f, expr.args[2]) - return MOI.Nonlinear.eval_comparison_function(registry, op, x, y) + # The result_stack needs to be ::Real because operators like || return a + # ::Bool. Also, some inputs may be ::Int. + stack, result_stack = Any[expr], Real[] + while !isempty(stack) + arg = pop!(stack) + if arg isa GenericNonlinearExpr + push!(stack, (arg,)) # wrap in (,) to catch when we should eval it. + for child in arg.args + push!(stack, child) + end + elseif arg isa Tuple{<:GenericNonlinearExpr} + f_expr = only(arg) + op, nargs = f_expr.head, length(f_expr.args) + # TODO(odow): uses private function + result = if !MOI.Nonlinear._is_registered(registry, op, nargs) + model = owner_model(f_expr) + udf = MOI.get(model, MOI.UserDefinedFunction(op, nargs)) + if udf === nothing + return error( + "Unable to evaluate nonlinear operator $op because " * + "it was not added as an operator.", + ) + end + first(udf)((pop!(result_stack) for _ in 1:nargs)...) + elseif nargs == 1 && haskey(registry.univariate_operator_to_id, op) + x = pop!(result_stack) + MOI.Nonlinear.eval_univariate_function(registry, op, x) + elseif haskey(registry.multivariate_operator_to_id, op) + args = Real[pop!(result_stack) for _ in 1:nargs] + MOI.Nonlinear.eval_multivariate_function(registry, op, args) + elseif haskey(registry.logic_operator_to_id, op) + @assert nargs == 2 + x = pop!(result_stack) + y = pop!(result_stack) + MOI.Nonlinear.eval_logic_function(registry, op, x, y) + else + @assert haskey(registry.comparison_operator_to_id, op) + @assert nargs == 2 + x = pop!(result_stack) + y = pop!(result_stack) + MOI.Nonlinear.eval_comparison_function(registry, op, x, y) + end + push!(result_stack, result) + else + push!(result_stack, _evaluate_expr(registry, f, arg)) + end end + return only(result_stack) end # MutableArithmetics.jl and promotion diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index 8bb131dfa4f..bfcbfbf8829 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -371,6 +371,33 @@ function test_extension_recursion_stackoverflow( return end +function test_evaluate_expr_stackoverflow() + N = 10_000 + model = Model() + @variable(model, x[1:N], start = 0) + f = prod(sum(x[1:i]) for i in 1:N) + @test value(start_value, f) == 0.0 + set_start_value.(x, 1.0) + @test value(start_value, f) == Inf + return +end + +function test_evaluate_expr_stackoverflow_user_defined_function() + N = 10_000 + f(x, y) = *(x, y) + model = Model() + @variable(model, x[1:N], start = 0) + @operator(model, op_f, 2, f) + y = x[1] + for i in 2:N + y = op_f(x[i], y) + end + @test value(start_value, y) == 0.0 + set_start_value.(x, 1.0) + @test value(start_value, y) == 1.0 + return +end + function test_nlobjective_with_nlexpr() model = Model() @variable(model, x)