Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for simple deduplication of terms (for improved performance) when building functions #698

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ Special Keyword Argumnets:
filling function is 0.
- `fillzeros`: Whether to perform `fill(out,0)` before the calculations to ensure
safety with `skipzeros`.
- `deduplicate_terms`: List of Terms to be computed in separate variables and
substituted into each equation (de-duplicating the computation of those terms).
"""
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
conv = toexpr, expression = Val{true},
Expand All @@ -224,7 +226,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
convert_oop = true, force_SA = false,
skipzeros = outputidxs===nothing,
fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC),
parallel=SerialForm(), kwargs...)
parallel=SerialForm(), deduplicate_terms=nothing, kwargs...)
if multithread isa Bool
@warn("multithraded is deprecated for the parallel argument. See the documentation.")
parallel = multithread ? MultithreadedForm() : SerialForm()
Expand All @@ -233,6 +235,16 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
argnames = [gensym(:MTKArg) for i in 1:length(args)]
symsdict = Dict()
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)

# Terms that should be deduplicated get converted into new variables and substituted into the RHS equations
dedupe_exprs = Expr[]
if !isnothing(deduplicate_terms)
!isa(parallel, SerialForm) && error("`deduplicate_terms` is not yet supported with `parallel`!")
# NOTE: Also updates symsdict so that variables get substituted using process() below.
dedupe_pairs = map((x,y)->vars_to_pairs(x, y, symsdict), conv.(deduplicate_terms), deduplicate_terms)
dedupe_exprs = [Expr(:(=), ModelingToolkit.build_expr(:tuple, varname), ModelingToolkit.build_expr(:tuple, term)) for (varname, term) in dedupe_pairs]
end

process = unflatten_long_ops∘(x->substitute(x, symsdict, fold=false))

ls = reduce(vcat,conv.(first.(arg_pairs)))
Expand Down Expand Up @@ -326,6 +338,14 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;

ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))

# Insert the deduplication variables inside the Let block, but before the Expressions.
# Of course, it would be simple to simply add the dedupe variables to the Let expression (i.e. `var_eqs`),
# but then we cannot use any of the variables as they were named by the user. So instead, we inject
# the dedupe variables into the Let block.
# NOTE: This won't play nicely with the parallel code below, but we already threw in error if they try!
# TODO: Add support for parallel functions (e.g. copy or compute the dedup variables in each thread?)
prepend!(ip_let_expr.args[2].args, collect(dedupe_exprs))

if parallel isa MultithreadedForm
lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads()))
threaded_exprs = vcat([quote
Expand Down Expand Up @@ -442,7 +462,12 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;

let_expr = Expr(:let, var_eqs, tuple_sys_expr)
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)

# As done above with ip_let_expr, we inject the dedupe variables into the Let expression here too.
prepend!(arr_let_expr.args[2].args, collect(dedupe_exprs))

bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)

oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ Symbol("z⦗t⦘")
```
"""
function tosymbol(t::Term; states=nothing, escape=true)
if t.op isa Sym
if t.op isa Union{Sym, Function}
if states !== nothing && !(any(isequal(t), states))
return nameof(t.op)
end
Expand Down
46 changes: 46 additions & 0 deletions test/deduplication_terms_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Revise # DELETE ME
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

using Test
using ModelingToolkit
using LinearAlgebra.BLAS
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

@parameters t
@variables vars[1:10]
@derivatives D'~t

const sleep_time = 0.1
function untrace_dupe_f(a::Float64, t::Float64)
# This is a function that is both expensive and untraceable (symbolically),
# and whose result is used several times in our ODE system of equations.
sleep(sleep_time) # Expensive function!
BLAS.dot(10, fill(a, 10), 1, fill(t, 20), 2) # Untraceable result!
end
@test untrace_dupe_f(5., 12.) == 600.0

# Confirm that our function cannot be traced symbolically...
@test_throws MethodError untrace_dupe_f(vars[1], t)

# Register the function so we can now use it without having to trace it.
@register untrace_dupe_f(a, t)

# Build an ODE system of equations that uses several duplicated, untraceable terms.
dupe_terms = (untrace_dupe_f(x, t) for x in vars)
eqs = (D.(vars) .~ prod(dupe_terms))
@test length(eqs) == length(vars)

# Build the ODE functions.
ode_system = ODESystem(eqs, t, vars, [])
ode_function_naive = ODEFunction(ode_system; jac=false, tgrad=false, eval_expression=false)
ode_function_deduplicated = ODEFunction(ode_system; jac=false, tgrad=false, eval_expression=false, deduplicate_terms=dupe_terms)

# Run both functions and compare...
u0 = rand(Float64, length(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is not defined anywhere. I suggest you to use @safetestset to run the tests locally, so that you won't pollute your test environment.

t0 = 10.0
result_naive = ode_function_naive(u0, [], t0)
result_deduplicated = ode_function_deduplicated(u0, [], t0)
@test result_naive == result_deduplicated

# The deduplicated version should run much faster, since it does not need to recompute the terms multiple times.
fuzzy_factor = 0.9
expected_speedup = fuzzy_factor * length(eqs)
@test @elapsed(ode_function_naive(u0, [], t0)) > (expected_speedup * @elapsed(ode_function_deduplicated(u0, [], t0)))
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ using SafeTestsets, Test
println("Last test requires gcc available in the path!")
@safetestset "C Compilation Test" begin include("ccompile.jl") end
@safetestset "Latexify recipes Test" begin include("latexify.jl") end
@safetestset "Deduplication of Terms Test" begin include("deduplication_terms_test.jl") end